Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions native/spark-expr/benches/cast_from_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
});
group.finish();

// str -> decimal benchmark
let decimal_string_batch = create_decimal_cast_string_batch();
for (mode, mode_name) in [
(EvalMode::Legacy, "legacy"),
(EvalMode::Ansi, "ansi"),
(EvalMode::Try, "try"),
] {
let spark_cast_options = SparkCastOptions::new(mode, "", false);
let cast_to_decimal_38_10 = Cast::new(
expr.clone(),
DataType::Decimal128(38, 10),
spark_cast_options,
);

let mut group = c.benchmark_group(format!("cast_string_to_decimal/{}", mode_name));
group.bench_function("decimal_38_10", |b| {
b.iter(|| {
cast_to_decimal_38_10
.evaluate(&decimal_string_batch)
.unwrap()
});
});
group.finish();
}
}

/// Create batch with small integer strings that fit in i8 range (for i8/i16 benchmarks)
Expand Down Expand Up @@ -118,6 +143,51 @@ fn create_decimal_string_batch() -> RecordBatch {
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}

/// Create batch with decimal strings for string-to-decimal cast perf evaluation
fn create_decimal_cast_string_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)]));
let mut b = StringBuilder::new();
for i in 0..1000 {
if i % 10 == 0 {
b.append_null();
} else {
// Generate various decimal formats
match i % 5 {
0 => {
// gen simple decimals (ex : "123.45"
let int_part: u32 = rand::random::<u32>() % 1000000;
let dec_part: u32 = rand::random::<u32>() % 100000;
b.append_value(format!("{}.{}", int_part, dec_part));
}
1 => {
// gen scientific notation like "123e5"
let mantissa: u32 = rand::random::<u32>() % 1000;
let exp: i8 = (rand::random::<i8>() % 10).abs();
b.append_value(format!("{}.{}E{}", mantissa / 100, mantissa % 100, exp));
}
2 => {
// Negative numbers
let int_part: u32 = rand::random::<u32>() % 1000000;
let dec_part: u32 = rand::random::<u32>() % 100000;
b.append_value(format!("-{}.{}", int_part, dec_part));
}
3 => {
// Ints only
let val: i32 = rand::random::<i32>() % 1000000;
b.append_value(format!("{}", val));
}
_ => {
// Small decimals (ex : 0.001)
let dec_part: u32 = rand::random::<u32>() % 100000;
b.append_value(format!("0.{:05}", dec_part));
}
}
}
}
let array = b.finish();
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}

fn config() -> Criterion {
Criterion::default()
}
Expand Down
166 changes: 97 additions & 69 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2317,8 +2317,8 @@ fn cast_string_to_decimal256_impl(
}

/// Parse a string to decimal following Spark's behavior
fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
let string_bytes = s.as_bytes();
fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
let string_bytes = input_str.as_bytes();
let mut start = 0;
let mut end = string_bytes.len();

Expand All @@ -2330,7 +2330,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
end -= 1;
}

let trimmed = &s[start..end];
let trimmed = &input_str[start..end];

if trimmed.is_empty() {
return Ok(None);
Expand All @@ -2347,73 +2347,101 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Opt
return Ok(None);
}

// validate and parse mantissa and exponent
match parse_decimal_str(trimmed) {
Ok((mantissa, exponent)) => {
// Convert to target scale
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;
// validate and parse mantissa and exponent or bubble up the error
let (mantissa, exponent) = parse_decimal_str(trimmed, input_str, precision, scale)?;

let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}
// Early return mantissa 0, Spark checks if it fits digits and throw error in ansi
if mantissa == 0 {
if exponent < -37 {
return Err(SparkError::NumericOutOfRange {
value: input_str.to_string(),
});
}
return Ok(Some(0));
}

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};
// scale adjustment
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;

match scaled_value {
Some(value) => {
// Check if it fits target precision
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
Ok(None)
}
}
None => {
// Overflow while scaling
Ok(None)
}
let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to divide (decrease scale)
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};

match scaled_value {
Some(value) => {
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
// Value ok but exceeds precision mentioned . THrow error
Err(SparkError::NumericValueOutOfRange {
value: trimmed.to_string(),
precision,
scale,
})
}
}
Err(_) => Ok(None),
None => {
// Overflow when scaling raise exception
Err(SparkError::NumericValueOutOfRange {
value: trimmed.to_string(),
precision,
scale,
})
}
}
}

fn invalid_decimal_cast(value: &str, precision: u8, scale: i8) -> SparkError {
invalid_value(
value,
"STRING",
&format!("DECIMAL({},{})", precision, scale),
)
}

/// Parse a decimal string into mantissa and scale
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
fn parse_decimal_str(
s: &str,
original_str: &str,
precision: u8,
scale: i8,
) -> SparkResult<(i128, i32)> {
if s.is_empty() {
return Err("Empty string".to_string());
return Err(invalid_decimal_cast(original_str, precision, scale));
}

let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
Expand All @@ -2422,7 +2450,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
// Parse exponent
let exp: i32 = exponent_part
.parse()
.map_err(|e| format!("Invalid exponent: {}", e))?;
.map_err(|_| invalid_decimal_cast(original_str, precision, scale))?;

(mantissa_part, exp)
} else {
Expand All @@ -2437,29 +2465,29 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
};

if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
return Err("Invalid sign format".to_string());
return Err(invalid_decimal_cast(original_str, precision, scale));
}

let (integral_part, fractional_part) = match mantissa_str.find('.') {
Some(dot_pos) => {
if mantissa_str[dot_pos + 1..].contains('.') {
return Err("Multiple decimal points".to_string());
return Err(invalid_decimal_cast(original_str, precision, scale));
}
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
}
None => (mantissa_str, ""),
};

if integral_part.is_empty() && fractional_part.is_empty() {
return Err("No digits found".to_string());
return Err(invalid_decimal_cast(original_str, precision, scale));
}

if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid integral part".to_string());
return Err(invalid_decimal_cast(original_str, precision, scale));
}

if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid fractional part".to_string());
return Err(invalid_decimal_cast(original_str, precision, scale));
}

// Parse integral part
Expand All @@ -2469,7 +2497,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
} else {
integral_part
.parse()
.map_err(|_| "Invalid integral part".to_string())?
.map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
};

// Parse fractional part
Expand All @@ -2479,14 +2507,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
} else {
fractional_part
.parse()
.map_err(|_| "Invalid fractional part".to_string())?
.map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
};

// Combine: value = integral * 10^fractional_scale + fractional
let mantissa = integral_value
.checked_mul(10_i128.pow(fractional_scale as u32))
.and_then(|v| v.checked_add(fractional_value))
.ok_or("Overflow in mantissa calculation")?;
.ok_or_else(|| invalid_decimal_cast(original_str, precision, scale))?;

let final_mantissa = if negative { -mantissa } else { mantissa };
// final scale = fractional_scale - exponent
Expand Down
3 changes: 3 additions & 0 deletions native/spark-expr/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub enum SparkError {
scale: i8,
},

#[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be interpreted as a numeric since it has more than 38 digits.")]
NumericOutOfRange { value: String },

#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
Expand Down
Loading
Loading