From 24d4de17acad45d465b187bfe67f8ffef3d5c2f4 Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 9 Jan 2026 21:00:27 +0200 Subject: [PATCH 1/4] Reduce data copies in connection receive path Optimize buffer management in `_ConnectionIOBuffer` to avoid unnecessary byte allocations during high-throughput reads. 1. **Buffer Compaction**: Replaced `io.BytesIO(buffer.read())` with a `getbuffer()` slice. The previous method `read()` allocated a new `bytes` object for the remaining content before creating the new generic `BytesIO`. The new approach uses a zero-copy memoryview slice for initialization. 2. **Header Peeking**: Replaced `getvalue()` in `_read_frame_header` with `getbuffer()`. This allows inspecting the protocol version and frame length without materializing the entire buffer contents into a new `bytes` string. Signed-off-by: Yaniv Kaul --- cassandra/connection.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..0338ec70e6 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -630,14 +630,25 @@ def readable_io_bytes(self): def readable_cql_frame_bytes(self): return self.cql_frame_buffer.tell() + @staticmethod + def _reset_buffer(buf): + """ + Reset a BytesIO buffer by discarding consumed data. + Avoid an intermediate bytes copy from .read(); slice the existing buffer. + BytesIO will still materialize its own backing store, but this reduces + one full-buffer allocation on the hot receive path. + """ + pos = buf.tell() + new_buf = io.BytesIO(buf.getbuffer()[pos:]) + new_buf.seek(0, 2) # 2 == SEEK_END + return new_buf + def reset_io_buffer(self): - self._io_buffer = io.BytesIO(self._io_buffer.read()) - self._io_buffer.seek(0, 2) # 2 == SEEK_END + self._io_buffer = self._reset_buffer(self._io_buffer) def reset_cql_frame_buffer(self): if self.is_checksumming_enabled: - self._cql_frame_buffer = io.BytesIO(self._cql_frame_buffer.read()) - self._cql_frame_buffer.seek(0, 2) # 2 == SEEK_END + self._cql_frame_buffer = self._reset_buffer(self._cql_frame_buffer) else: self.reset_io_buffer() @@ -1191,9 +1202,10 @@ def control_conn_disposed(self): @defunct_on_error def _read_frame_header(self): - buf = self._io_buffer.cql_frame_buffer.getvalue() - pos = len(buf) + cql_buf = self._io_buffer.cql_frame_buffer + pos = cql_buf.tell() if pos: + buf = cql_buf.getbuffer() version = buf[0] & PROTOCOL_VERSION_MASK if version not in ProtocolVersion.SUPPORTED_VERSIONS: raise ProtocolError("This version of the driver does not support protocol version %d" % version) From c93e0d025c7a7d0281be5f778d330bab0671bfaf Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 9 Jan 2026 22:26:15 +0200 Subject: [PATCH 2/4] Optimize frame body extraction to reduce memory copies Replace BytesIO.read() with direct buffer slicing to eliminate one intermediate bytes allocation per received message frame. Changes: - Use getbuffer() to get memoryview of underlying buffer - Slice directly at [body_offset:end_pos] instead of seek+read - Convert memoryview slice to bytes in single operation - Maintain buffer position tracking for proper reset behavior Benefits: - Eliminates one full-frame allocation on hot receive path - Maintains compatibility with existing protocol decoder The memoryview is immediately converted to bytes and released, preventing buffer resize issues while still gaining the allocation savings. Signed-off-by: Yaniv Kaul --- cassandra/connection.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 0338ec70e6..988fd0a946 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1269,8 +1269,12 @@ def process_io_buffer(self): return else: frame = self._current_frame - self._io_buffer.cql_frame_buffer.seek(frame.body_offset) - msg = self._io_buffer.cql_frame_buffer.read(frame.end_pos - frame.body_offset) + # Use memoryview to avoid intermediate allocation, then convert to bytes + # Explicitly scope the buffer to ensure memoryview is released before reset + cql_buf = self._io_buffer.cql_frame_buffer + msg = bytes(cql_buf.getbuffer()[frame.body_offset:frame.end_pos]) + # Advance buffer position to end of frame before reset + cql_buf.seek(frame.end_pos) self.process_msg(frame, msg) self._io_buffer.reset_cql_frame_buffer() self._current_frame = None From e3e581212dc66c59d656590349a87d2da4d6dddf Mon Sep 17 00:00:00 2001 From: Yaniv Michael Kaul Date: Fri, 9 Jan 2026 22:56:04 +0200 Subject: [PATCH 3/4] Add BytesReader to replace BytesIO in decode_message() Introduce a lightweight BytesReader class that provides the same read() interface as io.BytesIO but operates directly on the input bytes/memoryview without internal buffering overhead. Changes: - Add BytesReader class with __slots__ for memory efficiency - Replace io.BytesIO(body) with BytesReader(body) in decode_message() - BytesReader.read() returns slices directly, converting memoryview to bytes only when necessary for compatibility Benefits: - Eliminates BytesIO's internal buffer allocation and management - Reduces memory overhead for protocol message decoding - Works seamlessly with both bytes and memoryview inputs - Maintains full API compatibility with existing read_* functions The BytesReader is a minimal implementation focused on the read() method needed by the protocol decoder. It avoids the overhead of io.BytesIO's full file-like interface. Signed-off-by: Yaniv Kaul --- cassandra/protocol.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..f74fefbe1d 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -53,6 +53,34 @@ class NotSupportedError(Exception): class InternalError(Exception): pass + +class BytesReader: + """ + Lightweight reader for bytes data without BytesIO overhead. + Provides the same read() interface but operates directly on a + bytes or memoryview object, avoiding internal buffer copies. + """ + __slots__ = ('_data', '_pos', '_size') + + def __init__(self, data): + self._data = data + self._pos = 0 + self._size = len(data) + + def read(self, n=-1): + if n < 0: + result = self._data[self._pos:] + self._pos = self._size + else: + end = self._pos + n + if end > self._size: + raise EOFError("Cannot read past the end of the buffer") + result = self._data[self._pos:end] + self._pos = end + # Return bytes to maintain compatibility with unpack functions + return bytes(result) if isinstance(result, memoryview) else result + + ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type']) HEADER_DIRECTION_TO_CLIENT = 0x80 @@ -1154,7 +1182,8 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre body = decompressor(body) flags ^= COMPRESSED_FLAG - body = io.BytesIO(body) + # Use lightweight BytesReader instead of io.BytesIO to avoid buffer copy + body = BytesReader(body) if flags & TRACING_FLAG: trace_id = UUID(bytes=body.read(16)) flags ^= TRACING_FLAG From 1ee8e85059b17a4439a3921a3026fc5742660176 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sun, 22 Feb 2026 22:51:50 -0400 Subject: [PATCH 4/4] Add tests and fix BytesReader to match BytesIO.read() semantics BytesReader.read(n) now returns partial data at end-of-buffer instead of raising EOFError, matching BytesIO behavior. Added 11 unit tests for BytesReader and 6 tests verifying _ConnectionIOBuffer reset position semantics under both checksumming and non-checksumming paths. --- cassandra/protocol.py | 11 ++--- tests/unit/test_connection.py | 80 ++++++++++++++++++++++++++++++++++- tests/unit/test_protocol.py | 60 +++++++++++++++++++++++++- 3 files changed, 144 insertions(+), 7 deletions(-) diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f74fefbe1d..009c308c96 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -57,8 +57,11 @@ class InternalError(Exception): class BytesReader: """ Lightweight reader for bytes data without BytesIO overhead. - Provides the same read() interface but operates directly on a - bytes or memoryview object, avoiding internal buffer copies. + Provides the same read() interface as BytesIO but operates directly + on a bytes or memoryview object, avoiding internal buffer copies. + + read(n) behaves like BytesIO.read(n): returns up to n bytes and + returns fewer bytes (or empty bytes) when the end of data is reached. """ __slots__ = ('_data', '_pos', '_size') @@ -72,9 +75,7 @@ def read(self, n=-1): result = self._data[self._pos:] self._pos = self._size else: - end = self._pos + n - if end > self._size: - raise EOFError("Cannot read past the end of the buffer") + end = min(self._pos + n, self._size) result = self._data[self._pos:end] self._pos = end # Return bytes to maintain compatibility with unpack functions diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6ac63ff761..30dcb81b38 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,8 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator, + _ConnectionIOBuffer) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) @@ -571,3 +572,80 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): second_run = list(itertools.islice(gen.generate(0, 2), 5)) assert first_run == second_run + + +class TestConnectionIOBufferReset(unittest.TestCase): + """Verify _reset_buffer and reset_cql_frame_buffer position semantics.""" + + def test_reset_buffer_discards_consumed_data(self): + buf = BytesIO(b'\x01\x02\x03\x04\x05') + buf.seek(0) + # Consume first 3 bytes + assert buf.read(3) == b'\x01\x02\x03' + new_buf = _ConnectionIOBuffer._reset_buffer(buf) + # New buffer should contain only unconsumed data + new_buf.seek(0) + assert new_buf.read() == b'\x04\x05' + + def test_reset_buffer_position_at_end(self): + buf = BytesIO(b'\x01\x02\x03') + buf.seek(0) + buf.read(1) + new_buf = _ConnectionIOBuffer._reset_buffer(buf) + # Position should be at end (ready for appending) + assert new_buf.tell() == 2 + + def test_reset_buffer_fully_consumed(self): + buf = BytesIO(b'\x01\x02') + buf.seek(0) + buf.read(2) + new_buf = _ConnectionIOBuffer._reset_buffer(buf) + new_buf.seek(0) + assert new_buf.read() == b'' + + def test_reset_buffer_nothing_consumed(self): + buf = BytesIO(b'\x01\x02\x03') + buf.seek(0) + new_buf = _ConnectionIOBuffer._reset_buffer(buf) + new_buf.seek(0) + assert new_buf.read() == b'\x01\x02\x03' + + @staticmethod + def _make_iobuf(checksumming=False): + conn = Mock() + conn._is_checksumming_enabled = checksumming + iobuf = _ConnectionIOBuffer(conn) + # Keep a strong reference so the weakref.proxy inside iobuf stays valid + iobuf._conn_ref = conn + if checksumming: + iobuf.set_checksumming_buffer() + return iobuf + + def test_reset_cql_frame_buffer_checksumming_uses_tell_position(self): + """ + When checksumming is enabled, reset_cql_frame_buffer delegates to + _reset_buffer which relies on tell() to determine consumed data. + Verify that seeking to an arbitrary position before reset correctly + preserves only the unconsumed tail. + """ + iobuf = self._make_iobuf(checksumming=True) + # Write some data into the cql_frame_buffer + iobuf.cql_frame_buffer.write(b'\xAA\xBB\xCC\xDD\xEE') + # Seek to position 3, simulating that first 3 bytes were consumed + iobuf.cql_frame_buffer.seek(3) + iobuf.reset_cql_frame_buffer() + # After reset, only the unconsumed tail should remain + iobuf.cql_frame_buffer.seek(0) + assert iobuf.cql_frame_buffer.read() == b'\xDD\xEE' + + def test_reset_cql_frame_buffer_no_checksumming_resets_io_buffer(self): + """ + Without checksumming, reset_cql_frame_buffer delegates to + reset_io_buffer (since cql_frame_buffer IS the io_buffer). + """ + iobuf = self._make_iobuf(checksumming=False) + iobuf.io_buffer.write(b'\x01\x02\x03\x04') + iobuf.io_buffer.seek(2) + iobuf.reset_cql_frame_buffer() + iobuf.io_buffer.seek(0) + assert iobuf.io_buffer.read() == b'\x03\x04' diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 9704811239..4b9ec27bad 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -21,7 +21,7 @@ PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, - BatchMessage + BatchMessage, BytesReader ) from cassandra.query import BatchType from cassandra.marshal import uint32_unpack @@ -189,3 +189,61 @@ def test_batch_message_with_keyspace(self): (b'\x00\x03',), (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) ) + + +class BytesReaderTest(unittest.TestCase): + + def test_read_exact_n_bytes(self): + reader = BytesReader(b'\x01\x02\x03\x04\x05') + assert reader.read(3) == b'\x01\x02\x03' + + def test_read_all_with_negative(self): + reader = BytesReader(b'hello') + assert reader.read(-1) == b'hello' + + def test_read_all_default(self): + reader = BytesReader(b'hello') + reader.read(2) + assert reader.read(-1) == b'llo' + + def test_sequential_reads(self): + reader = BytesReader(b'\x01\x02\x03\x04') + assert reader.read(2) == b'\x01\x02' + assert reader.read(2) == b'\x03\x04' + + def test_read_past_end_returns_partial(self): + reader = BytesReader(b'\x01\x02') + assert reader.read(3) == b'\x01\x02' + + def test_read_past_end_after_partial_consume(self): + reader = BytesReader(b'\x01\x02\x03') + reader.read(2) + assert reader.read(2) == b'\x03' + + def test_read_at_end_returns_empty(self): + reader = BytesReader(b'\x01\x02') + reader.read(2) + assert reader.read(1) == b'' + + def test_read_zero_bytes(self): + reader = BytesReader(b'abc') + assert reader.read(0) == b'' + + def test_memoryview_input_returns_bytes(self): + data = b'\x01\x02\x03\x04' + reader = BytesReader(memoryview(data)) + result = reader.read(2) + assert isinstance(result, bytes) + assert result == b'\x01\x02' + + def test_memoryview_read_all_returns_bytes(self): + data = b'\x01\x02\x03' + reader = BytesReader(memoryview(data)) + result = reader.read(-1) + assert isinstance(result, bytes) + assert result == b'\x01\x02\x03' + + def test_empty_input(self): + reader = BytesReader(b'') + assert reader.read(-1) == b'' + assert reader.read(1) == b''