diff --git a/native/spark-expr/src/datetime_funcs/extract_date_part.rs b/native/spark-expr/src/datetime_funcs/extract_date_part.rs index acb7d2266e..7289f45d28 100644 --- a/native/spark-expr/src/datetime_funcs/extract_date_part.rs +++ b/native/spark-expr/src/datetime_funcs/extract_date_part.rs @@ -18,7 +18,7 @@ use crate::utils::array_with_timezone; use arrow::compute::{date_part, DatePart}; use arrow::datatypes::{DataType, TimeUnit::Microsecond}; -use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::common::internal_datafusion_err; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -86,9 +86,23 @@ macro_rules! extract_date_part { let result = date_part(&array, DatePart::$date_part_variant)?; Ok(ColumnarValue::Array(result)) } - _ => Err(DataFusionError::Execution( - concat!($fn_name, "(scalar) should be fold in Spark JVM side.").to_string(), - )), + [ColumnarValue::Scalar(scalar)] => { + // When Spark's ConstantFolding is disabled, literal-only expressions like + // hour can reach the native engine as scalar inputs. + // Instead of failing and requiring JVM folding, we evaluate the scalar + // natively by broadcasting it to a single-element array. + let array = scalar.clone().to_array_of_size(1)?; + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::$date_part_variant)?; + Ok(ColumnarValue::Array(result)) + } } } diff --git a/native/spark-expr/src/datetime_funcs/unix_timestamp.rs b/native/spark-expr/src/datetime_funcs/unix_timestamp.rs index c4f1576293..3eb518aaec 100644 --- a/native/spark-expr/src/datetime_funcs/unix_timestamp.rs +++ b/native/spark-expr/src/datetime_funcs/unix_timestamp.rs @@ -16,10 +16,10 @@ // under the License. use crate::utils::array_with_timezone; -use arrow::array::{Array, AsArray, PrimitiveArray}; +use arrow::array::{Array, AsArray, Int64Array, PrimitiveArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Int64Type, TimeUnit::Microsecond}; -use datafusion::common::{internal_datafusion_err, DataFusionError}; +use datafusion::common::{internal_datafusion_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; @@ -77,79 +77,37 @@ impl ScalarUDFImpl for SparkUnixTimestamp { .map_err(|_| internal_datafusion_err!("unix_timestamp expects exactly one argument"))?; match args { - [ColumnarValue::Array(array)] => match array.data_type() { - DataType::Timestamp(_, _) => { - let is_utc = self.timezone == "UTC"; - let array = if is_utc - && matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC") - { - array - } else { - array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), - )? - }; - - let timestamp_array = - array.as_primitive::(); - - let result: PrimitiveArray = if timestamp_array.null_count() == 0 { - timestamp_array - .values() - .iter() - .map(|µs| micros / MICROS_PER_SECOND) - .collect() - } else { - timestamp_array - .iter() - .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) - .collect() - }; - - Ok(ColumnarValue::Array(Arc::new(result))) - } - DataType::Date32 => { - let timestamp_array = cast(&array, &DataType::Timestamp(Microsecond, None))?; - - let is_utc = self.timezone == "UTC"; - let array = if is_utc { - timestamp_array - } else { - array_with_timezone( - timestamp_array, - self.timezone.clone(), - Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), - )? - }; - - let timestamp_array = - array.as_primitive::(); - - let result: PrimitiveArray = if timestamp_array.null_count() == 0 { - timestamp_array - .values() - .iter() - .map(|µs| micros / MICROS_PER_SECOND) - .collect() - } else { - timestamp_array - .iter() - .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) - .collect() - }; - - Ok(ColumnarValue::Array(Arc::new(result))) - } - _ => Err(DataFusionError::Execution(format!( - "unix_timestamp does not support input type: {:?}", - array.data_type() - ))), - }, - _ => Err(DataFusionError::Execution( - "unix_timestamp(scalar) should be fold in Spark JVM side.".to_string(), - )), + [ColumnarValue::Array(array)] => self.eval_array(&array), + [ColumnarValue::Scalar(scalar)] => { + // When Spark's ConstantFolding is disabled, literal-only expressions like + // unix_timestamp can reach the native engine + // as scalar inputs. Evaluate the scalar natively by broadcasting it to a + // single-element array and converting the result back to a scalar. + let array = scalar.clone().to_array_of_size(1)?; + let result = self.eval_array(&array)?; + + let result_array = match result { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(_) => { + return Err(DataFusionError::Internal( + "unix_timestamp: expected array result from eval_array".to_string(), + )) + } + }; + + let int64_array = result_array + .as_any() + .downcast_ref::() + .expect("unix_timestamp should return Int64Array"); + + let scalar_result = if int64_array.is_null(0) { + ScalarValue::Int64(None) + } else { + ScalarValue::Int64(Some(int64_array.value(0))) + }; + + Ok(ColumnarValue::Scalar(scalar_result)) + } } } @@ -158,6 +116,82 @@ impl ScalarUDFImpl for SparkUnixTimestamp { } } +impl SparkUnixTimestamp { + fn eval_array(&self, array: &Arc) -> datafusion::common::Result { + match array.data_type() { + DataType::Timestamp(_, _) => { + let is_utc = self.timezone == "UTC"; + let array = if is_utc + && matches!(array.data_type(), DataType::Timestamp(Microsecond, Some(tz)) if tz.as_ref() == "UTC") + { + Arc::clone(array) + } else { + array_with_timezone( + Arc::clone(array), + self.timezone.clone(), + Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), + )? + }; + + let timestamp_array = + array.as_primitive::(); + + let result: PrimitiveArray = if timestamp_array.null_count() == 0 { + timestamp_array + .values() + .iter() + .map(|µs| micros / MICROS_PER_SECOND) + .collect() + } else { + timestamp_array + .iter() + .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) + .collect() + }; + + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Date32 => { + let timestamp_array = + cast(array.as_ref(), &DataType::Timestamp(Microsecond, None))?; + + let is_utc = self.timezone == "UTC"; + let array = if is_utc { + timestamp_array + } else { + array_with_timezone( + timestamp_array, + self.timezone.clone(), + Some(&DataType::Timestamp(Microsecond, Some("UTC".into()))), + )? + }; + + let timestamp_array = + array.as_primitive::(); + + let result: PrimitiveArray = if timestamp_array.null_count() == 0 { + timestamp_array + .values() + .iter() + .map(|µs| micros / MICROS_PER_SECOND) + .collect() + } else { + timestamp_array + .iter() + .map(|v| v.map(|micros| div_floor(micros, MICROS_PER_SECOND))) + .collect() + }; + + Ok(ColumnarValue::Array(Arc::new(result))) + } + _ => Err(DataFusionError::Execution(format!( + "unix_timestamp does not support input type: {:?}", + array.data_type() + ))), + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 2999d8bfe5..ace8a0dce8 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -618,6 +618,27 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("scalar hour/minute/second/unix_timestamp with ConstantFolding disabled") { + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + val df = spark.sql(""" + |SELECT + | hour(timestamp('2026-04-18 04:18:45')) AS h, + | minute(timestamp('2026-04-18 04:18:45')) AS m, + | second(timestamp('2026-04-18 04:18:45')) AS s, + | unix_timestamp(timestamp('2020-01-01 00:00:00')) AS u + |""".stripMargin) + + val Row(h: Int, m: Int, s: Int, u: Long) = df.head() + + assert(h == 4) + assert(m == 18) + assert(s == 45) + assert(u == 1577836800L) + } + } + test("hour on int96 timestamp column") { import testImplicits._