diff --git a/datafusion/core/benches/sql_planner_extended.rs b/datafusion/core/benches/sql_planner_extended.rs index d4955313c79c3..73141bd8668e4 100644 --- a/datafusion/core/benches/sql_planner_extended.rs +++ b/datafusion/core/benches/sql_planner_extended.rs @@ -18,7 +18,10 @@ use arrow::array::{ArrayRef, RecordBatch}; use arrow_schema::DataType; use arrow_schema::TimeUnit::Nanosecond; -use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use criterion::{ + BenchmarkGroup, BenchmarkId, Criterion, criterion_group, criterion_main, + measurement::WallTime, +}; use datafusion::prelude::{DataFrame, SessionContext}; use datafusion_catalog::MemTable; use datafusion_common::ScalarValue; @@ -27,12 +30,17 @@ use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ btrim, length, regexp_like, regexp_replace, to_timestamp, upper, }; +use std::env; use std::fmt::Write; use std::hint::black_box; use std::ops::Rem; use std::sync::Arc; use tokio::runtime::Runtime; +const FULL_PREDICATE_SWEEP: [usize; 5] = [10, 20, 30, 40, 60]; +const FULL_DEPTH_SWEEP: [usize; 3] = [1, 2, 3]; +const DEFAULT_SWEEP_POINTS: [(usize, usize); 3] = [(10, 1), (30, 2), (60, 3)]; + // This benchmark suite is designed to test the performance of // logical planning with a large plan containing unions, many columns // with a variety of operations in it. @@ -252,26 +260,6 @@ fn build_case_heavy_left_join_query(predicate_count: usize, case_depth: usize) - query } -fn build_case_heavy_left_join_df_with_push_down_filter( - rt: &Runtime, - predicate_count: usize, - case_depth: usize, - push_down_filter_enabled: bool, -) -> DataFrame { - let ctx = SessionContext::new(); - register_string_table(&ctx, 100, 1000); - if !push_down_filter_enabled { - let removed = ctx.remove_optimizer_rule("push_down_filter"); - assert!( - removed, - "push_down_filter rule should be present in the default optimizer" - ); - } - - let query = build_case_heavy_left_join_query(predicate_count, case_depth); - rt.block_on(async { ctx.sql(&query).await.unwrap() }) -} - fn build_non_case_left_join_query( predicate_count: usize, nesting_depth: usize, @@ -304,10 +292,11 @@ fn build_non_case_left_join_query( query } -fn build_non_case_left_join_df_with_push_down_filter( +fn build_left_join_df_with_push_down_filter( rt: &Runtime, + query_builder: impl Fn(usize, usize) -> String, predicate_count: usize, - nesting_depth: usize, + depth: usize, push_down_filter_enabled: bool, ) -> DataFrame { let ctx = SessionContext::new(); @@ -320,10 +309,103 @@ fn build_non_case_left_join_df_with_push_down_filter( ); } - let query = build_non_case_left_join_query(predicate_count, nesting_depth); + let query = query_builder(predicate_count, depth); rt.block_on(async { ctx.sql(&query).await.unwrap() }) } +fn build_case_heavy_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + case_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + build_left_join_df_with_push_down_filter( + rt, + build_case_heavy_left_join_query, + predicate_count, + case_depth, + push_down_filter_enabled, + ) +} + +fn build_non_case_left_join_df_with_push_down_filter( + rt: &Runtime, + predicate_count: usize, + nesting_depth: usize, + push_down_filter_enabled: bool, +) -> DataFrame { + build_left_join_df_with_push_down_filter( + rt, + build_non_case_left_join_query, + predicate_count, + nesting_depth, + push_down_filter_enabled, + ) +} + +fn include_full_push_down_filter_sweep() -> bool { + env::var("DATAFUSION_PUSH_DOWN_FILTER_FULL_SWEEP") + .map(|value| value == "1" || value.eq_ignore_ascii_case("true")) + .unwrap_or(false) +} + +fn push_down_filter_sweep_points() -> Vec<(usize, usize)> { + if include_full_push_down_filter_sweep() { + FULL_DEPTH_SWEEP + .into_iter() + .flat_map(|depth| { + FULL_PREDICATE_SWEEP + .into_iter() + .map(move |predicate_count| (predicate_count, depth)) + }) + .collect() + } else { + DEFAULT_SWEEP_POINTS.to_vec() + } +} + +fn bench_push_down_filter_ab( + group: &mut BenchmarkGroup<'_, WallTime>, + rt: &Runtime, + sweep_points: &[(usize, usize)], + build_df: BuildFn, +) where + BuildFn: Fn(&Runtime, usize, usize, bool) -> DataFrame, +{ + for &(predicate_count, depth) in sweep_points { + let with_push_down_filter = build_df(rt, predicate_count, depth, true); + let without_push_down_filter = build_df(rt, predicate_count, depth, false); + + let input_label = format!("predicates={predicate_count},nesting_depth={depth}"); + + group.bench_with_input( + BenchmarkId::new("with_push_down_filter", &input_label), + &with_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + + group.bench_with_input( + BenchmarkId::new("without_push_down_filter", &input_label), + &without_push_down_filter, + |b, df| { + b.iter(|| { + let df_clone = df.clone(); + black_box( + rt.block_on(async { df_clone.into_optimized_plan().unwrap() }), + ); + }) + }, + ); + } +} + fn criterion_benchmark(c: &mut Criterion) { let baseline_ctx = SessionContext::new(); let case_heavy_ctx = SessionContext::new(); @@ -349,116 +431,40 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - let predicate_sweep = [10, 20, 30, 40, 60]; - let case_depth_sweep = [1, 2, 3]; + let sweep_points = push_down_filter_sweep_points(); let mut hotspot_group = c.benchmark_group("push_down_filter_hotspot_case_heavy_left_join_ab"); - for case_depth in case_depth_sweep { - for predicate_count in predicate_sweep { - let with_push_down_filter = - build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - true, - ); - let without_push_down_filter = - build_case_heavy_left_join_df_with_push_down_filter( - &rt, - predicate_count, - case_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},case_depth={case_depth}"); - // A/B interpretation: - // - with_push_down_filter: default optimizer path (rule enabled) - // - without_push_down_filter: control path with the rule removed - // Compare both IDs at the same sweep point to isolate rule impact. - hotspot_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - hotspot_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - } - } + bench_push_down_filter_ab( + &mut hotspot_group, + &rt, + &sweep_points, + |rt, predicate_count, depth, enable| { + build_case_heavy_left_join_df_with_push_down_filter( + rt, + predicate_count, + depth, + enable, + ) + }, + ); hotspot_group.finish(); let mut control_group = c.benchmark_group("push_down_filter_control_non_case_left_join_ab"); - for nesting_depth in case_depth_sweep { - for predicate_count in predicate_sweep { - let with_push_down_filter = build_non_case_left_join_df_with_push_down_filter( - &rt, + bench_push_down_filter_ab( + &mut control_group, + &rt, + &sweep_points, + |rt, predicate_count, depth, enable| { + build_non_case_left_join_df_with_push_down_filter( + rt, predicate_count, - nesting_depth, - true, - ); - let without_push_down_filter = - build_non_case_left_join_df_with_push_down_filter( - &rt, - predicate_count, - nesting_depth, - false, - ); - - let input_label = - format!("predicates={predicate_count},nesting_depth={nesting_depth}"); - control_group.bench_with_input( - BenchmarkId::new("with_push_down_filter", &input_label), - &with_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - control_group.bench_with_input( - BenchmarkId::new("without_push_down_filter", &input_label), - &without_push_down_filter, - |b, df| { - b.iter(|| { - let df_clone = df.clone(); - black_box( - rt.block_on(async { - df_clone.into_optimized_plan().unwrap() - }), - ); - }) - }, - ); - } - } + depth, + enable, + ) + }, + ); control_group.finish(); } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 9a1dc5502ee60..a245227382d83 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -67,6 +67,7 @@ pub mod create_drop; pub mod explain_analyze; pub mod joins; mod path_partition; +mod push_down_filter_regressions; mod runtime_config; pub mod select; mod sql_api; diff --git a/datafusion/core/tests/sql/push_down_filter_regressions.rs b/datafusion/core/tests/sql/push_down_filter_regressions.rs new file mode 100644 index 0000000000000..a1ff8293c97a1 --- /dev/null +++ b/datafusion/core/tests/sql/push_down_filter_regressions.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use super::*; +use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; + +const WINDOW_SCALAR_SUBQUERY_SQL: &str = r#" + WITH suppliers AS ( + SELECT * + FROM (VALUES (1, 10.0), (1, 20.0)) AS t(nation, acctbal) + ) + SELECT + ROW_NUMBER() OVER (PARTITION BY nation ORDER BY acctbal DESC) AS rn + FROM suppliers AS s + WHERE acctbal > ( + SELECT AVG(acctbal) FROM suppliers + ) +"#; + +fn sqllogictest_style_ctx(push_down_filter_enabled: bool) -> SessionContext { + let ctx = + SessionContext::new_with_config(SessionConfig::new().with_target_partitions(4)); + if !push_down_filter_enabled { + assert!(ctx.remove_optimizer_rule("push_down_filter")); + } + ctx +} + +async fn capture_window_scalar_subquery_plans( + push_down_filter_enabled: bool, +) -> Result<(String, String)> { + let ctx = sqllogictest_style_ctx(push_down_filter_enabled); + let df = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?; + let optimized_plan = df.clone().into_optimized_plan()?; + let physical_plan = df.create_physical_plan().await?; + + Ok(( + optimized_plan.display_indent_schema().to_string(), + displayable(physical_plan.as_ref()).indent(true).to_string(), + )) +} + +#[tokio::test] +async fn window_scalar_subquery_regression() -> Result<()> { + let ctx = SessionContext::new(); + let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; + + assert_batches_eq!( + &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], + &results + ); + + Ok(()) +} + +#[tokio::test] +async fn window_scalar_subquery_sqllogictest_style_regression() -> Result<()> { + let ctx = sqllogictest_style_ctx(true); + let results = ctx.sql(WINDOW_SCALAR_SUBQUERY_SQL).await?.collect().await?; + + assert_batches_eq!( + &["+----+", "| rn |", "+----+", "| 1 |", "+----+",], + &results + ); + + Ok(()) +} + +#[tokio::test] +async fn aggregate_regr_functions_regression() -> Result<()> { + let ctx = SessionContext::new(); + let batch = RecordBatch::try_from_iter(vec![ + ( + "c11", + Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0])) as ArrayRef, + ), + ( + "c12", + Arc::new(Float64Array::from(vec![2.0, 4.0, 6.0])) as ArrayRef, + ), + ])?; + ctx.register_batch("aggregate_test_100", batch)?; + + let sql = r#" + select + regr_slope(c12, c11), + regr_intercept(c12, c11), + regr_count(c12, c11), + regr_r2(c12, c11), + regr_avgx(c12, c11), + regr_avgy(c12, c11), + regr_sxx(c12, c11), + regr_syy(c12, c11), + regr_sxy(c12, c11) + from aggregate_test_100 + "#; + + let rows = execute(&ctx, sql).await; + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].len(), 9); + assert!(rows[0].iter().all(|value| value != "NULL")); + + Ok(()) +} + +#[tokio::test] +async fn correlated_in_subquery_regression() -> Result<()> { + let ctx = SessionContext::new(); + let t1 = RecordBatch::try_from_iter(vec![ + ("t1_id", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "t1_name", + Arc::new(StringArray::from(vec!["alpha", "beta"])) as ArrayRef, + ), + ("t1_int", Arc::new(Int32Array::from(vec![1, 0])) as ArrayRef), + ])?; + let t2 = RecordBatch::try_from_iter(vec![( + "t2_id", + Arc::new(Int32Array::from(vec![12, 99])) as ArrayRef, + )])?; + ctx.register_batch("t1", t1)?; + ctx.register_batch("t2", t2)?; + + let sql = r#" + select t1.t1_id, + t1.t1_name, + t1.t1_int + from t1 + where t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_sorted_eq!( + &[ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 1 | alpha | 1 |", + "+-------+---------+--------+", + ], + &results + ); + + Ok(()) +} + +#[tokio::test] +async fn natural_join_union_regression() -> Result<()> { + let ctx = SessionContext::new(); + let t1 = RecordBatch::try_from_iter(vec![ + ("v0", Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef), + ( + "v2", + Arc::new(Int32Array::from(vec![None, Some(5)])) as ArrayRef, + ), + ])?; + // Keep `v2` only on the left side so the natural join key remains `v0`. + let t2 = RecordBatch::try_from_iter(vec![( + "v0", + Arc::new(Int32Array::from(vec![1, 2])) as ArrayRef, + )])?; + ctx.register_batch("t1", t1)?; + ctx.register_batch("t2", t2)?; + + let sql = r#" + SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 + UNION ALL + SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL) + "#; + + let results = ctx.sql(sql).await?.collect().await?; + + assert_batches_sorted_eq!( + &[ + "+----+----+", + "| v2 | v0 |", + "+----+----+", + "| | 1 |", + "| | 1 |", + "| 5 | 2 |", + "+----+----+", + ], + &results + ); + + Ok(()) +} + +#[tokio::test(flavor = "current_thread")] +async fn window_scalar_subquery_optimizer_delta() -> Result<()> { + let (enabled_optimized, enabled_physical) = + capture_window_scalar_subquery_plans(true).await?; + let (disabled_optimized, disabled_physical) = + capture_window_scalar_subquery_plans(false).await?; + + assert!( + enabled_optimized + .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") + ); + assert!(enabled_optimized.contains("Cross Join:")); + assert!( + disabled_optimized + .contains("Filter: s.acctbal > __scalar_sq_1.avg(suppliers.acctbal)") + ); + assert!(disabled_optimized.contains("Cross Join:")); + + assert!( + enabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") + ); + assert!(enabled_physical.contains("CrossJoinExec")); + assert!( + disabled_physical.contains("FilterExec: acctbal@1 > avg(suppliers.acctbal)@2") + ); + assert!(disabled_physical.contains("CrossJoinExec")); + + assert_eq!(enabled_optimized, disabled_optimized); + assert_eq!(enabled_physical, disabled_physical); + + Ok(()) +} diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 03a7a0b864177..14286a4480835 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -431,7 +431,10 @@ fn push_down_all_join( left_push.push(predicate); } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); - } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { + } else if is_inner_join + && can_promote_post_join_filter_to_join_condition(&join) + && can_evaluate_as_join_condition(&predicate)? + { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate // and convert to the join on condition join_conditions.push(predicate); @@ -512,6 +515,20 @@ fn push_down_all_join( Ok(Transformed::yes(plan)) } +/// Returns true when post-join filters are allowed to be promoted to join conditions. +/// +/// Protection is necessary for scalar-side joins and cross joins to avoid incorrectly +/// rewriting a post-join filter into the join condition when one side is empty or +/// limited to at most one row (`max_rows() == Some(1)`). +/// +/// - `join.on` non-empty means existing join predicates already exist; promotion is safe. +/// - if neither side is scalar (`max_rows() == Some(1)`), promotion is safe. +fn can_promote_post_join_filter_to_join_condition(join: &Join) -> bool { + !join.on.is_empty() + || !(matches!(join.left.max_rows(), Some(1)) + || matches!(join.right.max_rows(), Some(1))) +} + fn push_down_join( join: Join, parent_predicate: Option<&Expr>, @@ -1477,7 +1494,7 @@ mod tests { use crate::simplify_expressions::SimplifyExpressions; use crate::test::udfs::leaf_udf_expr; use crate::test::*; - use datafusion_expr::test::function_stub::sum; + use datafusion_expr::test::function_stub::{avg, sum}; use insta::assert_snapshot; use super::*; @@ -3579,6 +3596,34 @@ mod tests { ) } + #[test] + fn cross_join_with_scalar_side_keeps_post_join_filter() -> Result<()> { + let left = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a"), col("b")])? + .build()?; + let right = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) + .project(vec![col("a")])? + .aggregate(Vec::::new(), vec![avg(col("a")).alias("avg_a")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .cross_join(right)? + .filter(col("test.b").gt(col("avg_a")))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Filter: test.b > avg_a + Cross Join: + Projection: test.a, test.b + TableScan: test + Aggregate: groupBy=[[]], aggr=[[avg(test1.a) AS avg_a]] + Projection: test1.a + TableScan: test1 + " + ) + } + #[test] fn left_semi_join() -> Result<()> { let left = test_table_scan_with_name("test1")?; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 7e038d2392022..79dcefb312c3c 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,9 +17,37 @@ //! Utility functions leveraged by the query optimizer rules +mod null_restriction; + use std::collections::{BTreeSet, HashMap, HashSet}; +use std::sync::Arc; + +#[cfg(test)] +use std::sync::Mutex; use crate::analyzer::type_coercion::TypeCoercionRewriter; + +/// Null restriction evaluation mode for optimizer tests. +#[cfg(test)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum NullRestrictionEvalMode { + Auto, + AuthoritativeOnly, +} + +#[cfg(test)] +static NULL_RESTRICTION_EVAL_MODE: Mutex = + Mutex::new(NullRestrictionEvalMode::Auto); + +#[cfg(test)] +pub(crate) fn set_null_restriction_eval_mode_for_test(mode: NullRestrictionEvalMode) { + *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() = mode; +} + +#[cfg(test)] +fn null_restriction_eval_mode() -> NullRestrictionEvalMode { + *NULL_RESTRICTION_EVAL_MODE.lock().unwrap() +} use arrow::array::{Array, RecordBatch, new_null_array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; @@ -30,7 +58,6 @@ use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan}; use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; -use std::sync::Arc; /// Re-export of `NamesPreserver` for backwards compatibility, /// as it was initially placed here and then moved elsewhere. @@ -79,24 +106,45 @@ pub fn is_restrict_null_predicate<'a>( return Ok(true); } - // If result is single `true`, return false; - // If result is single `NULL` or `false`, return true; - Ok( - match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { - ColumnarValue::Array(array) => { - if array.len() == 1 { - let boolean_array = as_boolean_array(&array)?; - boolean_array.is_null(0) || !boolean_array.value(0) - } else { - false - } - } - ColumnarValue::Scalar(scalar) => matches!( - scalar, - ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) - ), - }, - ) + // Collect join columns so they can be used in both the fast-path check and the + // fallback evaluation path below. + let join_cols: HashSet<&Column> = join_cols_of_predicate.into_iter().collect(); + + // Fast path: if the predicate references columns outside the join key set, + // `evaluate_expr_with_null_column` would fail because the null schema only + // contains a placeholder for the join key columns. Callers treat such errors as + // non-restricting (false) via `matches!(_, Ok(true))`, so we return false early + // and avoid the expensive physical-expression compilation pipeline entirely. + if !null_restriction::predicate_uses_only_columns(&predicate, &join_cols) { + return Ok(false); + } + + #[cfg(test)] + if matches!( + null_restriction_eval_mode(), + NullRestrictionEvalMode::AuthoritativeOnly + ) { + return authoritative_restrict_null_predicate(predicate, join_cols); + } + + if let Some(is_restricting) = + null_restriction::syntactic_restrict_null_predicate(&predicate, &join_cols) + { + #[cfg(debug_assertions)] + { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + )?; + debug_assert_eq!( + is_restricting, authoritative, + "syntactic fast path disagrees with authoritative null-restriction evaluation for predicate: {predicate}" + ); + } + return Ok(is_restricting); + } + + authoritative_restrict_null_predicate(predicate, join_cols) } /// Determines if an expression will always evaluate to null. @@ -146,6 +194,28 @@ fn evaluate_expr_with_null_column<'a>( .evaluate(&input_batch) } +fn authoritative_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + Ok( + match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }, + ) +} + fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; expr.rewrite(&mut expr_rewrite).data() @@ -154,7 +224,9 @@ fn coerce(expr: Expr, schema: &DFSchema) -> Result { #[cfg(test)] mod tests { use super::*; - use datafusion_expr::{Operator, binary_expr, case, col, in_list, is_null, lit}; + use datafusion_expr::{ + Operator, binary_expr, case, col, in_list, is_null, lit, when, + }; #[test] fn expr_is_restrict_null_predicate() -> Result<()> { @@ -193,6 +265,27 @@ mod tests { .otherwise(lit(false))?, true, ), + // CASE 1 WHEN 1 THEN true ELSE false END + ( + case(lit(1i64)) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + false, + ), + // CASE 1 WHEN 1 THEN NULL ELSE false END + ( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + true, + ), + // CASE true WHEN true THEN false ELSE true END + ( + case(lit(true)) + .when(lit(true), lit(false)) + .otherwise(lit(true))?, + true, + ), // CASE a WHEN 0 THEN false ELSE true END ( case(col("a")) @@ -246,16 +339,128 @@ mod tests { in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true), true, ), + // CASE WHEN a IS NOT NULL THEN a ELSE b END > 2 + ( + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + true, + ), ]; - let column_a = Column::from_name("a"); for (predicate, expected) in test_cases { - let join_cols_of_predicate = std::iter::once(&column_a); - let actual = - is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + let join_cols_of_predicate = predicate.column_refs(); + let actual = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; assert_eq!(actual, expected, "{predicate}"); } + // Keep coverage for the fast path that rejects predicates referencing + // columns outside the provided join key set. + let predicate = binary_expr(col("a"), Operator::Gt, col("b")); + let column_a = Column::from_name("a"); + let actual = + is_restrict_null_predicate(predicate.clone(), std::iter::once(&column_a))?; + assert!(!actual, "{predicate}"); + + Ok(()) + } + + #[test] + fn syntactic_fast_path_matches_authoritative_evaluator() -> Result<()> { + let test_cases = vec![ + is_null(col("a")), + Expr::IsNotNull(Box::new(col("a"))), + binary_expr(col("a"), Operator::Gt, lit(8i64)), + binary_expr(col("a"), Operator::Eq, lit(ScalarValue::Null)), + binary_expr(col("a"), Operator::And, lit(true)), + binary_expr(col("a"), Operator::Or, lit(false)), + Expr::Not(Box::new(col("a").is_true())), + col("a").is_true(), + col("a").is_false(), + col("a").is_unknown(), + col("a").is_not_true(), + col("a").is_not_false(), + col("a").is_not_unknown(), + col("a").between(lit(1i64), lit(10i64)), + binary_expr( + when(Expr::IsNotNull(Box::new(col("a"))), col("a")) + .otherwise(col("b"))?, + Operator::Gt, + lit(2i64), + ), + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + binary_expr( + case(lit(1i64)) + .when(lit(1i64), lit(ScalarValue::Null)) + .otherwise(lit(false))?, + Operator::IsNotDistinctFrom, + lit(true), + ), + ]; + + for predicate in test_cases { + let join_cols = predicate.column_refs(); + if let Some(syntactic) = null_restriction::syntactic_restrict_null_predicate( + &predicate, &join_cols, + ) { + let authoritative = authoritative_restrict_null_predicate( + predicate.clone(), + join_cols.iter().copied(), + ) + .unwrap_or_else(|error| { + panic!( + "authoritative evaluator failed for predicate `{predicate}`: {error}" + ) + }); + assert_eq!( + syntactic, authoritative, + "syntactic fast path disagrees with authoritative evaluator for predicate: {predicate}", + ); + } + } + + Ok(()) + } + + #[test] + fn null_restriction_eval_mode_auto_vs_authoritative_only() -> Result<()> { + let predicate = binary_expr(col("a"), Operator::Gt, lit(8i64)); + let join_cols_of_predicate = predicate.column_refs(); + + set_null_restriction_eval_mode_for_test(NullRestrictionEvalMode::Auto); + let auto_result = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; + + set_null_restriction_eval_mode_for_test( + NullRestrictionEvalMode::AuthoritativeOnly, + ); + let authoritative_result = is_restrict_null_predicate( + predicate.clone(), + join_cols_of_predicate.iter().copied(), + )?; + + assert_eq!(auto_result, authoritative_result); + Ok(()) } } diff --git a/datafusion/optimizer/src/utils/null_restriction.rs b/datafusion/optimizer/src/utils/null_restriction.rs new file mode 100644 index 0000000000000..28e32a26ed1da --- /dev/null +++ b/datafusion/optimizer/src/utils/null_restriction.rs @@ -0,0 +1,260 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Syntactic null-restriction evaluator used by optimizer fast paths. + +use std::collections::HashSet; + +use datafusion_common::{Column, ScalarValue}; +use datafusion_expr::{BinaryExpr, Expr, Operator}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum NullSubstitutionValue { + /// SQL NULL after substituting join columns with NULL. + Null, + /// Known to be non-null, but value is otherwise unknown. + NonNull, + /// A known boolean outcome from SQL three-valued logic. + Boolean(bool), +} + +pub(super) fn syntactic_restrict_null_predicate( + predicate: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match syntactic_null_substitution_value(predicate, join_cols) { + Some(NullSubstitutionValue::Boolean(true)) => Some(false), + Some(NullSubstitutionValue::Boolean(false) | NullSubstitutionValue::Null) => { + Some(true) + } + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +pub(super) fn predicate_uses_only_columns( + predicate: &Expr, + allowed_columns: &HashSet<&Column>, +) -> bool { + predicate + .column_refs() + .iter() + .all(|column| allowed_columns.contains(*column)) +} + +fn contains_null( + values: impl IntoIterator>, +) -> bool { + values + .into_iter() + .any(|value| matches!(value, Some(NullSubstitutionValue::Null))) +} + +fn not(value: Option) -> Option { + match value { + Some(NullSubstitutionValue::Boolean(value)) => { + Some(NullSubstitutionValue::Boolean(!value)) + } + Some(NullSubstitutionValue::Null) => Some(NullSubstitutionValue::Null), + Some(NullSubstitutionValue::NonNull) | None => None, + } +} + +fn binary_boolean_value( + left: Option, + right: Option, + when_short_circuit: bool, +) -> Option { + let short_circuit = Some(NullSubstitutionValue::Boolean(when_short_circuit)); + let identity = Some(NullSubstitutionValue::Boolean(!when_short_circuit)); + + if left == short_circuit || right == short_circuit { + return short_circuit; + } + + match (left, right) { + (value, other) if value == identity => other, + (other, value) if value == identity => other, + (Some(NullSubstitutionValue::Null), Some(NullSubstitutionValue::Null)) => { + Some(NullSubstitutionValue::Null) + } + (Some(NullSubstitutionValue::NonNull), _) + | (_, Some(NullSubstitutionValue::NonNull)) + | (None, _) + | (_, None) => None, + (left, right) => { + debug_assert_eq!(left, right); + left + } + } +} + +fn null_check_value( + value: Option, + is_not_null: bool, +) -> Option { + match value { + Some(NullSubstitutionValue::Null) => { + Some(NullSubstitutionValue::Boolean(!is_not_null)) + } + Some(NullSubstitutionValue::NonNull | NullSubstitutionValue::Boolean(_)) => { + Some(NullSubstitutionValue::Boolean(is_not_null)) + } + None => None, + } +} + +fn null_if_contains_null( + values: impl IntoIterator>, +) -> Option { + contains_null(values).then_some(NullSubstitutionValue::Null) +} + +fn syntactic_null_substitution_value( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + match expr { + Expr::Alias(alias) => { + syntactic_null_substitution_value(alias.expr.as_ref(), join_cols) + } + Expr::Column(column) => join_cols + .contains(column) + .then_some(NullSubstitutionValue::Null), + Expr::Literal(value, _) => Some(scalar_to_null_substitution_value(value)), + Expr::BinaryExpr(binary_expr) => syntactic_binary_value(binary_expr, join_cols), + Expr::Not(expr) => { + not(syntactic_null_substitution_value(expr.as_ref(), join_cols)) + } + Expr::IsNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + false, + ), + Expr::IsNotNull(expr) => null_check_value( + syntactic_null_substitution_value(expr.as_ref(), join_cols), + true, + ), + Expr::Between(between) => null_if_contains_null([ + syntactic_null_substitution_value(between.expr.as_ref(), join_cols), + syntactic_null_substitution_value(between.low.as_ref(), join_cols), + syntactic_null_substitution_value(between.high.as_ref(), join_cols), + ]), + Expr::Cast(cast) => strict_null_passthrough(cast.expr.as_ref(), join_cols), + Expr::TryCast(try_cast) => { + strict_null_passthrough(try_cast.expr.as_ref(), join_cols) + } + Expr::Negative(expr) => strict_null_passthrough(expr.as_ref(), join_cols), + Expr::Like(like) | Expr::SimilarTo(like) => null_if_contains_null([ + syntactic_null_substitution_value(like.expr.as_ref(), join_cols), + syntactic_null_substitution_value(like.pattern.as_ref(), join_cols), + ]), + Expr::Exists { .. } + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::SetComparison(_) + | Expr::ScalarSubquery(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::Placeholder(_) + | Expr::ScalarVariable(_, _) + | Expr::Unnest(_) + | Expr::GroupingSet(_) + | Expr::WindowFunction(_) + | Expr::ScalarFunction(_) + | Expr::Case(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) => None, + Expr::AggregateFunction(_) => None, + // TODO: remove the next line after `Expr::Wildcard` is removed + #[expect(deprecated)] + Expr::Wildcard { .. } => None, + } +} + +fn scalar_to_null_substitution_value(value: &ScalarValue) -> NullSubstitutionValue { + match value { + _ if value.is_null() => NullSubstitutionValue::Null, + ScalarValue::Boolean(Some(value)) => NullSubstitutionValue::Boolean(*value), + _ => NullSubstitutionValue::NonNull, + } +} + +fn strict_null_passthrough( + expr: &Expr, + join_cols: &HashSet<&Column>, +) -> Option { + matches!( + syntactic_null_substitution_value(expr, join_cols), + Some(NullSubstitutionValue::Null) + ) + .then_some(NullSubstitutionValue::Null) +} + +fn syntactic_binary_value( + binary_expr: &BinaryExpr, + join_cols: &HashSet<&Column>, +) -> Option { + let left = syntactic_null_substitution_value(binary_expr.left.as_ref(), join_cols); + let right = syntactic_null_substitution_value(binary_expr.right.as_ref(), join_cols); + + match binary_expr.op { + Operator::And => binary_boolean_value(left, right, false), + Operator::Or => binary_boolean_value(left, right, true), + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::LtEq + | Operator::Gt + | Operator::GtEq + | Operator::Plus + | Operator::Minus + | Operator::Multiply + | Operator::Divide + | Operator::Modulo + | Operator::RegexMatch + | Operator::RegexIMatch + | Operator::RegexNotMatch + | Operator::RegexNotIMatch + | Operator::LikeMatch + | Operator::ILikeMatch + | Operator::NotLikeMatch + | Operator::NotILikeMatch + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::BitwiseShiftRight + | Operator::BitwiseShiftLeft + | Operator::StringConcat + | Operator::AtArrow + | Operator::ArrowAt + | Operator::Arrow + | Operator::LongArrow + | Operator::HashArrow + | Operator::HashLongArrow + | Operator::AtAt + | Operator::IntegerDivide + | Operator::HashMinus + | Operator::AtQuestion + | Operator::Question + | Operator::QuestionAnd + | Operator::QuestionPipe + | Operator::Colon => null_if_contains_null([left, right]), + Operator::IsDistinctFrom | Operator::IsNotDistinctFrom => None, + } +}