From 3dda40cefa1d6b9f025decc6169e892c49185a09 Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Sat, 21 Feb 2026 20:37:27 +0530 Subject: [PATCH 1/3] fix(topk): avoid overflow panic in interleave emission --- datafusion/physical-plan/src/topk/mod.rs | 246 ++++++++++++++++++----- 1 file changed, 193 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 4b93e6a188d57..9f56a25bc20d0 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -24,7 +24,8 @@ use arrow::{ }; use datafusion_expr::{ColumnarValue, Operator}; use std::mem::size_of; -use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; +use std::panic::{AssertUnwindSafe, catch_unwind}; +use std::{any::Any, cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, RecordOutput, @@ -35,7 +36,7 @@ use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion_common::{ - HashMap, Result, ScalarValue, internal_datafusion_err, internal_err, + DataFusionError, HashMap, Result, ScalarValue, internal_datafusion_err, internal_err, }; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, @@ -109,8 +110,6 @@ pub struct TopK { metrics: TopKMetrics, /// Reservation reservation: MemoryReservation, - /// The target number of rows for output batches - batch_size: usize, /// sort expressions expr: LexOrdering, /// row converter, for sort keys @@ -216,7 +215,6 @@ impl TopK { schema: Arc::clone(&schema), metrics: TopKMetrics::new(metrics, partition_id), reservation, - batch_size, expr, row_converter, scratch_rows, @@ -588,7 +586,6 @@ impl TopK { schema, metrics, reservation: _, - batch_size, expr: _, row_converter: _, scratch_rows: _, @@ -605,20 +602,10 @@ impl TopK { // break into record batches as needed let mut batches = vec![]; - if let Some(mut batch) = heap.emit()? { + for batch in heap.emit()? { (&batch).record_output(&metrics.baseline); - - loop { - if batch.num_rows() <= batch_size { - batches.push(Ok(batch)); - break; - } else { - batches.push(Ok(batch.slice(0, batch_size))); - let remaining_length = batch.num_rows() - batch_size; - batch = batch.slice(batch_size, remaining_length); - } - } - }; + batches.push(Ok(batch)); + } Ok(Box::pin(RecordBatchStreamAdapter::new( schema, futures::stream::iter(batches), @@ -748,24 +735,34 @@ impl TopKHeap { } /// Returns the values stored in this heap, from values low to - /// high, as a single [`RecordBatch`], resetting the inner heap - pub fn emit(&mut self) -> Result> { + /// high, as [`RecordBatch`]es, resetting the inner heap + pub fn emit(&mut self) -> Result> { Ok(self.emit_with_state()?.0) } /// Returns the values stored in this heap, from values low to - /// high, as a single [`RecordBatch`], and a sorted vec of the + /// high, as [`RecordBatch`]es, and a sorted vec of the /// current heap's contents - pub fn emit_with_state(&mut self) -> Result<(Option, Vec)> { + pub fn emit_with_state(&mut self) -> Result<(Vec, Vec)> { // generate sorted rows let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); if self.store.is_empty() { - return Ok((None, topk_rows)); + return Ok((Vec::new(), topk_rows)); } - // Collect the batches into a vec and store the "batch_id -> array_pos" mapping, to then - // build the `indices` vec below. This is needed since the batch ids are not continuous. + let batches = self.interleave_topk_rows(&topk_rows, self.batch_size)?; + + Ok((batches, topk_rows)) + } + + fn interleave_topk_rows( + &self, + topk_rows: &[TopKRow], + max_rows_per_batch: usize, + ) -> Result> { + // Collect the batches into a vec and store the "batch_id -> array_pos" mapping. + // This is needed since the batch ids are not continuous. let mut record_batches = Vec::new(); let mut batch_id_array_pos = HashMap::new(); for (array_pos, (batch_id, batch)) in self.store.batches.iter().enumerate() { @@ -773,22 +770,57 @@ impl TopKHeap { batch_id_array_pos.insert(*batch_id, array_pos); } - let indices: Vec<_> = topk_rows + let all_indices = topk_rows .iter() - .map(|k| (batch_id_array_pos[&k.batch_id], k.index)) - .collect(); + .map(|k| { + let array_pos = batch_id_array_pos.get(&k.batch_id).ok_or_else(|| { + internal_datafusion_err!( + "TopK row references missing batch id {}", + k.batch_id + ) + })?; + Ok((*array_pos, k.index)) + }) + .collect::>>()?; - // At this point `indices` contains indexes within the - // rows and `input_arrays` contains a reference to the - // relevant RecordBatch for that index. `interleave_record_batch` pulls - // them together into a single new batch - let new_batch = interleave_record_batch(&record_batches, &indices)?; + let max_rows_per_batch = max_rows_per_batch.max(1); + let mut batches = Vec::new(); + let mut start = 0; + while start < all_indices.len() { + let remaining_rows = all_indices.len() - start; + let mut chunk_size = remaining_rows.min(max_rows_per_batch); - Ok((Some(new_batch), topk_rows)) + loop { + let indices = &all_indices[start..start + chunk_size]; + match try_interleave_record_batch(&record_batches, indices) { + Ok(batch) => { + batches.push(batch); + start += chunk_size; + break; + } + Err(InterleaveError::Overflow(_)) if chunk_size > 1 => { + chunk_size = chunk_size.div_ceil(2); + if chunk_size == 0 { + return internal_err!( + "Invalid TopK chunk size during interleave" + ); + } + } + Err(InterleaveError::Overflow(message)) => { + return internal_err!( + "TopK failed to interleave a single row due to offset overflow: {message}" + ); + } + Err(InterleaveError::DataFusion(err)) => return Err(err), + } + } + } + + Ok(batches) } - /// Compact this heap, rewriting all stored batches into a single - /// input batch + /// Compact this heap, rewriting all stored batches into new input + /// batches. pub fn maybe_compact(&mut self) -> Result<()> { // we compact if the number of "unused" rows in the store is // past some pre-defined threshold. Target holding up to @@ -802,32 +834,40 @@ impl TopKHeap { if self.store.len() <= 2 || unused_rows < max_unused_rows { return Ok(()); } - // at first, compact the entire thing always into a new batch + // at first, compact the entire thing always into new batches // (maybe we can get fancier in the future about ignoring // batches that have a high usage ratio already - // Note: new batch is in the same order as inner - let num_rows = self.inner.len(); - let (new_batch, mut topk_rows) = self.emit_with_state()?; - let Some(new_batch) = new_batch else { + // Note: new batches are in the same order as inner + let (new_batches, mut topk_rows) = self.emit_with_state()?; + if new_batches.is_empty() { return Ok(()); - }; + } // clear all old entries in store (this invalidates all // store_ids in `inner`) self.store.clear(); - let mut batch_entry = self.register_batch(new_batch); - batch_entry.uses = num_rows; - - // rewrite all existing entries to use the new batch, and - // remove old entries. The sortedness and their relative - // position do not change - for (i, topk_row) in topk_rows.iter_mut().enumerate() { - topk_row.batch_id = batch_entry.id; - topk_row.index = i; + // rewrite all existing entries to use the compacted batches. + // The sortedness and their relative position do not change. + let mut row_offset = 0; + for new_batch in new_batches { + let mut batch_entry = self.register_batch(new_batch); + batch_entry.uses = batch_entry.batch.num_rows(); + + for (index, topk_row) in topk_rows[row_offset..row_offset + batch_entry.uses] + .iter_mut() + .enumerate() + { + topk_row.batch_id = batch_entry.id; + topk_row.index = index; + } + row_offset += batch_entry.uses; + self.insert_batch_entry(batch_entry); } - self.insert_batch_entry(batch_entry); + + debug_assert_eq!(row_offset, topk_rows.len()); + // restore the heap self.inner = BinaryHeap::from(topk_rows); @@ -884,6 +924,56 @@ impl TopKHeap { } } +enum InterleaveError { + Overflow(String), + DataFusion(DataFusionError), +} + +fn try_interleave_record_batch( + record_batches: &[&RecordBatch], + indices: &[(usize, usize)], +) -> std::result::Result { + let result = catch_unwind(AssertUnwindSafe(|| { + interleave_record_batch(record_batches, indices) + })); + + match result { + Ok(Ok(batch)) => Ok(batch), + Ok(Err(err)) => { + let message = err.to_string(); + if is_overflow_message(&message) { + Err(InterleaveError::Overflow(message)) + } else { + Err(InterleaveError::DataFusion(err.into())) + } + } + Err(payload) => { + let message = panic_message(payload.as_ref()); + if is_overflow_message(&message) { + Err(InterleaveError::Overflow(message)) + } else { + Err(InterleaveError::DataFusion(internal_datafusion_err!( + "TopK interleave panicked: {message}" + ))) + } + } + } +} + +fn is_overflow_message(message: &str) -> bool { + message.to_ascii_lowercase().contains("overflow") +} + +fn panic_message(payload: &(dyn Any + Send)) -> String { + if let Some(message) = payload.downcast_ref::<&str>() { + (*message).to_string() + } else if let Some(message) = payload.downcast_ref::() { + message.clone() + } else { + "unknown panic payload".to_string() + } +} + /// Represents one of the top K rows held in this heap. Orders /// according to memcmp of row (e.g. the arrow Row format, but could /// also be primitive values) @@ -1110,6 +1200,56 @@ mod tests { assert_eq!(record_batch_store.batches_size, 0); } + #[test] + fn test_topk_heap_emit_with_state_respects_batch_size() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let mut heap = TopKHeap::new(5, 2); + + let batch_a = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![5, 3, 1]))], + )?; + let mut entry_a = heap.register_batch(batch_a); + for (index, key) in [5_u8, 3_u8, 1_u8].into_iter().enumerate() { + heap.add(&mut entry_a, [key], index); + } + heap.insert_batch_entry(entry_a); + + let batch_b = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![4, 2]))], + )?; + let mut entry_b = heap.register_batch(batch_b); + for (index, key) in [4_u8, 2_u8].into_iter().enumerate() { + heap.add(&mut entry_b, [key], index); + } + heap.insert_batch_entry(entry_b); + + let (batches, topk_rows) = heap.emit_with_state()?; + assert_eq!(batches.len(), 3); + assert_eq!( + batches + .iter() + .map(RecordBatch::num_rows) + .collect::>(), + vec![2, 2, 1] + ); + assert_eq!( + topk_rows.iter().map(|row| row.row[0]).collect::>(), + vec![1, 2, 3, 4, 5] + ); + + assert_batches_eq!( + &[ + "+---+", "| a |", "+---+", "| 1 |", "| 2 |", "| 3 |", "| 4 |", "| 5 |", + "+---+", + ], + &batches + ); + + Ok(()) + } + /// This test validates that the `try_finish` method marks the TopK operator as finished /// when the prefix (on column "a") of the last row in the current batch is strictly greater /// than the max top‑k row. From b2861cbaa336de5cfefe3af28d0d23ab0aedf32f Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Tue, 17 Mar 2026 17:55:05 +0530 Subject: [PATCH 2/3] fix(topk): split interleave by offset sizes --- datafusion/physical-plan/src/topk/mod.rs | 276 ++++++++++++++++------- 1 file changed, 196 insertions(+), 80 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 9f56a25bc20d0..f2c1e4eac4828 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,14 +18,14 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - array::{Array, AsArray}, + array::{Array, AsArray, BinaryArray, ListArray, MapArray, StringArray}, compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter}, row::{RowConverter, Rows, SortField}, }; use datafusion_expr::{ColumnarValue, Operator}; use std::mem::size_of; -use std::panic::{AssertUnwindSafe, catch_unwind}; -use std::{any::Any, cmp::Ordering, collections::BinaryHeap, sync::Arc}; +use std::ops::Range; +use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; use super::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, RecordOutput, @@ -34,9 +34,9 @@ use crate::spill::get_record_batch_memory_size; use crate::{SendableRecordBatchStream, stream::RecordBatchStreamAdapter}; use arrow::array::{ArrayRef, RecordBatch}; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_common::{ - DataFusionError, HashMap, Result, ScalarValue, internal_datafusion_err, internal_err, + HashMap, Result, ScalarValue, internal_datafusion_err, internal_err, }; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, @@ -761,6 +761,26 @@ impl TopKHeap { topk_rows: &[TopKRow], max_rows_per_batch: usize, ) -> Result> { + let (record_batches, batch_id_array_pos) = self.collect_record_batches(); + let all_indices = self.collect_indices(topk_rows, &batch_id_array_pos)?; + let index_ranges = split_indices_by_i32_offsets( + &record_batches, + &all_indices, + max_rows_per_batch.max(1), + I32_OFFSET_LIMIT, + )?; + + let mut batches = Vec::with_capacity(index_ranges.len()); + for range in index_ranges { + let indices = &all_indices[range]; + let batch = interleave_record_batch(&record_batches, indices)?; + batches.push(batch); + } + + Ok(batches) + } + + fn collect_record_batches(&self) -> (Vec<&RecordBatch>, HashMap) { // Collect the batches into a vec and store the "batch_id -> array_pos" mapping. // This is needed since the batch ids are not continuous. let mut record_batches = Vec::new(); @@ -769,8 +789,15 @@ impl TopKHeap { record_batches.push(&batch.batch); batch_id_array_pos.insert(*batch_id, array_pos); } + (record_batches, batch_id_array_pos) + } - let all_indices = topk_rows + fn collect_indices( + &self, + topk_rows: &[TopKRow], + batch_id_array_pos: &HashMap, + ) -> Result> { + topk_rows .iter() .map(|k| { let array_pos = batch_id_array_pos.get(&k.batch_id).ok_or_else(|| { @@ -781,42 +808,7 @@ impl TopKHeap { })?; Ok((*array_pos, k.index)) }) - .collect::>>()?; - - let max_rows_per_batch = max_rows_per_batch.max(1); - let mut batches = Vec::new(); - let mut start = 0; - while start < all_indices.len() { - let remaining_rows = all_indices.len() - start; - let mut chunk_size = remaining_rows.min(max_rows_per_batch); - - loop { - let indices = &all_indices[start..start + chunk_size]; - match try_interleave_record_batch(&record_batches, indices) { - Ok(batch) => { - batches.push(batch); - start += chunk_size; - break; - } - Err(InterleaveError::Overflow(_)) if chunk_size > 1 => { - chunk_size = chunk_size.div_ceil(2); - if chunk_size == 0 { - return internal_err!( - "Invalid TopK chunk size during interleave" - ); - } - } - Err(InterleaveError::Overflow(message)) => { - return internal_err!( - "TopK failed to interleave a single row due to offset overflow: {message}" - ); - } - Err(InterleaveError::DataFusion(err)) => return Err(err), - } - } - } - - Ok(batches) + .collect::>>() } /// Compact this heap, rewriting all stored batches into new input @@ -923,57 +915,164 @@ impl TopKHeap { Ok(Some(scalar_values)) } } +const I32_OFFSET_LIMIT: i64 = i32::MAX as i64; -enum InterleaveError { - Overflow(String), - DataFusion(DataFusionError), -} - -fn try_interleave_record_batch( +fn split_indices_by_i32_offsets( record_batches: &[&RecordBatch], - indices: &[(usize, usize)], -) -> std::result::Result { - let result = catch_unwind(AssertUnwindSafe(|| { - interleave_record_batch(record_batches, indices) - })); - - match result { - Ok(Ok(batch)) => Ok(batch), - Ok(Err(err)) => { - let message = err.to_string(); - if is_overflow_message(&message) { - Err(InterleaveError::Overflow(message)) - } else { - Err(InterleaveError::DataFusion(err.into())) - } + all_indices: &[(usize, usize)], + max_rows_per_batch: usize, + max_offset: i64, +) -> Result>> { + if all_indices.is_empty() { + return Ok(Vec::new()); + } + + let var_width_columns = + collect_var_width_columns(record_batches.first().ok_or_else(|| { + internal_datafusion_err!("Missing record batches for TopK interleave") + })?); + + if var_width_columns.is_empty() { + return Ok(split_indices_by_row_count( + all_indices.len(), + max_rows_per_batch, + )); + } + + let mut ranges = Vec::new(); + let mut start = 0; + let mut totals = vec![0_i64; var_width_columns.len()]; + + for (pos, (batch_pos, row_index)) in all_indices.iter().enumerate() { + if pos - start >= max_rows_per_batch { + ranges.push(start..pos); + start = pos; + totals.fill(0); } - Err(payload) => { - let message = panic_message(payload.as_ref()); - if is_overflow_message(&message) { - Err(InterleaveError::Overflow(message)) - } else { - Err(InterleaveError::DataFusion(internal_datafusion_err!( - "TopK interleave panicked: {message}" - ))) + + let batch = record_batches.get(*batch_pos).ok_or_else(|| { + internal_datafusion_err!("Invalid batch position in TopK indices") + })?; + + let mut row_sizes = Vec::with_capacity(var_width_columns.len()); + for column in &var_width_columns { + let array = batch.column(column.column_index); + let size = column.row_size(array.as_ref(), *row_index)?; + if size > max_offset { + return internal_err!( + "TopK row requires {size} offsets which exceeds i32::MAX" + ); } + row_sizes.push(size); + } + + if totals + .iter() + .zip(row_sizes.iter()) + .any(|(total, size)| total + size > max_offset) + { + ranges.push(start..pos); + start = pos; + totals.fill(0); + } + + for (total, size) in totals.iter_mut().zip(row_sizes.iter()) { + *total += *size; } } + + if start < all_indices.len() { + ranges.push(start..all_indices.len()); + } + + Ok(ranges) } -fn is_overflow_message(message: &str) -> bool { - message.to_ascii_lowercase().contains("overflow") +fn split_indices_by_row_count( + total_rows: usize, + max_rows_per_batch: usize, +) -> Vec> { + let mut ranges = Vec::new(); + let mut start = 0; + let max_rows_per_batch = max_rows_per_batch.max(1); + while start < total_rows { + let end = (start + max_rows_per_batch).min(total_rows); + ranges.push(start..end); + start = end; + } + ranges } -fn panic_message(payload: &(dyn Any + Send)) -> String { - if let Some(message) = payload.downcast_ref::<&str>() { - (*message).to_string() - } else if let Some(message) = payload.downcast_ref::() { - message.clone() - } else { - "unknown panic payload".to_string() +fn collect_var_width_columns(batch: &RecordBatch) -> Vec { + batch + .columns() + .iter() + .enumerate() + .filter_map(|(index, array)| VarWidthColumn::new(index, array.data_type())) + .collect() +} + +struct VarWidthColumn { + column_index: usize, + kind: VarWidthKind, +} + +impl VarWidthColumn { + fn new(column_index: usize, data_type: &DataType) -> Option { + let kind = match data_type { + DataType::Utf8 => VarWidthKind::Utf8, + DataType::Binary => VarWidthKind::Binary, + DataType::List(_) => VarWidthKind::List, + DataType::Map(_, _) => VarWidthKind::Map, + _ => return None, + }; + + Some(Self { column_index, kind }) + } + + fn row_size(&self, array: &dyn Array, row: usize) -> Result { + let size = match self.kind { + VarWidthKind::Utf8 => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected Utf8 array for TopK interleave") + })? + .value_length(row) as i64, + VarWidthKind::Binary => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected Binary array for TopK interleave") + })? + .value_length(row) as i64, + VarWidthKind::List => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected List array for TopK interleave") + })? + .value_length(row) as i64, + VarWidthKind::Map => array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected Map array for TopK interleave") + })? + .value_length(row) as i64, + }; + + Ok(size) } } +enum VarWidthKind { + Utf8, + Binary, + List, + Map, +} + /// Represents one of the top K rows held in this heap. Orders /// according to memcmp of row (e.g. the arrow Row format, but could /// also be primitive values) @@ -1250,6 +1349,23 @@ mod tests { Ok(()) } + #[test] + fn test_split_indices_by_i32_offsets_uses_sizes() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(StringArray::from(vec!["aaa", "bb", "cccc"]))], + )?; + let record_batches = vec![&batch]; + let all_indices = vec![(0, 0), (0, 1), (0, 2)]; + + let ranges = split_indices_by_i32_offsets(&record_batches, &all_indices, 10, 5)?; + + assert_eq!(ranges, vec![0..2, 2..3]); + + Ok(()) + } + /// This test validates that the `try_finish` method marks the TopK operator as finished /// when the prefix (on column "a") of the last row in the current batch is strictly greater /// than the max top‑k row. From 0d20bcc0bf12d6f9b34917b6fd65a07d82442ca1 Mon Sep 17 00:00:00 2001 From: aviralgarg05 Date: Sun, 22 Mar 2026 19:19:20 +0530 Subject: [PATCH 3/3] fix(topk): handle nested/view var-width offsets --- datafusion/physical-plan/src/topk/mod.rs | 199 +++++++++++++++++++---- 1 file changed, 168 insertions(+), 31 deletions(-) diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 00aaec27f6c40..47f6b30031369 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,7 +18,7 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - array::{Array, AsArray, BinaryArray, ListArray, MapArray, StringArray}, + array::{Array, AsArray, BinaryArray, ListArray, MapArray, StringArray, StructArray}, compute::{FilterBuilder, interleave_record_batch, prep_null_mask_filter}, row::{RowConverter, Rows, SortField}, }; @@ -940,6 +940,27 @@ fn split_indices_by_i32_offsets( )); } + // Fast path: if the combined data across *all* batches is well under the + // limit, no single interleave chunk can overflow regardless of which rows + // are selected, so we can skip the per-row accounting loop entirely. + let total_across_batches: i64 = record_batches + .iter() + .flat_map(|batch| { + var_width_columns.iter().map(|col| { + col.get_array(batch) + .map(|a| col.total_data_size(a)) + .unwrap_or(0) + }) + }) + .fold(0_i64, |acc, v| acc.saturating_add(v)); + + if total_across_batches <= max_offset / 2 { + return Ok(split_indices_by_row_count( + all_indices.len(), + max_rows_per_batch, + )); + } + let mut ranges = Vec::new(); let mut start = 0; let mut totals = vec![0_i64; var_width_columns.len()]; @@ -957,8 +978,8 @@ fn split_indices_by_i32_offsets( let mut row_sizes = Vec::with_capacity(var_width_columns.len()); for column in &var_width_columns { - let array = batch.column(column.column_index); - let size = column.row_size(array.as_ref(), *row_index)?; + let array = column.get_array(batch)?; + let size = column.row_size(array, *row_index)?; if size > max_offset { return internal_err!( "TopK row requires {size} offsets which exceeds i32::MAX" @@ -1004,33 +1025,100 @@ fn split_indices_by_row_count( ranges } +/// Recursively collect all variable-width leaf columns from `batch`, walking +/// into `Struct` fields. Each returned `VarWidthColumn` stores the full index +/// path needed to reach its array (top-level index, then zero or more struct +/// child indices). fn collect_var_width_columns(batch: &RecordBatch) -> Vec { - batch - .columns() - .iter() - .enumerate() - .filter_map(|(index, array)| VarWidthColumn::new(index, array.data_type())) - .collect() + let mut columns = Vec::new(); + collect_var_width_from_arrays(batch.columns(), &[], &mut columns); + columns +} + +fn collect_var_width_from_arrays( + arrays: &[ArrayRef], + path_prefix: &[usize], + out: &mut Vec, +) { + for (idx, array) in arrays.iter().enumerate() { + let mut path = path_prefix.to_vec(); + path.push(idx); + match array.data_type() { + DataType::Utf8 => out.push(VarWidthColumn { + column_path: path, + kind: VarWidthKind::Utf8, + }), + DataType::Binary => out.push(VarWidthColumn { + column_path: path, + kind: VarWidthKind::Binary, + }), + DataType::Utf8View => out.push(VarWidthColumn { + column_path: path, + kind: VarWidthKind::Utf8View, + }), + DataType::BinaryView => out.push(VarWidthColumn { + column_path: path, + kind: VarWidthKind::BinaryView, + }), + DataType::List(_) => out.push(VarWidthColumn { + column_path: path, + kind: VarWidthKind::List, + }), + DataType::Map(_, _) => out.push(VarWidthColumn { + column_path: path, + kind: VarWidthKind::Map, + }), + DataType::Struct(_) => { + // Recurse: any Utf8/Binary (etc.) field inside a Struct gets the + // same i32 overflow treatment when interleave recurses into it. + if let Some(struct_array) = array.as_any().downcast_ref::() { + collect_var_width_from_arrays(struct_array.columns(), &path, out); + } + } + _ => {} + } + } } struct VarWidthColumn { - column_index: usize, + /// Path of column indices to the target array. + /// `column_path[0]` is the top-level index in the `RecordBatch`; + /// any further elements are child indices inside `StructArray`s. + column_path: Vec, kind: VarWidthKind, } impl VarWidthColumn { - fn new(column_index: usize, data_type: &DataType) -> Option { - let kind = match data_type { - DataType::Utf8 => VarWidthKind::Utf8, - DataType::Binary => VarWidthKind::Binary, - DataType::List(_) => VarWidthKind::List, - DataType::Map(_, _) => VarWidthKind::Map, - _ => return None, - }; - - Some(Self { column_index, kind }) + /// Walk `batch` along `column_path` to reach the target array. + fn get_array<'a>(&self, batch: &'a RecordBatch) -> Result<&'a dyn Array> { + let first = self.column_path.first().ok_or_else(|| { + internal_datafusion_err!("Empty column path in VarWidthColumn") + })?; + let mut array: &dyn Array = batch.column(*first).as_ref(); + for &child_idx in &self.column_path[1..] { + array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected StructArray while following nested column path in TopK" + ) + })? + .column(child_idx) + .as_ref(); + } + Ok(array) } + /// Return the offset increment contributed by row `row`. + /// + /// For `Utf8`/`Binary`/`Utf8View`/`BinaryView` this is the byte length of + /// the value, which is what Arrow accumulates in its i32 offset buffer. + /// + /// For `List`/`Map` this is the *element count* in that row's list/map, + /// which is what Arrow accumulates in its i32 offset buffer for those types. + /// The comparison threshold is still `i32::MAX` in both cases — just + /// different units. fn row_size(&self, array: &dyn Array, row: usize) -> Result { let size = match self.kind { VarWidthKind::Utf8 => array @@ -1047,29 +1135,78 @@ impl VarWidthColumn { internal_datafusion_err!("Expected Binary array for TopK interleave") })? .value_length(row) as i64, + VarWidthKind::Utf8View => array.as_string_view().value(row).len() as i64, + VarWidthKind::BinaryView => array.as_binary_view().value(row).len() as i64, + VarWidthKind::List => { + // value_length returns child element count — the correct unit for i32 offset overflow + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected List array for TopK interleave" + ) + })? + .value_length(row) as i64 + } + VarWidthKind::Map => { + // value_length returns entry count — the correct unit for i32 offset overflow + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("Expected Map array for TopK interleave") + })? + .value_length(row) as i64 + } + }; + Ok(size) + } + + /// Return the total accumulated offset across the entire `array`. + /// Used for the fast-path check in `split_indices_by_i32_offsets`. + /// + /// For `Utf8`/`Binary` this is the total byte count (last value offset). + /// For `List`/`Map` this is the total child-element count (last offset). + /// For view arrays the format differs; we return `i64::MAX` to always + /// fall through to per-row accounting. + fn total_data_size(&self, array: &dyn Array) -> i64 { + match self.kind { + VarWidthKind::Utf8 => array + .as_any() + .downcast_ref::() + .and_then(|a| a.value_offsets().last().copied()) + .map(|v| v as i64) + .unwrap_or(0), + VarWidthKind::Binary => array + .as_any() + .downcast_ref::() + .and_then(|a| a.value_offsets().last().copied()) + .map(|v| v as i64) + .unwrap_or(0), + // View arrays don't use i32 offset buffers; conservatively skip fast path. + VarWidthKind::Utf8View | VarWidthKind::BinaryView => i64::MAX, VarWidthKind::List => array .as_any() .downcast_ref::() - .ok_or_else(|| { - internal_datafusion_err!("Expected List array for TopK interleave") - })? - .value_length(row) as i64, + .and_then(|a| a.offsets().last().copied()) + .map(|v| v as i64) + .unwrap_or(0), VarWidthKind::Map => array .as_any() .downcast_ref::() - .ok_or_else(|| { - internal_datafusion_err!("Expected Map array for TopK interleave") - })? - .value_length(row) as i64, - }; - - Ok(size) + .and_then(|a| a.offsets().last().copied()) + .map(|v| v as i64) + .unwrap_or(0), + } } } enum VarWidthKind { Utf8, Binary, + Utf8View, + BinaryView, List, Map, }