Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.stream.ChunkedStream;
import io.netty.handler.stream.ChunkedInput;
import io.netty.util.ReferenceCountUtil;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.util.JavaUtils;

/**
* A wrapper message that holds two separate pieces (a header and a body).
*
* The header must be a ByteBuf, while the body can be any InputStream or ChunkedStream
* The header must be a ByteBuf, while the body can be a ByteBuf, InputStream, or ChunkedStream.
*/
public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {

Expand All @@ -60,8 +62,9 @@ public class EncryptedMessageWithHeader implements ChunkedInput<ByteBuf> {

public EncryptedMessageWithHeader(
@Nullable ManagedBuffer managedBuffer, ByteBuf header, Object body, long bodyLength) {
JavaUtils.checkArgument(body instanceof InputStream || body instanceof ChunkedStream,
"Body must be an InputStream or a ChunkedStream.");
JavaUtils.checkArgument(
body instanceof ByteBuf || body instanceof InputStream || body instanceof ChunkedStream,
"Body must be a ByteBuf, an InputStream, or a ChunkedStream.");
this.managedBuffer = managedBuffer;
this.header = header;
this.headerLength = header.readableBytes();
Expand All @@ -81,38 +84,45 @@ public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
return null;
}

if (totalBytesTransferred < headerLength) {
totalBytesTransferred += headerLength;
return header.retain();
} else if (body instanceof InputStream stream) {
int available = stream.available();
if (available <= 0) {
available = (int) (length() - totalBytesTransferred);
} else {
available = (int) Math.min(available, length() - totalBytesTransferred);
}
ByteBuf buffer = allocator.buffer(available);
int toRead = Math.min(available, buffer.writableBytes());
int read = buffer.writeBytes(stream, toRead);
if (read >= 0) {
totalBytesTransferred += read;
return buffer;
} else {
throw new EOFException("Unable to read bytes from InputStream");
}
} else if (body instanceof ChunkedStream stream) {
long old = stream.transferredBytes();
ByteBuf buffer = stream.readChunk(allocator);
long read = stream.transferredBytes() - old;
if (read >= 0) {
totalBytesTransferred += read;
assert(totalBytesTransferred <= length());
return buffer;
if (body instanceof ByteBuf) {
// For ByteBuf bodies, return header + body as a single composite buffer.
ByteBuf bodyBuf = (ByteBuf) body;
totalBytesTransferred = headerLength + bodyLength;
return Unpooled.wrappedBuffer(header.retain(), bodyBuf.retain());
} else {
if (totalBytesTransferred < headerLength) {
totalBytesTransferred += headerLength;
return header.retain();
} else if (body instanceof InputStream stream) {
int available = stream.available();
if (available <= 0) {
available = (int) (length() - totalBytesTransferred);
} else {
available = (int) Math.min(available, length() - totalBytesTransferred);
}
ByteBuf buffer = allocator.buffer(available);
int toRead = Math.min(available, buffer.writableBytes());
int read = buffer.writeBytes(stream, toRead);
if (read >= 0) {
totalBytesTransferred += read;
return buffer;
} else {
throw new EOFException("Unable to read bytes from InputStream");
}
} else if (body instanceof ChunkedStream stream) {
long old = stream.transferredBytes();
ByteBuf buffer = stream.readChunk(allocator);
long read = stream.transferredBytes() - old;
if (read >= 0) {
totalBytesTransferred += read;
assert(totalBytesTransferred <= length());
return buffer;
} else {
throw new EOFException("Unable to read bytes from ChunkedStream");
}
} else {
throw new EOFException("Unable to read bytes from ChunkedStream");
return null;
}
} else {
return null;
}
}

Expand All @@ -134,6 +144,7 @@ public boolean isEndOfInput() throws Exception {
@Override
public void close() throws Exception {
header.release();
ReferenceCountUtil.release(body);
if (managedBuffer != null) {
managedBuffer.release();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
*/
package org.apache.spark.network.protocol;

import java.io.InputStream;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
import io.netty.handler.stream.ChunkedStream;

import org.apache.spark.internal.SparkLogger;
import org.apache.spark.internal.SparkLoggerFactory;
Expand Down Expand Up @@ -94,15 +91,9 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) thro
assert header.writableBytes() == 0;

if (body != null && bodyLength > 0) {
if (body instanceof ByteBuf byteBuf) {
out.add(Unpooled.wrappedBuffer(header, byteBuf));
} else if (body instanceof InputStream || body instanceof ChunkedStream) {
// For now, assume the InputStream is doing proper chunking.
out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength));
} else {
throw new IllegalArgumentException(
"Body must be a ByteBuf, ChunkedStream or an InputStream");
}
// We transfer ownership of the reference on in.body() to EncryptedMessageWithHeader.
// This reference will be freed when EncryptedMessageWithHeader.close() is called.
out.add(new EncryptedMessageWithHeader(in.body(), header, body, bodyLength));
} else {
out.add(header);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,43 @@ public void testChunkedStream() throws Exception {
assertEquals(0, header.refCnt());
}

// Tests the case where the body is a ByteBuf and that we manage the refcounts of the
// header, body, and managed buffer properly
@Test
public void testByteBufIsNotSupported() throws Exception {
// Validate that ByteBufs are not supported. This test can be updated
// when we add support for them
public void testByteBufBodyFromManagedBuffer() throws Exception {
byte[] randomData = new byte[128];
new Random().nextBytes(randomData);
ByteBuf sourceBuffer = Unpooled.copiedBuffer(randomData);
// convertToNettyForSsl() returns buf.duplicate().retain(), simulate that here
ByteBuf body = sourceBuffer.duplicate().retain();
ByteBuf header = Unpooled.copyLong(42);
assertThrows(IllegalArgumentException.class, () -> {
EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(
null, header, header, 4);
});

long expectedHeaderValue = header.getLong(header.readerIndex());
assertEquals(1, header.refCnt());
assertEquals(2, sourceBuffer.refCnt()); // original + duplicate retain
ManagedBuffer managedBuf = new NettyManagedBuffer(sourceBuffer);

EncryptedMessageWithHeader msg = new EncryptedMessageWithHeader(
managedBuf, header, body, managedBuf.size());
ByteBufAllocator allocator = ByteBufAllocator.DEFAULT;

assertFalse(msg.isEndOfInput());

// Single read should return header + body as a composite buffer
ByteBuf result = msg.readChunk(allocator);
assertEquals(header.capacity() + randomData.length, result.readableBytes());
assertEquals(expectedHeaderValue, result.readLong());
for (int i = 0; i < randomData.length; i++) {
assertEquals(randomData[i], result.readByte());
}
assertTrue(msg.isEndOfInput());

// Release the chunk (simulates Netty writing it out)
result.release();

// Closing the message should release the source buffer via managedBuffer.release()
msg.close();
assertEquals(0, sourceBuffer.refCnt());
assertEquals(0, header.refCnt());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;

import java.util.ArrayList;
import java.util.List;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.buffer.UnpooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

import org.apache.spark.network.buffer.NettyManagedBuffer;

/**
* Verifies reference counting correctness in SslMessageEncoder.encode() for messages with a
* NettyManagedBuffer body.
*
* <p>When convertToNettyForSsl() returns a ByteBuf, the encoder wraps it in an
* EncryptedMessageWithHeader whose close() releases the ManagedBuffer. This mirrors the non-SSL
* MessageEncoder which uses MessageWithHeader.deallocate().
*/
public class SslMessageEncoderSuite {

/**
* Core regression test: encoding an RpcRequest with a NettyManagedBuffer body must leave the
* underlying ByteBuf at refCnt=0 after Netty reads and closes the EncryptedMessageWithHeader.
*/
@Test
public void testNettyManagedBufferBodyIsReleasedAfterEncoding() throws Exception {
ByteBuf bodyBuf = Unpooled.copyLong(1L);
assertEquals(1, bodyBuf.refCnt());

RpcRequest rpcRequest = new RpcRequest(1L, new NettyManagedBuffer(bodyBuf));

ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT);

List<Object> out = new ArrayList<>();
SslMessageEncoder.INSTANCE.encode(ctx, rpcRequest, out);

assertEquals(1, out.size());
assertTrue(out.get(0) instanceof EncryptedMessageWithHeader);

EncryptedMessageWithHeader msg = (EncryptedMessageWithHeader) out.get(0);

// convertToNettyForSsl() called retain on a duplicate, so refCnt is 2
// (original + duplicate). The ManagedBuffer has not been released yet — that
// happens when close() is called.
assertEquals(2, bodyBuf.refCnt());

// Simulate Netty's ChunkedWriteHandler: read the chunk, then release it.
ByteBuf chunk = msg.readChunk(UnpooledByteBufAllocator.DEFAULT);
assertNotNull(chunk);
assertTrue(msg.isEndOfInput());
ReferenceCountUtil.release(chunk);

// Simulate Netty closing the ChunkedInput after transfer completes.
msg.close();

// After close(), the ManagedBuffer is released, bringing refCnt to 0.
assertEquals(0, bodyBuf.refCnt());
}
}
Loading