diff --git a/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java b/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java index e0b28fef646..e10807cf50b 100644 --- a/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java +++ b/client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java @@ -96,6 +96,7 @@ public class MQClientInstance { private final static long LOCK_TIMEOUT_MILLIS = 3000; + private final static long RESET_OFFSET_MAX_WAIT = 10; private final static Logger log = LoggerFactory.getLogger(MQClientInstance.class); private final ClientConfig clientConfig; private final String clientId; @@ -1380,9 +1381,11 @@ public synchronized void resetOffset(String topic, String group, Map iterator = processQueueTable.keySet().iterator(); @@ -1391,8 +1394,10 @@ public synchronized void resetOffset(String topic, String group, Map getConsumerStatus(String topic, String group) { MQConsumerInner impl = this.consumerTable.get(group); diff --git a/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java b/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java index 376ff9da8e1..82b9080438f 100644 --- a/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java +++ b/client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java @@ -74,7 +74,12 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertEquals; @@ -397,6 +402,119 @@ public void testResetOffset() throws IllegalAccessException { eq(0L)); } + @Test + public void testResetOffsetOrderly() { + topicRouteTable.put(topic, createTopicRouteData()); + brokerAddrTable.put(defaultBroker, createBrokerAddrMap()); + MessageQueue messageQueue = createMessageQueue(); + ProcessQueue processQueue = new ProcessQueue(); + RebalanceImpl rebalanceImpl = mock(RebalanceImpl.class); + when(rebalanceImpl.removeUnnecessaryMessageQueue(eq(messageQueue), eq(processQueue))) + .thenReturn(false, false, true); + consumerTable.put(group, createMQConsumerInner(processQueue, true, rebalanceImpl)); + Map offsetTable = new HashMap<>(); + offsetTable.put(messageQueue, 0L); + + mqClientInstance.resetOffset(topic, group, offsetTable); + + verify(rebalanceImpl).removeUnnecessaryMessageQueue(messageQueue, processQueue); + } + + @Test + public void testResetOffsetOrderlyWhenWaitTimesOut() throws InterruptedException { + topicRouteTable.put(topic, createTopicRouteData()); + brokerAddrTable.put(defaultBroker, createBrokerAddrMap()); + MessageQueue messageQueue = createMessageQueue(); + ProcessQueue processQueue = mock(ProcessQueue.class); + ReadWriteLock consumeLock = mock(ReadWriteLock.class); + Lock writeLock = mock(Lock.class); + RebalanceImpl rebalanceImpl = mock(RebalanceImpl.class); + when(processQueue.getConsumeLock()).thenReturn(consumeLock); + when(consumeLock.writeLock()).thenReturn(writeLock); + when(writeLock.tryLock(10, TimeUnit.SECONDS)).thenReturn(false); + DefaultMQPushConsumerImpl consumer = (DefaultMQPushConsumerImpl) createMQConsumerInner(processQueue, true, rebalanceImpl); + consumerTable.put(group, consumer); + Map offsetTable = new HashMap<>(); + offsetTable.put(messageQueue, 0L); + + mqClientInstance.resetOffset(topic, group, offsetTable); + + verify(consumer).updateConsumeOffset(messageQueue, 0L); + verify(rebalanceImpl).removeUnnecessaryMessageQueue(messageQueue, processQueue); + verify(writeLock, times(1)).tryLock(10, TimeUnit.SECONDS); + verify(writeLock, times(0)).unlock(); + } + + @Test + public void testResetOffsetOrderlyWaitsForInflightConsumptionBeforeUpdatingOffset() throws Exception { + topicRouteTable.put(topic, createTopicRouteData()); + brokerAddrTable.put(defaultBroker, createBrokerAddrMap()); + MessageQueue messageQueue = createMessageQueue(); + ProcessQueue processQueue = new ProcessQueue(); + RebalanceImpl rebalanceImpl = mock(RebalanceImpl.class); + when(rebalanceImpl.removeUnnecessaryMessageQueue(eq(messageQueue), eq(processQueue))).thenReturn(true); + DefaultMQPushConsumerImpl consumer = (DefaultMQPushConsumerImpl) createMQConsumerInner(processQueue, true, rebalanceImpl); + consumerTable.put(group, consumer); + Map offsetTable = new HashMap<>(); + offsetTable.put(messageQueue, 0L); + + CountDownLatch consumeLockHeld = new CountDownLatch(1); + CountDownLatch releaseConsumeLock = new CountDownLatch(1); + CountDownLatch suspendCalled = new CountDownLatch(1); + CountDownLatch updateOffsetCalled = new CountDownLatch(1); + AtomicReference backgroundFailure = new AtomicReference<>(); + + doAnswer(invocation -> { + suspendCalled.countDown(); + return null; + }).when(consumer).suspend(); + doAnswer(invocation -> { + updateOffsetCalled.countDown(); + return null; + }).when(consumer).updateConsumeOffset(messageQueue, 0L); + + Thread consumingThread = new Thread(() -> { + processQueue.getConsumeLock().readLock().lock(); + try { + consumeLockHeld.countDown(); + if (!releaseConsumeLock.await(5, TimeUnit.SECONDS)) { + backgroundFailure.compareAndSet(null, + new AssertionError("Timed out while waiting to release orderly consume lock")); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + backgroundFailure.compareAndSet(null, e); + } finally { + processQueue.getConsumeLock().readLock().unlock(); + } + }); + Thread resetThread = new Thread(() -> { + try { + mqClientInstance.resetOffset(topic, group, offsetTable); + } catch (Throwable t) { + backgroundFailure.compareAndSet(null, t); + } + }); + + consumingThread.start(); + assertTrue(consumeLockHeld.await(5, TimeUnit.SECONDS)); + + resetThread.start(); + assertTrue(suspendCalled.await(5, TimeUnit.SECONDS)); + assertFalse(updateOffsetCalled.await(200, TimeUnit.MILLISECONDS)); + + releaseConsumeLock.countDown(); + consumingThread.join(5000); + resetThread.join(5000); + + assertNull(backgroundFailure.get()); + assertFalse(consumingThread.isAlive()); + assertFalse(resetThread.isAlive()); + assertTrue(updateOffsetCalled.await(1, TimeUnit.SECONDS)); + verify(consumer).updateConsumeOffset(messageQueue, 0L); + verify(rebalanceImpl).removeUnnecessaryMessageQueue(messageQueue, processQueue); + } + @Test public void testGetConsumerStatus() { topicRouteTable.put(topic, createTopicRouteData()); @@ -475,17 +593,26 @@ private HashMap createBrokerAddrMap() { } private MQConsumerInner createMQConsumerInner() { + RebalanceImpl rebalanceImpl = mock(RebalanceImpl.class); + when(rebalanceImpl.removeUnnecessaryMessageQueue(any(MessageQueue.class), any(ProcessQueue.class))).thenReturn(true); + return createMQConsumerInner(new ProcessQueue(), false, rebalanceImpl); + } + + private MQConsumerInner createMQConsumerInner(ProcessQueue processQueue, boolean orderly, RebalanceImpl rebalanceImpl) { + ConcurrentMap processQueueMap = new ConcurrentHashMap<>(); + processQueueMap.put(createMessageQueue(), processQueue); + return createMQConsumerInner(processQueueMap, orderly, rebalanceImpl); + } + + private MQConsumerInner createMQConsumerInner(ConcurrentMap processQueueMap, boolean orderly, RebalanceImpl rebalanceImpl) { DefaultMQPushConsumerImpl result = mock(DefaultMQPushConsumerImpl.class); Set subscriptionDataSet = new HashSet<>(); SubscriptionData subscriptionData = mock(SubscriptionData.class); subscriptionDataSet.add(subscriptionData); when(result.subscriptions()).thenReturn(subscriptionDataSet); - RebalanceImpl rebalanceImpl = mock(RebalanceImpl.class); - ConcurrentMap processQueueMap = new ConcurrentHashMap<>(); - ProcessQueue processQueue = new ProcessQueue(); - processQueueMap.put(createMessageQueue(), processQueue); when(rebalanceImpl.getProcessQueueTable()).thenReturn(processQueueMap); when(result.getRebalanceImpl()).thenReturn(rebalanceImpl); + when(result.isConsumeOrderly()).thenReturn(orderly); OffsetStore offsetStore = mock(OffsetStore.class); when(result.getOffsetStore()).thenReturn(offsetStore); ConsumeMessageService consumeMessageService = mock(ConsumeMessageService.class);