Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions src/amp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,42 +791,62 @@ def query_and_load_streaming(
self.logger.warning(f'Failed to load checkpoint, starting from beginning: {e}')

try:
# Execute streaming query with Flight SQL
# Create a CommandStatementQuery message
command_query = FlightSql_pb2.CommandStatementQuery()
command_query.query = query

# Add resume watermark if provided
if resume_watermark:
# TODO: Add watermark to query metadata when Flight SQL supports it
self.logger.info(f'Resuming stream from watermark: {resume_watermark}')

# Wrap the CommandStatementQuery in an Any type
any_command = Any()
any_command.Pack(command_query)
cmd = any_command.SerializeToString()

self.logger.info('Establishing Flight SQL connection...')
flight_descriptor = flight.FlightDescriptor.for_command(cmd)
info = self.conn.get_flight_info(flight_descriptor)
reader = self.conn.do_get(info.endpoints[0].ticket)

# Create streaming iterator
stream_iterator = StreamingResultIterator(reader)
self.logger.info('Stream connection established, waiting for data...')

# Optionally wrap with reorg detection
if with_reorg_detection:
stream_iterator = ReorgAwareStream(stream_iterator)
self.logger.info('Reorg detection enabled for streaming query')

# Start continuous loading with checkpoint support
with loader_instance:
self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.')
# Pass connection_name for checkpoint saving
yield from loader_instance.load_stream_continuous(
stream_iterator, destination, connection_name=connection_name, **load_config.__dict__
)
while True:
# Execute streaming query with Flight SQL
# Create a CommandStatementQuery message
command_query = FlightSql_pb2.CommandStatementQuery()
command_query.query = query

# Add resume watermark if provided
if resume_watermark:
# TODO: Add watermark to query metadata when Flight SQL supports it
self.logger.info(f'Resuming stream from watermark: {resume_watermark}')

# Wrap the CommandStatementQuery in an Any type
any_command = Any()
any_command.Pack(command_query)
cmd = any_command.SerializeToString()

self.logger.info('Establishing Flight SQL connection...')
flight_descriptor = flight.FlightDescriptor.for_command(cmd)
info = self.conn.get_flight_info(flight_descriptor)
reader = self.conn.do_get(info.endpoints[0].ticket)

# Create streaming iterator
stream_iterator = StreamingResultIterator(reader)
self.logger.info('Stream connection established, waiting for data...')

# Optionally wrap with reorg detection
if with_reorg_detection:
stream_iterator = ReorgAwareStream(stream_iterator, resume_watermark=resume_watermark)
self.logger.info('Reorg detection enabled for streaming query')

# Start continuous loading with checkpoint support
self.logger.info(f'Starting continuous load to {destination}. Press Ctrl+C to stop.')

reorg_result = None
# Pass connection_name for checkpoint saving
for result in loader_instance.load_stream_continuous(
stream_iterator, destination, connection_name=connection_name, **load_config.__dict__
):
yield result
# Break on reorg to restart stream
if result.is_reorg:
reorg_result = result
break

# Check if we need to restart due to reorg
if reorg_result:
# Close the old stream before restarting
if hasattr(stream_iterator, 'close'):
stream_iterator.close()
self.logger.info('Reorg detected, restarting stream with new resume position...')
resume_watermark = loader_instance.state_store.get_resume_position(connection_name, destination)
continue

# Normal exit - stream completed
break

except Exception as e:
self.logger.error(f'Streaming query failed: {e}')
Expand Down
71 changes: 50 additions & 21 deletions src/amp/streaming/reorg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, Iterator, List

from .iterator import StreamingResultIterator
from .types import BlockRange, ResponseBatch
from .types import BlockRange, ResponseBatch, ResumeWatermark


class ReorgAwareStream:
Expand All @@ -16,20 +16,32 @@ class ReorgAwareStream:
This class monitors the block ranges in consecutive batches to detect chain
reorganizations (reorgs). When a reorg is detected, a ResponseBatch with
is_reorg=True is emitted containing the invalidation ranges.

Supports cross-restart reorg detection by initializing from a resume watermark
that contains the last known block hashes from persistent state.
"""

def __init__(self, stream_iterator: StreamingResultIterator):
def __init__(self, stream_iterator: StreamingResultIterator, resume_watermark: ResumeWatermark = None):
"""
Initialize the reorg-aware stream.

Args:
stream_iterator: The underlying streaming result iterator
resume_watermark: Optional watermark from persistent state (LMDB) containing
last known block ranges with hashes for cross-restart reorg detection
"""
self.stream_iterator = stream_iterator
# Track the latest range for each network
self.prev_ranges_by_network: Dict[str, BlockRange] = {}
self.logger = logging.getLogger(__name__)

if resume_watermark:
for block_range in resume_watermark.ranges:
self.prev_ranges_by_network[block_range.network] = block_range
self.logger.debug(
f'Initialized reorg detection for {block_range.network} '
f'from block {block_range.end} hash {block_range.hash}'
)

def __iter__(self) -> Iterator[ResponseBatch]:
"""Return iterator instance"""
return self
Expand Down Expand Up @@ -63,20 +75,16 @@ def __next__(self) -> ResponseBatch:
for range in batch.metadata.ranges:
self.prev_ranges_by_network[range.network] = range

# If we detected a reorg, yield the reorg notification first
# If we detected a reorg, return reorg batch
# Caller decides whether to stop/restart or continue
if invalidation_ranges:
self.logger.info(f'Reorg detected with {len(invalidation_ranges)} invalidation ranges')
# Store the batch to yield after the reorg
self._pending_batch = batch
# Clear memory for affected networks so restart works correctly
for inv_range in invalidation_ranges:
if inv_range.network in self.prev_ranges_by_network:
del self.prev_ranges_by_network[inv_range.network]
return ResponseBatch.reorg_batch(invalidation_ranges)

# Check if we have a pending batch from a previous reorg detection
# REVIEW: I think we should remove this
if hasattr(self, '_pending_batch'):
pending = self._pending_batch
delattr(self, '_pending_batch')
return pending

# Normal case - just return the data batch
return batch

Expand All @@ -89,9 +97,9 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
"""
Detect reorganizations by comparing current ranges with previous ranges.

A reorg is detected when:
- A range starts at or before the end of the previous range for the same network
- The range is different from the previous range
A reorg is detected when either:
1. Block number overlap: current range starts at or before previous range end
2. Hash mismatch: server's prev_hash doesn't match our stored hash (cross-restart detection)

Args:
current_ranges: Block ranges from the current batch
Expand All @@ -102,18 +110,39 @@ def _detect_reorg(self, current_ranges: List[BlockRange]) -> List[BlockRange]:
invalidation_ranges = []

for current_range in current_ranges:
# Get the previous range for this network
prev_range = self.prev_ranges_by_network.get(current_range.network)

if prev_range:
# Check if this indicates a reorg
is_reorg = False

# Detection 1: Block number overlap (original logic)
if current_range != prev_range and current_range.start <= prev_range.end:
# Reorg detected - create invalidation range
# Invalidate from the start of the current range to the max end
is_reorg = True
self.logger.info(
f'Reorg detected via block overlap: {current_range.network} '
f'current start {current_range.start} <= prev end {prev_range.end}'
)

# Detection 2: Hash mismatch (cross-restart detection)
# Server sends prev_hash = hash of block before current range
# If it doesn't match our stored hash, chain has changed
elif (
current_range.prev_hash is not None
and prev_range.hash is not None
and current_range.prev_hash != prev_range.hash
):
is_reorg = True
self.logger.info(
f'Reorg detected via hash mismatch: {current_range.network} '
f'server prev_hash {current_range.prev_hash} != stored hash {prev_range.hash}'
)

if is_reorg:
invalidation = BlockRange(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we invalidate the entire previous range in this case to be safe?

Copy link
Member Author

@incrypto32 incrypto32 Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good catch. for the case of hash mismatch we need to invalidate entire previous range to be safe. But just setting it to previous range would create a gap since previous range would be skipped completely since processing started in the next range.
So there need to be some changes on how this ReorgAwareStream works. I'll look into and come up with a better solution. I also identified a bug in ReorgAwareStream currently when there is _pending_batch it gets returned but the current_batch that was just fetched from the stream gets dropped completely.

I'll come up with a fix for that as well

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper fix would be to trigger a backfill when we invalidate more than the current range.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fordN i have added a fix for both the changes are quite a bit, we now restart the stream from the checkpoint store when a reorg occurs. I dont know if its the right approach, the other option was to do some backfill for the skipped block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, cool. I'll take a look in the morning

network=current_range.network,
start=current_range.start,
start=prev_range.start,
end=max(current_range.end, prev_range.end),
hash=prev_range.hash,
)
invalidation_ranges.append(invalidation)

Expand Down
110 changes: 106 additions & 4 deletions tests/unit/test_streaming_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class MockIterator:

assert len(invalidations) == 1
assert invalidations[0].network == 'ethereum'
assert invalidations[0].start == 180
assert invalidations[0].start == 100 # prev_range.start
assert invalidations[0].end == 280 # max(280, 200)

def test_detect_reorg_multiple_networks(self):
Expand Down Expand Up @@ -504,12 +504,12 @@ class MockIterator:

# Check ethereum reorg
eth_inv = next(inv for inv in invalidations if inv.network == 'ethereum')
assert eth_inv.start == 150
assert eth_inv.start == 100 # prev_range.start
assert eth_inv.end == 250

# Check polygon reorg
poly_inv = next(inv for inv in invalidations if inv.network == 'polygon')
assert poly_inv.start == 140
assert poly_inv.start == 50 # prev_range.start
assert poly_inv.end == 240

def test_detect_reorg_same_range_no_reorg(self):
Expand Down Expand Up @@ -546,7 +546,7 @@ class MockIterator:
invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 1
assert invalidations[0].start == 250
assert invalidations[0].start == 100 # prev_range.start
assert invalidations[0].end == 300 # max(280, 300)

def test_is_duplicate_batch_all_same(self):
Expand Down Expand Up @@ -619,3 +619,105 @@ class MockIterator:
stream = ReorgAwareStream(MockIterator())

assert stream._is_duplicate_batch([]) == False

def test_init_from_resume_watermark(self):
"""Test initialization from resume watermark for cross-restart reorg detection"""

class MockIterator:
pass

watermark = ResumeWatermark(
ranges=[
BlockRange(network='ethereum', start=100, end=200, hash='0xabc123'),
BlockRange(network='polygon', start=50, end=150, hash='0xdef456'),
]
)

stream = ReorgAwareStream(MockIterator(), resume_watermark=watermark)

assert 'ethereum' in stream.prev_ranges_by_network
assert 'polygon' in stream.prev_ranges_by_network
assert stream.prev_ranges_by_network['ethereum'].hash == '0xabc123'
assert stream.prev_ranges_by_network['polygon'].hash == '0xdef456'

def test_detect_reorg_hash_mismatch(self):
"""Test reorg detection via hash mismatch (cross-restart detection)"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xoriginal'),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xdifferent'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 1
assert invalidations[0].network == 'ethereum'
assert invalidations[0].hash == '0xoriginal'

def test_detect_reorg_hash_match_no_reorg(self):
"""Test no reorg when hashes match across restart"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash='0xsame'),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsame'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0

def test_detect_reorg_hash_mismatch_with_none_prev_hash(self):
"""Test no reorg detection when server prev_hash is None (genesis block)"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=0, end=0, hash='0xgenesis'),
}

current_ranges = [
BlockRange(network='ethereum', start=1, end=100, prev_hash=None),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0

def test_detect_reorg_hash_mismatch_with_none_stored_hash(self):
"""Test no reorg detection when stored hash is None"""

class MockIterator:
pass

stream = ReorgAwareStream(MockIterator())

stream.prev_ranges_by_network = {
'ethereum': BlockRange(network='ethereum', start=100, end=200, hash=None),
}

current_ranges = [
BlockRange(network='ethereum', start=201, end=300, prev_hash='0xsome_hash'),
]

invalidations = stream._detect_reorg(current_ranges)

assert len(invalidations) == 0
Loading