Skip to content

Commit 7882a4b

Browse files
committed
base loader: fix micro batch is_processed marking, add tests
1 parent 710c4e3 commit 7882a4b

File tree

4 files changed

+265
-13
lines changed

4 files changed

+265
-13
lines changed

src/amp/loaders/base.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def load_stream_continuous(
484484
table_name,
485485
connection_name,
486486
response.metadata.ranges,
487+
ranges_complete=response.metadata.ranges_complete,
487488
)
488489
else:
489490
# Non-transactional loading (separate check, load, mark)
@@ -494,6 +495,7 @@ def load_stream_continuous(
494495
table_name,
495496
connection_name,
496497
response.metadata.ranges,
498+
ranges_complete=response.metadata.ranges_complete,
497499
**filtered_kwargs,
498500
)
499501

@@ -611,6 +613,7 @@ def _process_batch_transactional(
611613
table_name: str,
612614
connection_name: str,
613615
ranges: List[BlockRange],
616+
ranges_complete: bool = False,
614617
) -> LoadResult:
615618
"""
616619
Process a data batch using transactional exactly-once semantics.
@@ -622,6 +625,7 @@ def _process_batch_transactional(
622625
table_name: Target table name
623626
connection_name: Connection identifier
624627
ranges: Block ranges for this batch
628+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
625629
626630
Returns:
627631
LoadResult with operation outcome
@@ -630,13 +634,17 @@ def _process_batch_transactional(
630634
try:
631635
# Delegate to loader-specific transactional implementation
632636
# Loaders that support transactions implement load_batch_transactional()
633-
rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges)
637+
rows_loaded_batch = self.load_batch_transactional(
638+
batch_data, table_name, connection_name, ranges, ranges_complete
639+
)
634640
duration = time.time() - start_time
635641

636-
# Mark batches as processed in state store after successful transaction
637-
if ranges:
642+
# Mark batches as processed ONLY when microbatch is complete
643+
# multiple RecordBatches can share the same microbatch ID
644+
if ranges and ranges_complete:
638645
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
639646
self.state_store.mark_processed(connection_name, table_name, batch_ids)
647+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
640648

641649
return LoadResult(
642650
rows_loaded=rows_loaded_batch,
@@ -648,6 +656,7 @@ def _process_batch_transactional(
648656
metadata={
649657
'operation': 'transactional_load' if rows_loaded_batch > 0 else 'skip_duplicate',
650658
'ranges': [r.to_dict() for r in ranges],
659+
'ranges_complete': ranges_complete,
651660
},
652661
)
653662

@@ -670,6 +679,7 @@ def _process_batch_non_transactional(
670679
table_name: str,
671680
connection_name: str,
672681
ranges: Optional[List[BlockRange]],
682+
ranges_complete: bool = False,
673683
**kwargs,
674684
) -> Optional[LoadResult]:
675685
"""
@@ -682,21 +692,25 @@ def _process_batch_non_transactional(
682692
table_name: Target table name
683693
connection_name: Connection identifier
684694
ranges: Block ranges for this batch (if available)
695+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
685696
**kwargs: Additional options passed to load_batch
686697
687698
Returns:
688699
LoadResult, or None if batch was skipped as duplicate
689700
"""
690701
# Check if batch already processed (idempotency / exactly-once)
691-
if ranges and self.state_enabled:
702+
# For streaming: only check when ranges_complete=True (end of microbatch)
703+
# Multiple RecordBatches can share the same microbatch ID, so we must wait
704+
# until the entire microbatch is delivered before checking/marking as processed
705+
if ranges and self.state_enabled and ranges_complete:
692706
try:
693707
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
694708
is_duplicate = self.state_store.is_processed(connection_name, table_name, batch_ids)
695709

696710
if is_duplicate:
697711
# Skip this batch - already processed
698712
self.logger.info(
699-
f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}'
713+
f'Skipping duplicate microbatch: {len(ranges)} ranges already processed for {table_name}'
700714
)
701715
return LoadResult(
702716
rows_loaded=0,
@@ -711,14 +725,16 @@ def _process_batch_non_transactional(
711725
# BlockRange missing hash - log and continue without idempotency check
712726
self.logger.warning(f'Cannot check for duplicates: {e}. Processing batch anyway.')
713727

714-
# Load batch
728+
# Load batch (always load, even if part of larger microbatch)
715729
result = self.load_batch(batch_data, table_name, **kwargs)
716730

717-
if result.success and ranges and self.state_enabled:
718-
# Mark batch as processed (for exactly-once semantics)
731+
# Mark batch as processed ONLY when microbatch is complete
732+
# This ensures we don't skip subsequent RecordBatches within the same microbatch
733+
if result.success and ranges and self.state_enabled and ranges_complete:
719734
try:
720735
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
721736
self.state_store.mark_processed(connection_name, table_name, batch_ids)
737+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
722738
except Exception as e:
723739
self.logger.error(f'Failed to mark batches as processed: {e}')
724740
# Continue anyway - state store provides resume capability

src/amp/loaders/implementations/postgresql_loader.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def load_batch_transactional(
119119
table_name: str,
120120
connection_name: str,
121121
ranges: List[BlockRange],
122+
ranges_complete: bool = False,
122123
) -> int:
123124
"""
124125
Load a batch with transactional exactly-once semantics using in-memory state.
@@ -135,6 +136,7 @@ def load_batch_transactional(
135136
table_name: Target table name
136137
connection_name: Connection identifier for tracking
137138
ranges: Block ranges covered by this batch
139+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
138140
139141
Returns:
140142
Number of rows loaded (0 if duplicate)
@@ -149,24 +151,27 @@ def load_batch_transactional(
149151
self.logger.warning(f'Cannot create batch identifiers: {e}. Loading without duplicate check.')
150152
batch_ids = []
151153

152-
# Check if already processed (using in-memory state)
153-
if batch_ids and self.state_store.is_processed(connection_name, table_name, batch_ids):
154+
# Check if already processed ONLY when microbatch is complete
155+
# Multiple RecordBatches can share the same microbatch ID (BlockRange)
156+
if batch_ids and ranges_complete and self.state_store.is_processed(connection_name, table_name, batch_ids):
154157
self.logger.info(
155158
f'Batch already processed (ranges: {[f"{r.network}:{r.start}-{r.end}" for r in ranges]}), '
156159
f'skipping (state check)'
157160
)
158161
return 0
159162

160-
# Load data
163+
# Load data (always load, even if part of larger microbatch)
161164
conn = self.pool.getconn()
162165
try:
163166
with conn.cursor() as cur:
164167
self._copy_arrow_data(cur, batch, table_name)
165168
conn.commit()
166169

167-
# Mark as processed after successful load
168-
if batch_ids:
170+
# Mark as processed ONLY when microbatch is complete
171+
# This ensures we don't skip subsequent RecordBatches within the same microbatch
172+
if batch_ids and ranges_complete:
169173
self.state_store.mark_processed(connection_name, table_name, batch_ids)
174+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
170175

171176
self.logger.debug(
172177
f'Batch load committed: {batch.num_rows} rows, '

tests/integration/test_postgresql_loader.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,126 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t
692692

693693
finally:
694694
loader.pool.putconn(conn)
695+
696+
def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables):
697+
"""
698+
Test that multiple RecordBatches within the same microbatch are all loaded,
699+
and deduplication only happens at microbatch boundaries when ranges_complete=True.
700+
701+
This test verifies the fix for the critical bug where we were marking batches
702+
as processed after every RecordBatch instead of waiting for ranges_complete=True.
703+
"""
704+
from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch
705+
706+
cleanup_tables.append(test_table_name)
707+
708+
# Enable state management to test deduplication
709+
config_with_state = {
710+
**postgresql_test_config,
711+
'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True},
712+
}
713+
loader = PostgreSQLLoader(config_with_state)
714+
715+
with loader:
716+
# Create table first from the schema
717+
batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]})
718+
loader._create_table_from_schema(batch1_data.schema, test_table_name)
719+
720+
# Simulate a microbatch sent as 3 RecordBatches with the same BlockRange
721+
# This happens when the server sends large microbatches in smaller chunks
722+
723+
# First RecordBatch of the microbatch (ranges_complete=False)
724+
response1 = ResponseBatch.data_batch(
725+
data=batch1_data,
726+
metadata=BatchMetadata(
727+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')],
728+
ranges_complete=False, # Not the last batch in this microbatch
729+
),
730+
)
731+
732+
# Second RecordBatch of the microbatch (ranges_complete=False)
733+
batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]})
734+
response2 = ResponseBatch.data_batch(
735+
data=batch2_data,
736+
metadata=BatchMetadata(
737+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange!
738+
ranges_complete=False, # Still not the last batch
739+
),
740+
)
741+
742+
# Third RecordBatch of the microbatch (ranges_complete=True)
743+
batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]})
744+
response3 = ResponseBatch.data_batch(
745+
data=batch3_data,
746+
metadata=BatchMetadata(
747+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange!
748+
ranges_complete=True, # Last batch in this microbatch - safe to mark as processed
749+
),
750+
)
751+
752+
# Process the microbatch stream
753+
stream = [response1, response2, response3]
754+
results = list(
755+
loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection')
756+
)
757+
758+
# CRITICAL: All 3 RecordBatches should be loaded successfully
759+
# Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates")
760+
assert len(results) == 3, 'All RecordBatches within microbatch should be processed'
761+
assert all(r.success for r in results), 'All batches should succeed'
762+
assert results[0].rows_loaded == 2, 'First batch should load 2 rows'
763+
assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)'
764+
assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)'
765+
766+
# Verify total rows in table (all batches loaded)
767+
conn = loader.pool.getconn()
768+
try:
769+
with conn.cursor() as cur:
770+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
771+
total_count = cur.fetchone()[0]
772+
assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table'
773+
774+
# Verify the actual IDs are present
775+
cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id')
776+
all_ids = [row[0] for row in cur.fetchall()]
777+
assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present'
778+
779+
finally:
780+
loader.pool.putconn(conn)
781+
782+
# Now test that re-sending the complete microbatch is properly deduplicated
783+
# This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch)
784+
duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]})
785+
duplicate_response = ResponseBatch.data_batch(
786+
data=duplicate_batch,
787+
metadata=BatchMetadata(
788+
ranges=[
789+
BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')
790+
], # Same range as before!
791+
ranges_complete=True, # Complete microbatch
792+
),
793+
)
794+
795+
# Process duplicate microbatch
796+
duplicate_results = list(
797+
loader.load_stream_continuous(
798+
iter([duplicate_response]), test_table_name, connection_name='test_connection'
799+
)
800+
)
801+
802+
# The duplicate microbatch should be skipped (already processed)
803+
assert len(duplicate_results) == 1
804+
assert duplicate_results[0].success is True
805+
assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped'
806+
assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate'
807+
808+
# Verify row count unchanged (duplicate was skipped)
809+
conn = loader.pool.getconn()
810+
try:
811+
with conn.cursor() as cur:
812+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
813+
final_count = cur.fetchone()[0]
814+
assert final_count == 6, 'Row count should not increase after duplicate microbatch'
815+
816+
finally:
817+
loader.pool.putconn(conn)

0 commit comments

Comments
 (0)