Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ public boolean match(ByteBuf in) {
public void config(ChannelHandlerContext ctx, ByteBuf msg) {
ctx.pipeline().addLast(
this.encoderSupplier.get(),
new NettyDecoder(),
this.remotingCodeDistributionHandlerSupplier.get(),
new NettyDecoder(this.remotingCodeDistributionHandlerSupplier.get()),
this.connectionManageHandlerSupplier.get(),
this.serverHandlerSupplier.get()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ public class NettyDecoder extends LengthFieldBasedFrameDecoder {
private static final int FRAME_MAX_LENGTH =
Integer.parseInt(System.getProperty("com.rocketmq.remoting.frameMaxLength", "16777216"));

private final RemotingCodeDistributionHandler distributionHandler;

public NettyDecoder() {
this(null);
}

public NettyDecoder(RemotingCodeDistributionHandler distributionHandler) {
super(FRAME_MAX_LENGTH, 0, 4, 0, 4);
this.distributionHandler = distributionHandler;
}

@Override
Expand All @@ -45,7 +52,13 @@ public Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
if (null == frame) {
return null;
}
// readableBytes() is the frame size after stripping the 4-byte length prefix;
// add 4 back to get the actual wire size.
int wireSize = frame.readableBytes() + 4;
RemotingCommand cmd = RemotingCommand.decode(frame);
if (distributionHandler != null) {
distributionHandler.recordInbound(cmd.getCode(), wireSize);
}
cmd.setProcessTimer(timer);
return cmd;
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,30 @@
public class NettyEncoder extends MessageToByteEncoder<RemotingCommand> {
private static final Logger log = LoggerFactory.getLogger(LoggerName.ROCKETMQ_REMOTING_NAME);

private final RemotingCodeDistributionHandler distributionHandler;

public NettyEncoder() {
this(null);
}

public NettyEncoder(RemotingCodeDistributionHandler distributionHandler) {
this.distributionHandler = distributionHandler;
}

@Override
public void encode(ChannelHandlerContext ctx, RemotingCommand remotingCommand, ByteBuf out)
throws Exception {
try {
int beginIndex = out.writerIndex();
remotingCommand.fastEncodeHeader(out);
byte[] body = remotingCommand.getBody();
if (body != null) {
out.writeBytes(body);
}
if (distributionHandler != null) {
distributionHandler.recordOutbound(
remotingCommand.getCode(), out.writerIndex() - beginIndex);
}
} catch (Exception e) {
log.error("encode exception, " + RemotingHelper.parseChannelRemoteAddr(ctx.channel()), e);
if (remotingCommand != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,10 @@ public class NettyRemotingServer extends NettyRemotingAbstract implements Remoti

// sharable handlers
protected final TlsModeHandler tlsModeHandler = new TlsModeHandler(TlsSystemConfig.tlsMode);
protected final NettyEncoder encoder = new NettyEncoder();
protected final NettyEncoder encoder;
protected final NettyConnectManageHandler connectionManageHandler = new NettyConnectManageHandler();
protected final NettyServerHandler serverHandler = new NettyServerHandler();
protected final RemotingCodeDistributionHandler distributionHandler = new RemotingCodeDistributionHandler();
protected final RemotingCodeDistributionHandler distributionHandler;

public NettyRemotingServer(final NettyServerConfig nettyServerConfig) {
this(nettyServerConfig, null);
Expand All @@ -136,6 +136,8 @@ public NettyRemotingServer(final NettyServerConfig nettyServerConfig,
this.serverBootstrap = new ServerBootstrap();
this.nettyServerConfig = nettyServerConfig;
this.channelEventListener = channelEventListener;
this.distributionHandler = new RemotingCodeDistributionHandler();
this.encoder = new NettyEncoder(distributionHandler);

this.publicExecutor = buildPublicExecutor(nettyServerConfig);
this.scheduledExecutorService = buildScheduleExecutor();
Expand Down Expand Up @@ -276,8 +278,7 @@ protected ChannelPipeline configChannel(SocketChannel ch) {
HANDSHAKE_HANDLER_NAME, new HandshakeHandler())
.addLast(getDefaultEventExecutorGroup(),
encoder,
new NettyDecoder(),
distributionHandler,
new NettyDecoder(distributionHandler),
new IdleStateHandler(0, 0,
nettyServerConfig.getServerChannelMaxIdleTimeSeconds()),
connectionManageHandler,
Expand Down Expand Up @@ -426,6 +427,18 @@ private void printRemotingCodeDistribution() {
TRAFFIC_LOGGER.info("Port: {}, ResponseCode Distribution: {}",
nettyServerConfig.getListenPort(), outBoundSnapshotString);
}

String inBoundTrafficSnapshotString = distributionHandler.getInBoundTrafficSnapshotString();
if (inBoundTrafficSnapshotString != null) {
TRAFFIC_LOGGER.info("Port: {}, RequestCode Traffic(byte): {}",
nettyServerConfig.getListenPort(), inBoundTrafficSnapshotString);
}

String outBoundTrafficSnapshotString = distributionHandler.getOutBoundTrafficSnapshotString();
if (outBoundTrafficSnapshotString != null) {
TRAFFIC_LOGGER.info("Port: {}, ResponseCode Traffic(byte): {}",
nettyServerConfig.getListenPort(), outBoundTrafficSnapshotString);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,89 +16,76 @@
*/
package org.apache.rocketmq.remoting.netty;

import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.LongAdder;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;

@ChannelHandler.Sharable
public class RemotingCodeDistributionHandler extends ChannelDuplexHandler {
/**
* Thread-safe tracker for per-requestCode count and traffic distribution.
* <p>
*/
public class RemotingCodeDistributionHandler {

private final ConcurrentMap<Integer, LongAdder> inboundDistribution;
private final ConcurrentMap<Integer, LongAdder> outboundDistribution;
private final ConcurrentMap<Integer, TrafficStats> inboundStats = new ConcurrentHashMap<>();
private final ConcurrentMap<Integer, TrafficStats> outboundStats = new ConcurrentHashMap<>();

public RemotingCodeDistributionHandler() {
inboundDistribution = new ConcurrentHashMap<>();
outboundDistribution = new ConcurrentHashMap<>();
public void recordInbound(int code, int wireSize) {
TrafficStats stats = inboundStats.computeIfAbsent(code, k -> new TrafficStats());
stats.count.increment();
stats.trafficSize.add(wireSize);
}

private void countInbound(int requestCode) {
LongAdder item = inboundDistribution.computeIfAbsent(requestCode, k -> new LongAdder());
item.increment();
public void recordOutbound(int code, int wireSize) {
TrafficStats stats = outboundStats.computeIfAbsent(code, k -> new TrafficStats());
stats.count.increment();
stats.trafficSize.add(wireSize);
}

private void countOutbound(int responseCode) {
LongAdder item = outboundDistribution.computeIfAbsent(responseCode, k -> new LongAdder());
item.increment();
public String getInBoundSnapshotString() {
return snapshotToString(getSnapshot(inboundStats, true));
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
if (msg instanceof RemotingCommand) {
RemotingCommand cmd = (RemotingCommand) msg;
countInbound(cmd.getCode());
}
ctx.fireChannelRead(msg);
public String getOutBoundSnapshotString() {
return snapshotToString(getSnapshot(outboundStats, true));
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
if (msg instanceof RemotingCommand) {
RemotingCommand cmd = (RemotingCommand) msg;
countOutbound(cmd.getCode());
}
ctx.write(msg, promise);
public String getInBoundTrafficSnapshotString() {
return snapshotToString(getSnapshot(inboundStats, false));
}

public String getOutBoundTrafficSnapshotString() {
return snapshotToString(getSnapshot(outboundStats, false));
}

private Map<Integer, Long> getDistributionSnapshot(Map<Integer, LongAdder> countMap) {
Map<Integer, Long> map = new HashMap<>(countMap.size());
for (Map.Entry<Integer, LongAdder> entry : countMap.entrySet()) {
map.put(entry.getKey(), entry.getValue().sumThenReset());
private Map<Integer, Long> getSnapshot(ConcurrentMap<Integer, TrafficStats> statsMap, boolean count) {
Map<Integer, Long> map = new HashMap<>(statsMap.size());
for (Map.Entry<Integer, TrafficStats> entry : statsMap.entrySet()) {
LongAdder adder = count ? entry.getValue().count : entry.getValue().trafficSize;
map.put(entry.getKey(), adder.sumThenReset());
}
return map;
}

private String snapshotToString(Map<Integer, Long> distribution) {
if (null != distribution && !distribution.isEmpty()) {
StringBuilder sb = new StringBuilder("{");
boolean first = true;
for (Map.Entry<Integer, Long> entry : distribution.entrySet()) {
if (0L == entry.getValue()) {
continue;
}
sb.append(first ? "" : ", ").append(entry.getKey()).append(":").append(entry.getValue());
first = false;
}
if (first) {
return null;
if (null == distribution || distribution.isEmpty()) {
return null;
}
StringBuilder sb = new StringBuilder("{");
boolean first = true;
for (Map.Entry<Integer, Long> entry : distribution.entrySet()) {
if (0L == entry.getValue()) {
continue;
}
sb.append("}");
return sb.toString();
sb.append(first ? "" : ", ").append(entry.getKey()).append(":").append(entry.getValue());
first = false;
}
return null;
}

public String getInBoundSnapshotString() {
return this.snapshotToString(this.getDistributionSnapshot(this.inboundDistribution));
return first ? null : sb.append("}").toString();
}

public String getOutBoundSnapshotString() {
return this.snapshotToString(this.getDistributionSnapshot(this.outboundDistribution));
static class TrafficStats {
final LongAdder count = new LongAdder();
final LongAdder trafficSize = new LongAdder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,110 @@
*/
package org.apache.rocketmq.remoting.netty;

import java.lang.reflect.Method;
import java.time.Duration;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.rocketmq.common.ThreadFactoryImpl;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import static org.awaitility.Awaitility.await;

public class RemotingCodeDistributionHandlerTest {

private final RemotingCodeDistributionHandler distributionHandler = new RemotingCodeDistributionHandler();
private RemotingCodeDistributionHandler handler;

@Before
public void setUp() {
handler = new RemotingCodeDistributionHandler();
}

@Test
public void testInboundCountAndTraffic() {
handler.recordInbound(100, 512);
handler.recordInbound(100, 1024);

Assert.assertEquals("{100:2}", handler.getInBoundSnapshotString());
Assert.assertEquals("{100:1536}", handler.getInBoundTrafficSnapshotString());
}

@Test
public void remotingCodeCountTest() throws Exception {
Class<RemotingCodeDistributionHandler> clazz = RemotingCodeDistributionHandler.class;
Method methodIn = clazz.getDeclaredMethod("countInbound", int.class);
Method methodOut = clazz.getDeclaredMethod("countOutbound", int.class);
methodIn.setAccessible(true);
methodOut.setAccessible(true);
public void testOutboundCountAndTraffic() {
handler.recordOutbound(0, 256);
handler.recordOutbound(0, 256);
handler.recordOutbound(0, 512);

Assert.assertEquals("{0:3}", handler.getOutBoundSnapshotString());
Assert.assertEquals("{0:1024}", handler.getOutBoundTrafficSnapshotString());
}

@Test
public void testMultipleRequestCodes() {
handler.recordInbound(10, 200);
handler.recordInbound(10, 200);
handler.recordInbound(20, 300);

String countSnapshot = handler.getInBoundSnapshotString();
Assert.assertNotNull(countSnapshot);
Assert.assertTrue(countSnapshot.contains("10:2"));
Assert.assertTrue(countSnapshot.contains("20:1"));

String trafficSnapshot = handler.getInBoundTrafficSnapshotString();
Assert.assertNotNull(trafficSnapshot);
Assert.assertTrue(trafficSnapshot.contains("10:400"));
Assert.assertTrue(trafficSnapshot.contains("20:300"));
}

@Test
public void testSnapshotResetsAfterRead() {
handler.recordInbound(400, 100);

Assert.assertNotNull(handler.getInBoundSnapshotString());
Assert.assertNotNull(handler.getInBoundTrafficSnapshotString());

// Second read returns null after sumThenReset
Assert.assertNull(handler.getInBoundSnapshotString());
Assert.assertNull(handler.getInBoundTrafficSnapshotString());
}

@Test
public void testEmptySnapshotReturnsNull() {
Assert.assertNull(handler.getInBoundSnapshotString());
Assert.assertNull(handler.getOutBoundSnapshotString());
Assert.assertNull(handler.getInBoundTrafficSnapshotString());
Assert.assertNull(handler.getOutBoundTrafficSnapshotString());
}

@Test
public void testConcurrentAccess() throws Exception {
int threadCount = 4;
int count = 1000 * 1000;
int countPerThread = 100_000;
int wireSize = 512;
CountDownLatch latch = new CountDownLatch(threadCount);
AtomicBoolean result = new AtomicBoolean(true);
ExecutorService executorService = Executors.newFixedThreadPool(threadCount, new ThreadFactoryImpl("RemotingCodeTest_"));
AtomicBoolean success = new AtomicBoolean(true);
ExecutorService executor = Executors.newFixedThreadPool(threadCount);

for (int i = 0; i < threadCount; i++) {
executorService.submit(() -> {
executor.submit(() -> {
try {
for (int j = 0; j < count; j++) {
methodIn.invoke(distributionHandler, 1);
methodOut.invoke(distributionHandler, 2);
for (int j = 0; j < countPerThread; j++) {
handler.recordInbound(1, wireSize);
}
} catch (Exception e) {
result.set(false);
success.set(false);
} finally {
latch.countDown();
}
});
}

latch.await();
Assert.assertTrue(result.get());
await().pollInterval(Duration.ofMillis(100)).atMost(Duration.ofSeconds(10)).until(() -> {
boolean f1 = ("{1:" + count * threadCount + "}").equals(distributionHandler.getInBoundSnapshotString());
boolean f2 = ("{2:" + count * threadCount + "}").equals(distributionHandler.getOutBoundSnapshotString());
return f1 && f2;
});
Assert.assertTrue(success.get());

long totalCount = threadCount * (long) countPerThread;
long totalTraffic = totalCount * wireSize;
Assert.assertEquals("{1:" + totalCount + "}", handler.getInBoundSnapshotString());
Assert.assertEquals("{1:" + totalTraffic + "}", handler.getInBoundTrafficSnapshotString());

executor.shutdown();
}
}
}
Loading