diff --git a/broker/src/main/java/org/apache/rocketmq/broker/client/ConsumerManager.java b/broker/src/main/java/org/apache/rocketmq/broker/client/ConsumerManager.java index 176456043b0..82e0daf3efd 100644 --- a/broker/src/main/java/org/apache/rocketmq/broker/client/ConsumerManager.java +++ b/broker/src/main/java/org/apache/rocketmq/broker/client/ConsumerManager.java @@ -236,12 +236,10 @@ public boolean registerConsumer(final String group, final ClientChannelInfo clie } for (SubscriptionData subscriptionData : subList) { - Set groups = this.topicGroupTable.get(subscriptionData.getTopic()); - if (groups == null) { - Set tmp = new HashSet<>(); - Set prev = this.topicGroupTable.putIfAbsent(subscriptionData.getTopic(), tmp); - groups = prev != null ? prev : tmp; - } + Set groups = this.topicGroupTable.computeIfAbsent( + subscriptionData.getTopic(), + k -> ConcurrentHashMap.newKeySet() + ); groups.add(group); } @@ -287,12 +285,10 @@ public boolean registerConsumerWithoutSub(final String group, final ClientChanne } for (SubscriptionData subscriptionData : consumerGroupInfo.getSubscriptionTable().values()) { - Set groups = this.topicGroupTable.get(subscriptionData.getTopic()); - if (groups == null) { - Set tmp = new HashSet<>(); - Set prev = this.topicGroupTable.putIfAbsent(subscriptionData.getTopic(), tmp); - groups = prev != null ? prev : tmp; - } + Set groups = this.topicGroupTable.computeIfAbsent( + subscriptionData.getTopic(), + k -> ConcurrentHashMap.newKeySet() + ); groups.add(group); } diff --git a/broker/src/main/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImpl.java b/broker/src/main/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImpl.java index fb6c9de3f3b..cef725a8a82 100644 --- a/broker/src/main/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImpl.java +++ b/broker/src/main/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImpl.java @@ -67,6 +67,8 @@ public class TransactionalMessageServiceImpl implements TransactionalMessageServ private static final int SLEEP_WHILE_NO_OP = 1000; + private static final int PUT_BACK_RETRY_TIMES = 3; + private final ConcurrentHashMap deleteContext = new ConcurrentHashMap<>(); private ServiceThread transactionalOpBatchService; @@ -298,9 +300,30 @@ public void check(long transactionTimeout, int transactionCheckMax, if (isNeedCheck) { - if (!putBackHalfMsgQueue(msgExt, i)) { + int retryTimes = 0; + boolean putBackSuccess = false; + while (retryTimes < PUT_BACK_RETRY_TIMES) { + putBackSuccess = putBackHalfMsgQueue(msgExt, i); + if (putBackSuccess) { + break; + } + retryTimes++; + if (retryTimes < PUT_BACK_RETRY_TIMES) { + try { + Thread.sleep(100L * retryTimes); + } catch (InterruptedException ignored) { + } + } + } + + if (!putBackSuccess) { + log.error("PutBackToHalfQueue failed after {} retries, skip this message. topic={}, queueId={}, offset={}, msgId={}", + PUT_BACK_RETRY_TIMES, msgExt.getTopic(), msgExt.getQueueId(), i, msgExt.getMsgId()); + newOffset = i + 1; + i++; continue; } + putInQueueCount++; log.info("Check transaction. real_topic={},uniqKey={},offset={},commitLogOffset={}", msgExt.getUserProperty(MessageConst.PROPERTY_REAL_TOPIC), diff --git a/broker/src/test/java/org/apache/rocketmq/broker/client/ConsumerManagerConcurrentTest.java b/broker/src/test/java/org/apache/rocketmq/broker/client/ConsumerManagerConcurrentTest.java new file mode 100644 index 00000000000..595831964b5 --- /dev/null +++ b/broker/src/test/java/org/apache/rocketmq/broker/client/ConsumerManagerConcurrentTest.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.rocketmq.broker.client; + +import io.netty.channel.Channel; +import org.apache.rocketmq.common.BrokerConfig; +import org.apache.rocketmq.common.consumer.ConsumeFromWhere; +import org.apache.rocketmq.remoting.protocol.LanguageCode; +import org.apache.rocketmq.remoting.protocol.heartbeat.ConsumeType; +import org.apache.rocketmq.remoting.protocol.heartbeat.MessageModel; +import org.apache.rocketmq.remoting.protocol.heartbeat.SubscriptionData; +import org.apache.rocketmq.store.stats.BrokerStatsManager; +import org.junit.Before; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; + +/** + * Test concurrent registration to verify thread safety of ConsumerManager. + * This test ensures that the fix for the concurrent HashSet modification bug works correctly. + */ +public class ConsumerManagerConcurrentTest { + + private ConsumerManager consumerManager; + private final BrokerConfig brokerConfig = new BrokerConfig(); + + @Before + public void before() { + DefaultConsumerIdsChangeListener defaultConsumerIdsChangeListener = + new DefaultConsumerIdsChangeListener(null); + BrokerStatsManager brokerStatsManager = new BrokerStatsManager(brokerConfig); + consumerManager = new ConsumerManager(defaultConsumerIdsChangeListener, brokerStatsManager, brokerConfig); + } + + /** + * Test concurrent consumer registration for the same topic. + * This test verifies that no data is lost when multiple threads register consumers concurrently. + * + * Before fix: Using HashSet in topicGroupTable could cause data loss (60% reproduction rate) + * After fix: Using ConcurrentHashMap.newKeySet() ensures thread safety + */ + @Test + public void testConcurrentRegisterConsumer() throws InterruptedException { + int threadCount = 100; + String topic = "TestTopic"; + + ExecutorService executor = Executors.newFixedThreadPool(50); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(threadCount); + AtomicInteger successCount = new AtomicInteger(0); + + for (int i = 0; i < threadCount; i++) { + final int index = i; + executor.submit(() -> { + try { + startLatch.await(); + + String group = "Group_" + index; + Channel channel = mock(Channel.class); + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + channel, "Client_" + index, LanguageCode.JAVA, 1); + + Set subList = new HashSet<>(); + subList.add(new SubscriptionData(topic, "*")); + + boolean registered = consumerManager.registerConsumer( + group, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_FIRST_OFFSET, + subList, + false + ); + + if (registered) { + successCount.incrementAndGet(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + + // Start all threads at the same time to maximize contention + startLatch.countDown(); + + // Wait for all threads to complete + boolean finished = endLatch.await(10, TimeUnit.SECONDS); + assertThat(finished).isTrue(); + executor.shutdown(); + + // Verify the result + HashSet groups = consumerManager.queryTopicConsumeByWho(topic); + + // After fix, we should have exactly threadCount groups (no data loss) + assertThat(groups.size()).isEqualTo(threadCount); + assertThat(successCount.get()).isEqualTo(threadCount); + } + + /** + * Test concurrent registration with multiple runs to ensure consistency. + */ + @Test + public void testConcurrentRegisterConsistency() throws InterruptedException { + int iterations = 10; + int threadCount = 50; + + for (int iter = 0; iter < iterations; iter++) { + final int iteration = iter; + String topic = "Topic_" + iteration; + + ExecutorService executor = Executors.newFixedThreadPool(30); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int index = i; + final String topicFinal = topic; + executor.submit(() -> { + try { + startLatch.await(); + + String group = "Group_" + iteration + "_" + index; + Channel channel = mock(Channel.class); + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + channel, "Client_" + index, LanguageCode.JAVA, 1); + + Set subList = new HashSet<>(); + subList.add(new SubscriptionData(topicFinal, "*")); + + consumerManager.registerConsumer( + group, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_FIRST_OFFSET, + subList, + false + ); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + + startLatch.countDown(); + boolean finished = endLatch.await(5, TimeUnit.SECONDS); + assertThat(finished).isTrue(); + executor.shutdown(); + + // Verify no data loss in each iteration + HashSet groups = consumerManager.queryTopicConsumeByWho(topic); + assertThat(groups.size()).isEqualTo(threadCount); + } + } + + /** + * Test high stress scenario with more threads. + */ + @Test + public void testHighConcurrencyStress() throws InterruptedException { + int threadCount = 200; + String topic = "StressTestTopic"; + + ExecutorService executor = Executors.newFixedThreadPool(100); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(threadCount); + + for (int i = 0; i < threadCount; i++) { + final int index = i; + executor.submit(() -> { + try { + startLatch.await(); + + String group = "StressGroup_" + index; + Channel channel = mock(Channel.class); + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + channel, "StressClient_" + index, LanguageCode.JAVA, 1); + + Set subList = new HashSet<>(); + subList.add(new SubscriptionData(topic, "*")); + + consumerManager.registerConsumer( + group, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_FIRST_OFFSET, + subList, + false + ); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + + startLatch.countDown(); + boolean finished = endLatch.await(15, TimeUnit.SECONDS); + assertThat(finished).isTrue(); + executor.shutdown(); + + // Verify no data loss under high stress + HashSet groups = consumerManager.queryTopicConsumeByWho(topic); + assertThat(groups.size()).isEqualTo(threadCount); + } + + /** + * Test concurrent registration for multiple topics. + */ + @Test + public void testConcurrentRegisterMultipleTopics() throws InterruptedException { + int threadCount = 50; + int topicCount = 10; + + ExecutorService executor = Executors.newFixedThreadPool(50); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(threadCount * topicCount); + + for (int t = 0; t < topicCount; t++) { + final String topic = "MultiTopic_" + t; + for (int i = 0; i < threadCount; i++) { + final int index = i; + executor.submit(() -> { + try { + startLatch.await(); + + String group = "MultiGroup_" + topic + "_" + index; + Channel channel = mock(Channel.class); + ClientChannelInfo clientChannelInfo = new ClientChannelInfo( + channel, "MultiClient_" + index, LanguageCode.JAVA, 1); + + Set subList = new HashSet<>(); + subList.add(new SubscriptionData(topic, "*")); + + consumerManager.registerConsumer( + group, + clientChannelInfo, + ConsumeType.CONSUME_PASSIVELY, + MessageModel.CLUSTERING, + ConsumeFromWhere.CONSUME_FROM_FIRST_OFFSET, + subList, + false + ); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + } + + startLatch.countDown(); + boolean finished = endLatch.await(15, TimeUnit.SECONDS); + assertThat(finished).isTrue(); + executor.shutdown(); + + // Verify each topic has exactly threadCount groups + for (int t = 0; t < topicCount; t++) { + String topic = "MultiTopic_" + t; + HashSet groups = consumerManager.queryTopicConsumeByWho(topic); + assertThat(groups.size()).isEqualTo(threadCount); + } + } +} diff --git a/broker/src/test/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImplTest.java b/broker/src/test/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImplTest.java index b92c07dd478..9bebf793124 100644 --- a/broker/src/test/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImplTest.java +++ b/broker/src/test/java/org/apache/rocketmq/broker/transaction/queue/TransactionalMessageServiceImplTest.java @@ -179,6 +179,73 @@ public void testOpen() { assertThat(isOpen).isTrue(); } + @Test + public void testCheck_putBackFailedShouldNotInfiniteLoop() { + // This test verifies that when putBackHalfMsgQueue fails, the check method should not enter an infinite loop + // The check should retry 3 times and then skip the message to continue processing subsequent messages + + when(bridge.fetchMessageQueues(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC)).thenReturn(createMessageQueueSet(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC)); + // Create a message that needs to be checked (old enough) + when(bridge.getHalfMessage(0, 0, 1)).thenReturn(createPullResult(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC, 5, "hello", 1)); + when(bridge.getHalfMessage(0, 1, 1)).thenReturn(createPullResult(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC, 6, "hello2", 0)); + when(bridge.getOpMessage(anyInt(), anyLong(), anyInt())).thenReturn(createPullResult(TopicValidator.RMQ_SYS_TRANS_OP_HALF_TOPIC, 1, "5", 0)); + when(bridge.getBrokerController()).thenReturn(this.brokerController); + when(bridge.renewHalfMessageInner(any(MessageExtBrokerInner.class))).thenReturn(createMessageBrokerInner()); + // Simulate putBack failure - return PUT_FAILED status + when(bridge.putMessageReturnResult(any(MessageExtBrokerInner.class))) + .thenReturn(new PutMessageResult(PutMessageStatus.CREATE_MAPPED_FILE_FAILED, null)); + // Mock fetchConsumeOffset to return valid offset + when(bridge.fetchConsumeOffset(any(MessageQueue.class))).thenReturn(0L); + + long timeOut = this.brokerController.getBrokerConfig().getTransactionTimeOut(); + final int checkMax = this.brokerController.getBrokerConfig().getTransactionCheckMax(); + + // This should complete without getting stuck in an infinite loop + long startTime = System.currentTimeMillis(); + queueTransactionMsgService.check(timeOut, checkMax, listener); + long elapsedTime = System.currentTimeMillis() - startTime; + + // The check should complete quickly (within a few seconds), not run for MAX_PROCESS_TIME_LIMIT (60s) + assertThat(elapsedTime).isLessThan(5000L); + // Verify that putMessageReturnResult was called 3 times (retry limit) + verify(bridge, org.mockito.Mockito.times(3)).putMessageReturnResult(any(MessageExtBrokerInner.class)); + } + + @Test + public void testCheck_putBackSucceedsAfterRetry() { + // This test verifies that if putBackHalfMsgQueue succeeds after retry, the check continues normally + + when(bridge.fetchMessageQueues(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC)).thenReturn(createMessageQueueSet(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC)); + when(bridge.getHalfMessage(0, 0, 1)).thenReturn(createPullResult(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC, 5, "hello", 1)); + when(bridge.getHalfMessage(0, 1, 1)).thenReturn(createPullResult(TopicValidator.RMQ_SYS_TRANS_HALF_TOPIC, 6, "hello2", 0)); + when(bridge.getOpMessage(anyInt(), anyLong(), anyInt())).thenReturn(createPullResult(TopicValidator.RMQ_SYS_TRANS_OP_HALF_TOPIC, 1, "5", 0)); + when(bridge.getBrokerController()).thenReturn(this.brokerController); + when(bridge.renewHalfMessageInner(any(MessageExtBrokerInner.class))).thenReturn(createMessageBrokerInner()); + when(bridge.fetchConsumeOffset(any(MessageQueue.class))).thenReturn(0L); + + // First call fails, second call succeeds + org.mockito.Mockito.when(bridge.putMessageReturnResult(any(MessageExtBrokerInner.class))) + .thenReturn(new PutMessageResult(PutMessageStatus.CREATE_MAPPED_FILE_FAILED, null)) + .thenReturn(new PutMessageResult(PutMessageStatus.PUT_OK, new AppendMessageResult(AppendMessageStatus.PUT_OK))); + + long timeOut = this.brokerController.getBrokerConfig().getTransactionTimeOut(); + final int checkMax = this.brokerController.getBrokerConfig().getTransactionCheckMax(); + + final AtomicInteger checkMessage = new AtomicInteger(0); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) { + checkMessage.addAndGet(1); + return checkMessage; + } + }).when(listener).resolveHalfMsg(any(MessageExt.class)); + + queueTransactionMsgService.check(timeOut, checkMax, listener); + + // resolveHalfMsg should be called once since putBack succeeded on retry + assertThat(checkMessage.get()).isEqualTo(1); + } + private PullResult createDiscardPullResult(String topic, long queueOffset, String body, int size) { PullResult result = createPullResult(topic, queueOffset, body, size); List msgs = result.getMsgFoundList();