Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions datafusion/sql/src/set_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{
DataFusionError, Diagnostic, Result, Span, not_impl_err, plan_err,
DFSchemaRef, DataFusionError, Diagnostic, Result, Span, not_impl_err, plan_err,
};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{SetExpr, SetOperator, SetQuantifier, Spanned};
use sqlparser::ast::{
Expr as SQLExpr, Ident, SelectItem, SetExpr, SetOperator, SetQuantifier, Spanned,
};

impl<S: ContextProvider> SqlToRel<'_, S> {
#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
Expand All @@ -36,13 +38,33 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
SetExpr::SetOperation {
op,
left,
right,
mut right,
set_quantifier,
} => {
let left_span = Span::try_from_sqlparser_span(left.span());
let right_span = Span::try_from_sqlparser_span(right.span());
let left_plan = self.set_expr_to_plan(*left, planner_context);

// For non-*ByName operations, add missing aliases to right side using left schema's
// column names. This allows queries like
// `SELECT 1 a, 1 b UNION ALL SELECT 2, 2`
// where the right side has duplicate literal values.
// We only do this if the left side succeeded.
if let Ok(plan) = &left_plan
&& plan.schema().fields().len() > 1
&& matches!(
set_quantifier,
SetQuantifier::All
| SetQuantifier::Distinct
| SetQuantifier::None
)
{
alias_set_expr(&mut right, plan.schema())
}

let right_plan = self.set_expr_to_plan(*right, planner_context);

// Handle errors from both sides, collecting them if both failed
let (left_plan, right_plan) = match (left_plan, right_plan) {
(Ok(left_plan), Ok(right_plan)) => (left_plan, right_plan),
(Err(left_err), Err(right_err)) => {
Expand Down Expand Up @@ -160,3 +182,67 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}
}

// Adds aliases to SELECT items in a SetExpr using the provided schema.
// This ensures that unnamed expressions on the right side of a UNION/INTERSECT/EXCEPT
// get aliased with the column names from the left side, allowing queries like
// `SELECT 1 AS a, 0 AS b, 0 AS c UNION ALL SELECT 2, 0, 0` to work correctly.
fn alias_set_expr(set_expr: &mut SetExpr, schema: &DFSchemaRef) {
match set_expr {
SetExpr::Select(select) => alias_select_items(&mut select.projection, schema),
// For nested set operations, only alias the leftmost branch
SetExpr::SetOperation { left, .. } => alias_set_expr(left, schema),
// Handle parenthesized queries like (SELECT ... UNION ALL SELECT ...)
SetExpr::Query(query) => alias_set_expr(&mut query.body, schema),
// For other cases (Values, etc.), return as-is
_other => (),
}
}

// Aliases unnamed expressions in the provided select items using the provided schema.
// This helps with set expression queries where the right side has duplicate expressions,
// but the left side has unique column names, which control the output schema anyway.
fn alias_select_items(items: &mut [SelectItem], schema: &DFSchemaRef) {
// Figure out how many (qualified) wildcards we got. We only handle
// the case of a single unqualified wildcard; for multiple or qualified
// wildcards we can't reliably determine column counts, so bail out.
let (wildcard_count, qualified_wildcard_count) =
items.iter().fold((0, 0), |(wc, qwc), item| match item {
SelectItem::Wildcard(_) => (wc + 1, qwc),
SelectItem::QualifiedWildcard(_, _) => (wc, qwc + 1),
_ => (wc, qwc),
});
if qualified_wildcard_count > 0 || wildcard_count > 1 {
return;
}

let wildcard_expansion = schema.fields().len().saturating_sub(items.len() - 1);

let mut col_idx = 0;
for item in items.iter_mut() {
match item {
SelectItem::UnnamedExpr(expr) => {
if !matches!(
expr,
SQLExpr::Identifier(_) | SQLExpr::CompoundIdentifier(_)
) && let Some(field) = schema.fields().get(col_idx)
{
*item = SelectItem::ExprWithAlias {
expr: expr.clone(),
alias: Ident::new(field.name()),
};
}
col_idx += 1;
}
SelectItem::ExprWithAlias { .. } => {
col_idx += 1;
}
SelectItem::Wildcard(_) => {
col_idx += wildcard_expansion;
}
SelectItem::QualifiedWildcard(_, _) => {
unreachable!("qualified wildcards are handled above")
}
}
}
}
100 changes: 100 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2655,6 +2655,106 @@ fn union_all_by_name_same_column_names() {
);
}

