diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index bf212a8219d02..92d8e90ac372e 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -20,6 +20,7 @@ pub mod expm1; pub mod factorial; pub mod hex; pub mod modulus; +pub mod negative; pub mod rint; pub mod trigonometry; pub mod unhex; @@ -40,6 +41,7 @@ make_udf_function!(unhex::SparkUnhex, unhex); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); make_udf_function!(trigonometry::SparkCsc, csc); make_udf_function!(trigonometry::SparkSec, sec); +make_udf_function!(negative::SparkNegative, negative); pub mod expr_fn { use datafusion_functions::export_functions; @@ -63,6 +65,11 @@ pub mod expr_fn { export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); export_functions!((csc, "Returns the cosecant of expr.", arg1)); export_functions!((sec, "Returns the secant of expr.", arg1)); + export_functions!(( + negative, + "Returns the negation of expr (unary minus).", + arg1 + )); } pub fn functions() -> Vec> { @@ -78,5 +85,6 @@ pub fn functions() -> Vec> { width_bucket(), csc(), sec(), + negative(), ] } diff --git a/datafusion/spark/src/function/math/negative.rs b/datafusion/spark/src/function/math/negative.rs new file mode 100644 index 0000000000000..f1803d2d771a2 --- /dev/null +++ b/datafusion/spark/src/function/math/negative.rs @@ -0,0 +1,293 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::types::*; +use arrow::array::*; +use arrow::datatypes::{DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit}; +use bigdecimal::num_traits::WrappingNeg; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-compatible `negative` expression +/// +/// +/// Returns the negation of input (equivalent to unary minus) +/// Returns NULL if input is NULL, returns NaN if input is NaN. +/// +/// ANSI mode support see (): +/// - Spark's ANSI-compliant dialect, when off (i.e. `spark.sql.ansi.enabled=false`), +/// negating the minimal value of a signed integer wraps around. +/// For example: negative(i32::MIN) returns i32::MIN (wraps instead of error). +/// This is the current implementation (legacy mode only). +/// - Spark's ANSI mode (when `spark.sql.ansi.enabled=true`) should throw an +/// ARITHMETIC_OVERFLOW error on integer overflow instead of wrapping. +/// This is not yet implemented - all operations currently use wrapping behavior. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkNegative { + signature: Signature, +} + +impl Default for SparkNegative { + fn default() -> Self { + Self::new() + } +} + +impl SparkNegative { + pub fn new() -> Self { + Self { + signature: Signature { + type_signature: TypeSignature::OneOf(vec![ + // Numeric types: signed integers, float, decimals + TypeSignature::Numeric(1), + // Interval types: YearMonth, DayTime, MonthDayNano + TypeSignature::Uniform( + 1, + vec![ + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Interval(IntervalUnit::MonthDayNano), + ], + ), + ]), + volatility: Volatility::Immutable, + parameter_names: None, + }, + } + } +} + +impl ScalarUDFImpl for SparkNegative { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "negative" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_negative(&args.args) + } +} + +/// Core implementation of Spark's negative function +fn spark_negative(args: &[ColumnarValue]) -> Result { + let [arg] = take_function_args("negative", args)?; + + match arg { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Null => Ok(arg.clone()), + + // Signed integers - use wrapping negation (Spark legacy mode behavior) + DataType::Int8 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int16 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int32 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Int64 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Floating point - simple negation (no overflow possible) + DataType::Float16 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| -x); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Float32 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| -x); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Float64 => { + let array = array.as_primitive::(); + let result: PrimitiveArray = array.unary(|x| -x); + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Decimal types - wrapping negation + DataType::Decimal32(_, _) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Decimal64(_, _) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Decimal128(_, _) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Decimal256(_, _) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // interval type + DataType::Interval(IntervalUnit::YearMonth) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| x.wrapping_neg()); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Interval(IntervalUnit::DayTime) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| IntervalDayTime { + days: x.days.wrapping_neg(), + milliseconds: x.milliseconds.wrapping_neg(), + }); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let array = array.as_primitive::(); + let result: PrimitiveArray = + array.unary(|x| IntervalMonthDayNano { + months: x.months.wrapping_neg(), + days: x.days.wrapping_neg(), + nanoseconds: x.nanoseconds.wrapping_neg(), + }); + Ok(ColumnarValue::Array(Arc::new(result))) + } + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + ColumnarValue::Scalar(sv) => match sv { + ScalarValue::Null => Ok(arg.clone()), + _ if sv.is_null() => Ok(arg.clone()), + + // Signed integers - wrapping negation + ScalarValue::Int8(Some(v)) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(result)))) + } + ScalarValue::Int16(Some(v)) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(result)))) + } + ScalarValue::Int32(Some(v)) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } + ScalarValue::Int64(Some(v)) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result)))) + } + + // Floating point - simple negation + ScalarValue::Float16(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float16(Some(-v)))) + } + ScalarValue::Float32(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(-v)))) + } + ScalarValue::Float64(Some(v)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(-v)))) + } + + // Decimal types - wrapping negation + ScalarValue::Decimal32(Some(v), precision, scale) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal32( + Some(result), + *precision, + *scale, + ))) + } + ScalarValue::Decimal64(Some(v), precision, scale) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal64( + Some(result), + *precision, + *scale, + ))) + } + ScalarValue::Decimal128(Some(v), precision, scale) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(result), + *precision, + *scale, + ))) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + let result = v.wrapping_neg(); + Ok(ColumnarValue::Scalar(ScalarValue::Decimal256( + Some(result), + *precision, + *scale, + ))) + } + + //interval type + ScalarValue::IntervalYearMonth(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::IntervalYearMonth(Some(v.wrapping_neg())), + )), + ScalarValue::IntervalDayTime(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: v.days.wrapping_neg(), + milliseconds: v.milliseconds.wrapping_neg(), + })), + )), + ScalarValue::IntervalMonthDayNano(Some(v)) => Ok(ColumnarValue::Scalar( + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: v.months.wrapping_neg(), + days: v.days.wrapping_neg(), + nanoseconds: v.nanoseconds.wrapping_neg(), + })), + )), + + dt => not_impl_err!("Not supported datatype for Spark negative(): {dt}"), + }, + } +} diff --git a/datafusion/sqllogictest/test_files/spark/math/negative.slt b/datafusion/sqllogictest/test_files/spark/math/negative.slt index aa8e558e9895e..c62267e4963ff 100644 --- a/datafusion/sqllogictest/test_files/spark/math/negative.slt +++ b/datafusion/sqllogictest/test_files/spark/math/negative.slt @@ -23,5 +23,257 @@ ## Original Query: SELECT negative(1); ## PySpark 3.5.5 Result: {'negative(1)': -1, 'typeof(negative(1))': 'int', 'typeof(1)': 'int'} -#query -#SELECT negative(1::int); + +# Test negative with integer +query I +SELECT negative(1::int); +---- +-1 + +# Test negative with positive integer +query I +SELECT negative(42::int); +---- +-42 + +# Test negative with negative integer +query I +SELECT negative(-10::int); +---- +10 + +# Test negative with zero +query I +SELECT negative(0::int); +---- +0 + +# Test negative with bigint +query I +SELECT negative(9223372036854775807::bigint); +---- +-9223372036854775807 + +# Test negative with negative bigint +query I +SELECT negative(-100::bigint); +---- +100 + +# Test negative with smallint +query I +SELECT negative(32767::smallint); +---- +-32767 + +# Test negative with float +query R +SELECT negative(3.14::float); +---- +-3.14 + +# Test negative with negative float +query R +SELECT negative(-2.5::float); +---- +2.5 + +# Test negative with double +query R +SELECT negative(3.14159265358979::double); +---- +-3.14159265358979 + +# Test negative with negative double +query R +SELECT negative(-1.5::double); +---- +1.5 + +# Test negative with decimal +query R +SELECT negative(123.456::decimal(10,3)); +---- +-123.456 + +# Test negative with negative decimal +query R +SELECT negative(-99.99::decimal(10,2)); +---- +99.99 + +# Test negative with NULL +query I +SELECT negative(NULL::int); +---- +NULL + +# Test negative with column values +statement ok +CREATE TABLE test_negative (id int, value int) AS VALUES (1, 10), (2, -20), (3, 0), (4, NULL); + +query II rowsort +SELECT id, negative(value) FROM test_negative; +---- +1 -10 +2 20 +3 0 +4 NULL + +statement ok +DROP TABLE test_negative; + +# Test negative in expressions +query I +SELECT negative(5) + 3; +---- +-2 + +# Test nested negative +query I +SELECT negative(negative(7)); +---- +7 + +# Test negative with large numbers +query R +SELECT negative(1234567890.123456::double); +---- +-1234567890.123456 + +# Test wrap-around: negative of minimum int (should wrap to same value) +# Using table to avoid constant folding overflow during optimization +statement ok +CREATE TABLE min_values_int AS VALUES (-2147483648); + +query I +SELECT negative(column1::int) FROM min_values_int; +---- +-2147483648 + +statement ok +DROP TABLE min_values_int; + +# Test wrap-around: negative of minimum bigint (should wrap to same value) +statement ok +CREATE TABLE min_values_bigint AS VALUES (-9223372036854775808); + +query I +SELECT negative(column1::bigint) FROM min_values_bigint; +---- +-9223372036854775808 + +statement ok +DROP TABLE min_values_bigint; + +# Test wrap-around: negative of minimum smallint (should wrap to same value) +statement ok +CREATE TABLE min_values_smallint AS VALUES (-32768); + +query I +SELECT negative(column1::smallint) FROM min_values_smallint; +---- +-32768 + +statement ok +DROP TABLE min_values_smallint; + +# Test wrap-around: negative of minimum tinyint (should wrap to same value) +statement ok +CREATE TABLE min_values_tinyint AS VALUES (-128); + +query I +SELECT negative(column1::tinyint) FROM min_values_tinyint; +---- +-128 + +statement ok +DROP TABLE min_values_tinyint; + +# Test overflow: negative of positive infinity (float) +query R +SELECT negative('Infinity'::float); +---- +-Infinity + +# Test overflow: negative of negative infinity (float) +query R +SELECT negative('-Infinity'::float); +---- +Infinity + +# Test overflow: negative of positive infinity (double) +query R +SELECT negative('Infinity'::double); +---- +-Infinity + +# Test overflow: negative of negative infinity (double) +query R +SELECT negative('-Infinity'::double); +---- +Infinity + +# Test overflow: negative of NaN (float) +query R +SELECT negative('NaN'::float); +---- +NaN + +# Test overflow: negative of NaN (double) +query R +SELECT negative('NaN'::double); +---- +NaN + +# Test overflow: negative of maximum float value +query R +SELECT negative(3.4028235e38::float); +---- +-340282350000000000000000000000000000000 + +# Test overflow: negative of minimum float value +query R +SELECT negative(-3.4028235e38::float); +---- +340282350000000000000000000000000000000 + +# Test overflow: negative of maximum double value +query R +SELECT negative(1.7976931348623157e308::double); +---- +-179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 + +# Test overflow: negative of minimum double value +query R +SELECT negative(-1.7976931348623157e308::double); +---- +179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000 + +# Test negative with CalendarIntervalType (IntervalMonthDayNano) +# Spark make_interval creates CalendarInterval +query ? +SELECT negative(make_interval(1, 2, 3, 4, 5, 6, 7.5)); +---- +-14 mons -25 days -5 hours -6 mins -7.500000000 secs + +# Test negative with negative CalendarIntervalType +query ? +SELECT negative(make_interval(-2, -5, -1, -10, -3, -30, -15.25)); +---- +29 mons 17 days 3 hours 30 mins 15.250000000 secs + +# Test negative with CalendarInterval from table +statement ok +CREATE TABLE interval_test AS VALUES + (make_interval(1, 2, 0, 5, 0, 0, 0.0)), + (make_interval(-3, -1, 0, -2, 0, 0, 0.0)); + +query ? rowsort +SELECT negative(column1) FROM interval_test; +---- +-14 mons -5 days +37 mons 2 days + +statement ok +DROP TABLE interval_test;