diff --git a/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/protocol/remoting/RemotingProtocolHandler.java b/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/protocol/remoting/RemotingProtocolHandler.java index 49fea89cdd3..1da4432618e 100644 --- a/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/protocol/remoting/RemotingProtocolHandler.java +++ b/proxy/src/main/java/org/apache/rocketmq/proxy/remoting/protocol/remoting/RemotingProtocolHandler.java @@ -52,8 +52,7 @@ public boolean match(ByteBuf in) { public void config(ChannelHandlerContext ctx, ByteBuf msg) { ctx.pipeline().addLast( this.encoderSupplier.get(), - new NettyDecoder(), - this.remotingCodeDistributionHandlerSupplier.get(), + new NettyDecoder(this.remotingCodeDistributionHandlerSupplier.get()), this.connectionManageHandlerSupplier.get(), this.serverHandlerSupplier.get() ); diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyDecoder.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyDecoder.java index 19624d74028..50760a9a056 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyDecoder.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyDecoder.java @@ -32,8 +32,15 @@ public class NettyDecoder extends LengthFieldBasedFrameDecoder { private static final int FRAME_MAX_LENGTH = Integer.parseInt(System.getProperty("com.rocketmq.remoting.frameMaxLength", "16777216")); + private final RemotingCodeDistributionHandler distributionHandler; + public NettyDecoder() { + this(null); + } + + public NettyDecoder(RemotingCodeDistributionHandler distributionHandler) { super(FRAME_MAX_LENGTH, 0, 4, 0, 4); + this.distributionHandler = distributionHandler; } @Override @@ -45,7 +52,13 @@ public Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { if (null == frame) { return null; } + // readableBytes() is the frame size after stripping the 4-byte length prefix; + // add 4 back to get the actual wire size. + int wireSize = frame.readableBytes() + 4; RemotingCommand cmd = RemotingCommand.decode(frame); + if (distributionHandler != null) { + distributionHandler.recordInbound(cmd.getCode(), wireSize); + } cmd.setProcessTimer(timer); return cmd; } catch (Exception e) { diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyEncoder.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyEncoder.java index 2af0af6b725..6cfa63d471d 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyEncoder.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyEncoder.java @@ -30,15 +30,30 @@ public class NettyEncoder extends MessageToByteEncoder { private static final Logger log = LoggerFactory.getLogger(LoggerName.ROCKETMQ_REMOTING_NAME); + private final RemotingCodeDistributionHandler distributionHandler; + + public NettyEncoder() { + this(null); + } + + public NettyEncoder(RemotingCodeDistributionHandler distributionHandler) { + this.distributionHandler = distributionHandler; + } + @Override public void encode(ChannelHandlerContext ctx, RemotingCommand remotingCommand, ByteBuf out) throws Exception { try { + int beginIndex = out.writerIndex(); remotingCommand.fastEncodeHeader(out); byte[] body = remotingCommand.getBody(); if (body != null) { out.writeBytes(body); } + if (distributionHandler != null) { + distributionHandler.recordOutbound( + remotingCommand.getCode(), out.writerIndex() - beginIndex); + } } catch (Exception e) { log.error("encode exception, " + RemotingHelper.parseChannelRemoteAddr(ctx.channel()), e); if (remotingCommand != null) { diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java index 578c102daa4..06f9314a638 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/NettyRemotingServer.java @@ -121,10 +121,10 @@ public class NettyRemotingServer extends NettyRemotingAbstract implements Remoti // sharable handlers protected final TlsModeHandler tlsModeHandler = new TlsModeHandler(TlsSystemConfig.tlsMode); - protected final NettyEncoder encoder = new NettyEncoder(); + protected final NettyEncoder encoder; protected final NettyConnectManageHandler connectionManageHandler = new NettyConnectManageHandler(); protected final NettyServerHandler serverHandler = new NettyServerHandler(); - protected final RemotingCodeDistributionHandler distributionHandler = new RemotingCodeDistributionHandler(); + protected final RemotingCodeDistributionHandler distributionHandler; public NettyRemotingServer(final NettyServerConfig nettyServerConfig) { this(nettyServerConfig, null); @@ -136,6 +136,8 @@ public NettyRemotingServer(final NettyServerConfig nettyServerConfig, this.serverBootstrap = new ServerBootstrap(); this.nettyServerConfig = nettyServerConfig; this.channelEventListener = channelEventListener; + this.distributionHandler = new RemotingCodeDistributionHandler(); + this.encoder = new NettyEncoder(distributionHandler); this.publicExecutor = buildPublicExecutor(nettyServerConfig); this.scheduledExecutorService = buildScheduleExecutor(); @@ -276,8 +278,7 @@ protected ChannelPipeline configChannel(SocketChannel ch) { HANDSHAKE_HANDLER_NAME, new HandshakeHandler()) .addLast(getDefaultEventExecutorGroup(), encoder, - new NettyDecoder(), - distributionHandler, + new NettyDecoder(distributionHandler), new IdleStateHandler(0, 0, nettyServerConfig.getServerChannelMaxIdleTimeSeconds()), connectionManageHandler, @@ -426,6 +427,18 @@ private void printRemotingCodeDistribution() { TRAFFIC_LOGGER.info("Port: {}, ResponseCode Distribution: {}", nettyServerConfig.getListenPort(), outBoundSnapshotString); } + + String inBoundTrafficSnapshotString = distributionHandler.getInBoundTrafficSnapshotString(); + if (inBoundTrafficSnapshotString != null) { + TRAFFIC_LOGGER.info("Port: {}, RequestCode Traffic(byte): {}", + nettyServerConfig.getListenPort(), inBoundTrafficSnapshotString); + } + + String outBoundTrafficSnapshotString = distributionHandler.getOutBoundTrafficSnapshotString(); + if (outBoundTrafficSnapshotString != null) { + TRAFFIC_LOGGER.info("Port: {}, ResponseCode Traffic(byte): {}", + nettyServerConfig.getListenPort(), outBoundTrafficSnapshotString); + } } } diff --git a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandler.java b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandler.java index c6a97fe441b..e23c2f3fc92 100644 --- a/remoting/src/main/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandler.java +++ b/remoting/src/main/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandler.java @@ -16,89 +16,76 @@ */ package org.apache.rocketmq.remoting.netty; -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelHandler; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.LongAdder; -import org.apache.rocketmq.remoting.protocol.RemotingCommand; -@ChannelHandler.Sharable -public class RemotingCodeDistributionHandler extends ChannelDuplexHandler { +/** + * Thread-safe tracker for per-requestCode count and traffic distribution. + *

+ */ +public class RemotingCodeDistributionHandler { - private final ConcurrentMap inboundDistribution; - private final ConcurrentMap outboundDistribution; + private final ConcurrentMap inboundStats = new ConcurrentHashMap<>(); + private final ConcurrentMap outboundStats = new ConcurrentHashMap<>(); - public RemotingCodeDistributionHandler() { - inboundDistribution = new ConcurrentHashMap<>(); - outboundDistribution = new ConcurrentHashMap<>(); + public void recordInbound(int code, int wireSize) { + TrafficStats stats = inboundStats.computeIfAbsent(code, k -> new TrafficStats()); + stats.count.increment(); + stats.trafficSize.add(wireSize); } - private void countInbound(int requestCode) { - LongAdder item = inboundDistribution.computeIfAbsent(requestCode, k -> new LongAdder()); - item.increment(); + public void recordOutbound(int code, int wireSize) { + TrafficStats stats = outboundStats.computeIfAbsent(code, k -> new TrafficStats()); + stats.count.increment(); + stats.trafficSize.add(wireSize); } - private void countOutbound(int responseCode) { - LongAdder item = outboundDistribution.computeIfAbsent(responseCode, k -> new LongAdder()); - item.increment(); + public String getInBoundSnapshotString() { + return snapshotToString(getSnapshot(inboundStats, true)); } - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - if (msg instanceof RemotingCommand) { - RemotingCommand cmd = (RemotingCommand) msg; - countInbound(cmd.getCode()); - } - ctx.fireChannelRead(msg); + public String getOutBoundSnapshotString() { + return snapshotToString(getSnapshot(outboundStats, true)); } - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - if (msg instanceof RemotingCommand) { - RemotingCommand cmd = (RemotingCommand) msg; - countOutbound(cmd.getCode()); - } - ctx.write(msg, promise); + public String getInBoundTrafficSnapshotString() { + return snapshotToString(getSnapshot(inboundStats, false)); + } + + public String getOutBoundTrafficSnapshotString() { + return snapshotToString(getSnapshot(outboundStats, false)); } - private Map getDistributionSnapshot(Map countMap) { - Map map = new HashMap<>(countMap.size()); - for (Map.Entry entry : countMap.entrySet()) { - map.put(entry.getKey(), entry.getValue().sumThenReset()); + private Map getSnapshot(ConcurrentMap statsMap, boolean count) { + Map map = new HashMap<>(statsMap.size()); + for (Map.Entry entry : statsMap.entrySet()) { + LongAdder adder = count ? entry.getValue().count : entry.getValue().trafficSize; + map.put(entry.getKey(), adder.sumThenReset()); } return map; } private String snapshotToString(Map distribution) { - if (null != distribution && !distribution.isEmpty()) { - StringBuilder sb = new StringBuilder("{"); - boolean first = true; - for (Map.Entry entry : distribution.entrySet()) { - if (0L == entry.getValue()) { - continue; - } - sb.append(first ? "" : ", ").append(entry.getKey()).append(":").append(entry.getValue()); - first = false; - } - if (first) { - return null; + if (null == distribution || distribution.isEmpty()) { + return null; + } + StringBuilder sb = new StringBuilder("{"); + boolean first = true; + for (Map.Entry entry : distribution.entrySet()) { + if (0L == entry.getValue()) { + continue; } - sb.append("}"); - return sb.toString(); + sb.append(first ? "" : ", ").append(entry.getKey()).append(":").append(entry.getValue()); + first = false; } - return null; - } - - public String getInBoundSnapshotString() { - return this.snapshotToString(this.getDistributionSnapshot(this.inboundDistribution)); + return first ? null : sb.append("}").toString(); } - public String getOutBoundSnapshotString() { - return this.snapshotToString(this.getDistributionSnapshot(this.outboundDistribution)); + static class TrafficStats { + final LongAdder count = new LongAdder(); + final LongAdder trafficSize = new LongAdder(); } } diff --git a/remoting/src/test/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandlerTest.java b/remoting/src/test/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandlerTest.java index eb623a9de92..36eab6fa09a 100644 --- a/remoting/src/test/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandlerTest.java +++ b/remoting/src/test/java/org/apache/rocketmq/remoting/netty/RemotingCodeDistributionHandlerTest.java @@ -16,45 +16,96 @@ */ package org.apache.rocketmq.remoting.netty; -import java.lang.reflect.Method; -import java.time.Duration; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; -import org.apache.rocketmq.common.ThreadFactoryImpl; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; -import static org.awaitility.Awaitility.await; - public class RemotingCodeDistributionHandlerTest { - private final RemotingCodeDistributionHandler distributionHandler = new RemotingCodeDistributionHandler(); + private RemotingCodeDistributionHandler handler; + + @Before + public void setUp() { + handler = new RemotingCodeDistributionHandler(); + } + + @Test + public void testInboundCountAndTraffic() { + handler.recordInbound(100, 512); + handler.recordInbound(100, 1024); + + Assert.assertEquals("{100:2}", handler.getInBoundSnapshotString()); + Assert.assertEquals("{100:1536}", handler.getInBoundTrafficSnapshotString()); + } @Test - public void remotingCodeCountTest() throws Exception { - Class clazz = RemotingCodeDistributionHandler.class; - Method methodIn = clazz.getDeclaredMethod("countInbound", int.class); - Method methodOut = clazz.getDeclaredMethod("countOutbound", int.class); - methodIn.setAccessible(true); - methodOut.setAccessible(true); + public void testOutboundCountAndTraffic() { + handler.recordOutbound(0, 256); + handler.recordOutbound(0, 256); + handler.recordOutbound(0, 512); + Assert.assertEquals("{0:3}", handler.getOutBoundSnapshotString()); + Assert.assertEquals("{0:1024}", handler.getOutBoundTrafficSnapshotString()); + } + + @Test + public void testMultipleRequestCodes() { + handler.recordInbound(10, 200); + handler.recordInbound(10, 200); + handler.recordInbound(20, 300); + + String countSnapshot = handler.getInBoundSnapshotString(); + Assert.assertNotNull(countSnapshot); + Assert.assertTrue(countSnapshot.contains("10:2")); + Assert.assertTrue(countSnapshot.contains("20:1")); + + String trafficSnapshot = handler.getInBoundTrafficSnapshotString(); + Assert.assertNotNull(trafficSnapshot); + Assert.assertTrue(trafficSnapshot.contains("10:400")); + Assert.assertTrue(trafficSnapshot.contains("20:300")); + } + + @Test + public void testSnapshotResetsAfterRead() { + handler.recordInbound(400, 100); + + Assert.assertNotNull(handler.getInBoundSnapshotString()); + Assert.assertNotNull(handler.getInBoundTrafficSnapshotString()); + + // Second read returns null after sumThenReset + Assert.assertNull(handler.getInBoundSnapshotString()); + Assert.assertNull(handler.getInBoundTrafficSnapshotString()); + } + + @Test + public void testEmptySnapshotReturnsNull() { + Assert.assertNull(handler.getInBoundSnapshotString()); + Assert.assertNull(handler.getOutBoundSnapshotString()); + Assert.assertNull(handler.getInBoundTrafficSnapshotString()); + Assert.assertNull(handler.getOutBoundTrafficSnapshotString()); + } + + @Test + public void testConcurrentAccess() throws Exception { int threadCount = 4; - int count = 1000 * 1000; + int countPerThread = 100_000; + int wireSize = 512; CountDownLatch latch = new CountDownLatch(threadCount); - AtomicBoolean result = new AtomicBoolean(true); - ExecutorService executorService = Executors.newFixedThreadPool(threadCount, new ThreadFactoryImpl("RemotingCodeTest_")); + AtomicBoolean success = new AtomicBoolean(true); + ExecutorService executor = Executors.newFixedThreadPool(threadCount); for (int i = 0; i < threadCount; i++) { - executorService.submit(() -> { + executor.submit(() -> { try { - for (int j = 0; j < count; j++) { - methodIn.invoke(distributionHandler, 1); - methodOut.invoke(distributionHandler, 2); + for (int j = 0; j < countPerThread; j++) { + handler.recordInbound(1, wireSize); } } catch (Exception e) { - result.set(false); + success.set(false); } finally { latch.countDown(); } @@ -62,11 +113,13 @@ public void remotingCodeCountTest() throws Exception { } latch.await(); - Assert.assertTrue(result.get()); - await().pollInterval(Duration.ofMillis(100)).atMost(Duration.ofSeconds(10)).until(() -> { - boolean f1 = ("{1:" + count * threadCount + "}").equals(distributionHandler.getInBoundSnapshotString()); - boolean f2 = ("{2:" + count * threadCount + "}").equals(distributionHandler.getOutBoundSnapshotString()); - return f1 && f2; - }); + Assert.assertTrue(success.get()); + + long totalCount = threadCount * (long) countPerThread; + long totalTraffic = totalCount * wireSize; + Assert.assertEquals("{1:" + totalCount + "}", handler.getInBoundSnapshotString()); + Assert.assertEquals("{1:" + totalTraffic + "}", handler.getInBoundTrafficSnapshotString()); + + executor.shutdown(); } -} \ No newline at end of file +}