Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 209 additions & 55 deletions datafusion/physical-expr/src/expressions/in_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -793,63 +793,71 @@ impl PhysicalExpr for InListExpr {
// comparator for unsupported types (nested, RunEndEncoded, etc.).
let value = value.into_array(num_rows)?;
let lhs_supports_arrow_eq = supports_arrow_eq(value.data_type());
let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold(
BooleanArray::new(BooleanBuffer::new_unset(num_rows), None),
|result, expr| -> Result<BooleanArray> {
let rhs = match expr? {
ColumnarValue::Array(array) => {
if lhs_supports_arrow_eq
&& supports_arrow_eq(array.data_type())
{
arrow_eq(&value, &array)?
} else {
let cmp = make_comparator(
value.as_ref(),
array.as_ref(),
SortOptions::default(),
)?;
(0..num_rows)
.map(|i| {
if value.is_null(i) || array.is_null(i) {
return None;
}
Some(cmp(i, i).is_eq())
})
.collect::<BooleanArray>()
}

// Helper: compare value against a single list expression
let compare_one = |expr: &Arc<dyn PhysicalExpr>| -> Result<BooleanArray> {
match expr.evaluate(batch)? {
ColumnarValue::Array(array) => {
if lhs_supports_arrow_eq
&& supports_arrow_eq(array.data_type())
{
Ok(arrow_eq(&value, &array)?)
} else {
let cmp = make_comparator(
value.as_ref(),
array.as_ref(),
SortOptions::default(),
)?;
let buffer = BooleanBuffer::collect_bool(num_rows, |i| {
cmp(i, i).is_eq()
});
let nulls =
NullBuffer::union(value.nulls(), array.nulls());
Ok(BooleanArray::new(buffer, nulls))
}
ColumnarValue::Scalar(scalar) => {
// Check if scalar is null once, before the loop
if scalar.is_null() {
// If scalar is null, all comparisons return null
BooleanArray::from(vec![None; num_rows])
} else if lhs_supports_arrow_eq {
let scalar_datum = scalar.to_scalar()?;
arrow_eq(&value, &scalar_datum)?
} else {
// Convert scalar to 1-element array
let array = scalar.to_array()?;
let cmp = make_comparator(
value.as_ref(),
array.as_ref(),
SortOptions::default(),
)?;
// Compare each row of value with the single scalar element
(0..num_rows)
.map(|i| {
if value.is_null(i) {
None
} else {
Some(cmp(i, 0).is_eq())
}
})
.collect::<BooleanArray>()
}
}
ColumnarValue::Scalar(scalar) => {
// Check if scalar is null once, before the loop
if scalar.is_null() {
// If scalar is null, all comparisons return null
Ok(BooleanArray::from(vec![None; num_rows]))
} else if lhs_supports_arrow_eq {
let scalar_datum = scalar.to_scalar()?;
Ok(arrow_eq(&value, &scalar_datum)?)
} else {
// Convert scalar to 1-element array
let array = scalar.to_array()?;
let cmp = make_comparator(
value.as_ref(),
array.as_ref(),
SortOptions::default(),
)?;
// Compare each row of value with the single scalar element
let buffer = BooleanBuffer::collect_bool(num_rows, |i| {
cmp(i, 0).is_eq()
});
Ok(BooleanArray::new(buffer, value.nulls().cloned()))
}
};
Ok(or_kleene(&result, &rhs)?)
},
)?;
}
}
};

// Evaluate first expression directly to avoid a redundant
// or_kleene with an all-false accumulator.
let mut found = if let Some(first) = self.list.first() {
compare_one(first)?
} else {
BooleanArray::new(BooleanBuffer::new_unset(num_rows), None)
};

for expr in self.list.iter().skip(1) {
// Short-circuit: if every non-null row is already true,
// no further list items can change the result.
if found.null_count() == 0 && found.true_count() == num_rows {
break;
}
found = or_kleene(&found, &compare_one(expr)?)?;
}

if self.negated { not(&found)? } else { found }
}
Expand Down Expand Up @@ -3724,4 +3732,150 @@ mod tests {
assert_eq!(result, &BooleanArray::from(vec![true, false, false]));
Ok(())
}

/// Tests that short-circuit evaluation produces correct results.
/// When all rows match after the first list item, remaining items
/// should be skipped without affecting correctness.
#[test]
fn test_in_list_with_columns_short_circuit() -> Result<()> {
// a IN (b, c) where b already matches every row of a
// The short-circuit should skip evaluating c
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![1, 2, 3])), // b == a for all rows
Arc::new(Int32Array::from(vec![99, 99, 99])),
],
)?;

let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, false);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
assert_eq!(result, &BooleanArray::from(vec![true, true, true]));
Ok(())
}

/// Short-circuit must NOT skip when nulls are present (three-valued logic).
/// Even if all non-null values are true, null rows keep the result as null.
#[test]
fn test_in_list_with_columns_short_circuit_with_nulls() -> Result<()> {
// a IN (b, c) where a has nulls
// Even if b matches all non-null rows, result should preserve nulls
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])),
Arc::new(Int32Array::from(vec![1, 2, 3])), // matches non-null rows
Arc::new(Int32Array::from(vec![99, 99, 99])),
],
)?;

let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, false);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
// row 0: 1 IN (1, 99) → true
// row 1: NULL IN (2, 99) → NULL
// row 2: 3 IN (3, 99) → true
assert_eq!(
result,
&BooleanArray::from(vec![Some(true), None, Some(true)])
);
Ok(())
}

/// Tests the make_comparator + collect_bool fallback path using
/// struct column references (nested types don't support arrow_eq).
#[test]
fn test_in_list_with_columns_struct() -> Result<()> {
let struct_fields = Fields::from(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, false),
]);
let struct_dt = DataType::Struct(struct_fields.clone());

let schema = Schema::new(vec![
Field::new("a", struct_dt.clone(), true),
Field::new("b", struct_dt.clone(), false),
Field::new("c", struct_dt.clone(), false),
]);

// a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}]
// b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}]
// c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}]
let a = Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3, 4])),
Arc::new(StringArray::from(vec!["a", "b", "c", "d"])),
],
Some(vec![true, true, false, true].into()),
));
let b = Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 9, 3, 4])),
Arc::new(StringArray::from(vec!["a", "z", "c", "d"])),
],
None,
));
let c = Arc::new(StructArray::new(
struct_fields.clone(),
vec![
Arc::new(Int32Array::from(vec![9, 2, 9, 9])),
Arc::new(StringArray::from(vec!["z", "b", "z", "z"])),
],
None,
));

let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b, c])?;

let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, false);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
// row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b)
// row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c)
// row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL
// row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b)
assert_eq!(
result,
&BooleanArray::from(vec![Some(true), Some(true), None, Some(true)])
);

// Also test NOT IN
let col_a = col("a", &schema)?;
let list = vec![col("b", &schema)?, col("c", &schema)?];
let expr = make_in_list_with_columns(col_a, list, true);

let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?;
let result = as_boolean_array(&result);
// row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false
// row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false
// row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL
// row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false
assert_eq!(
result,
&BooleanArray::from(vec![Some(false), Some(false), None, Some(false)])
);
Ok(())
}
}
Loading