#[test]
fn union_all_with_duplicate_expressions() {
let sql = "\
SELECT 0 a, 0 b \
UNION ALL SELECT 1, 1 \
UNION ALL SELECT count(*), count(*) FROM orders";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
Union
Union
Projection: Int64(0) AS a, Int64(0) AS b
EmptyRelation: rows=1
Projection: Int64(1) AS a, Int64(1) AS b
EmptyRelation: rows=1
Projection: count(*) AS a, count(*) AS b
Aggregate: groupBy=[[]], aggr=[[count(*)]]
TableScan: orders
"
);
}

#[test]
fn union_with_qualified_and_duplicate_expressions() {
let sql = "\
SELECT 0 a, id b, price c, 0 d FROM test_decimal \
UNION SELECT 1, *, 1 FROM test_decimal";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@"
Distinct:
Union
Projection: Int64(0) AS a, test_decimal.id AS b, test_decimal.price AS c, Int64(0) AS d
TableScan: test_decimal
Projection: Int64(1) AS a, test_decimal.id, test_decimal.price, Int64(1) AS d
TableScan: test_decimal
"
);
}

#[test]
fn intersect_with_duplicate_expressions() {
let sql = "\
SELECT 0 a, 0 b \
INTERSECT SELECT 1, 1 \
INTERSECT SELECT count(*), count(*) FROM orders";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
LeftSemi Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
LeftSemi Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
Projection: Int64(0) AS a, Int64(0) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: Int64(1) AS a, Int64(1) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: count(*) AS a, count(*) AS b
Aggregate: groupBy=[[]], aggr=[[count(*)]]
TableScan: orders
"
);
}

#[test]
fn except_with_duplicate_expressions() {
let sql = "\
SELECT 0 a, 0 b \
EXCEPT SELECT 1, 1 \
EXCEPT SELECT count(*), count(*) FROM orders";
let plan = logical_plan(sql).unwrap();
assert_snapshot!(
plan,
@r"
LeftAnti Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
LeftAnti Join: left.a = right.a, left.b = right.b
Distinct:
SubqueryAlias: left
Projection: Int64(0) AS a, Int64(0) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: Int64(1) AS a, Int64(1) AS b
EmptyRelation: rows=1
SubqueryAlias: right
Projection: count(*) AS a, count(*) AS b
Aggregate: groupBy=[[]], aggr=[[count(*)]]
TableScan: orders
"
);
}

#[test]
fn empty_over() {
let sql = "SELECT order_id, MAX(order_id) OVER () from orders";
Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ Bob_new
John
John_new

# Test UNION ALL with unaliased duplicate literal values on the right side.
# The second projection will inherit field names from the first one, and so
# pass the unique projection expression name check.
query TII rowsort
SELECT name, 1 as table, 1 as row FROM t1 WHERE id = 1
UNION ALL
SELECT name, 2, 2 FROM t2 WHERE id = 2
----
Alex 1 1
Bob 2 2

# Test nested UNION, EXCEPT, INTERSECT with duplicate unaliased literals.
# Only the first SELECT has column aliases, which should propagate to all projections.
query III rowsort
SELECT 1 as a, 0 as b, 0 as c
UNION ALL
((SELECT 2, 0, 0 UNION ALL SELECT 3, 0, 0) EXCEPT SELECT 3, 0, 0)
UNION ALL
(SELECT 4, 0, 0 INTERSECT SELECT 4, 0, 0)
----
1 0 0
2 0 0
4 0 0

# Plan is unnested
query TT
EXPLAIN SELECT name FROM t1 UNION ALL (SELECT name from t2 UNION ALL SELECT name || '_new' from t2)
Expand Down
Loading