diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 960aff8702..5c14c74164 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -211,7 +211,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[WeekDay] -> CometWeekDay, classOf[DayOfYear] -> CometDayOfYear, classOf[WeekOfYear] -> CometWeekOfYear, - classOf[Quarter] -> CometQuarter) + classOf[Quarter] -> CometQuarter, + classOf[Years] -> CometYears) private val conversionExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[Cast] -> CometCast) diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala b/spark/src/main/scala/org/apache/comet/serde/datetime.scala index d36b6a3b40..698ea5338f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala +++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala @@ -21,8 +21,8 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, LastDay, Literal, MakeDate, Minute, Month, NextDay, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year} -import org.apache.spark.sql.types.{DateType, IntegerType, StringType, TimestampType} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DateAdd, DateDiff, DateFormatClass, DateSub, DayOfMonth, DayOfWeek, DayOfYear, GetDateField, Hour, LastDay, Literal, MakeDate, Minute, Month, NextDay, Quarter, Second, TruncDate, TruncTimestamp, UnixDate, UnixTimestamp, WeekDay, WeekOfYear, Year, Years} +import org.apache.spark.sql.types.{DateType, IntegerType, StringType, TimestampNTZType, TimestampType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -541,3 +541,37 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] { } } } + +object CometYears extends CometExpressionSerde[Years] { + + override def getSupportLevel(expr: Years): SupportLevel = { + expr.child.dataType match { + case DateType | TimestampType | TimestampNTZType => Compatible() + case _ => Unsupported(Some(s"Years does not support type: ${expr.child.dataType}")) + } + } + + override def convert( + expr: Years, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val periodType = + exprToProtoInternal(Literal(CometGetDateField.Year.toString), inputs, binding) + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val optExpr = scalarFunctionExprToProto("datepart", Seq(periodType, childExpr): _*) + .map(e => { + Expr + .newBuilder() + .setCast( + ExprOuterClass.Cast + .newBuilder() + .setChild(e) + .setDatatype(serializeDataType(IntegerType).get) + .setEvalMode(ExprOuterClass.EvalMode.LEGACY) + .setAllowIncompat(false) + .build()) + .build() + }) + optExprWithInfo(optExpr, expr, expr.child) + } +} diff --git a/spark/src/test/resources/sql-tests/expressions/datetime/years.sql b/spark/src/test/resources/sql-tests/expressions/datetime/years.sql new file mode 100644 index 0000000000..1bd4dfdbd8 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/datetime/years.sql @@ -0,0 +1,56 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.sql.catalog.test_cat=org.apache.iceberg.spark.SparkCatalog +-- Config: spark.sql.catalog.test_cat.type=hadoop +-- Config: spark.sql.catalog.test_cat.warehouse=/tmp/iceberg-warehouse +-- Config: spark.comet.scan.icebergNative.enabled=true + +statement +CREATE DATABASE IF NOT EXISTS test_cat.db + +statement +CREATE TABLE test_cat.db.test_years_iceberg ( + id INT, + event_date DATE, + value STRING +) USING iceberg +PARTITIONED BY (years(event_date)) + +statement +INSERT INTO test_cat.db.test_years_iceberg VALUES + (1, DATE '2022-06-15', 'a'), + (2, DATE '2023-03-20', 'b'), + (3, DATE '2023-11-10', 'c'), + (4, DATE '2024-01-05', 'd'), + (5, DATE '2024-07-20', 'e'), + (6, DATE '2024-12-31', 'f') + +query +SELECT * FROM test_cat.db.test_years_iceberg ORDER BY id + +query +SELECT * FROM test_cat.db.test_years_iceberg WHERE event_date = DATE '2023-03-20' + +query +SELECT * FROM test_cat.db.test_years_iceberg WHERE event_date >= DATE '2023-01-01' AND event_date < DATE '2024-01-01' ORDER BY id + +query +SELECT year(event_date) as yr, COUNT(*) as cnt FROM test_cat.db.test_years_iceberg GROUP BY year(event_date) ORDER BY yr + +statement +DROP TABLE test_cat.db.test_years_iceberg PURGE diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 1ab8d54fd2..78fea28a26 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, TruncDate, TruncTimestamp} +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast, FromUnixTime, Literal, TruncDate, TruncTimestamp, Years} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.CometProjectExec import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} @@ -39,6 +39,8 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.sql.types._ import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.serde.{CometYears, Compatible, Unsupported} +import org.apache.comet.serde.QueryPlanSerde.exprToProto import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -3153,4 +3155,40 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("support Years partition transform (serialization only)") { + val input = Seq(java.sql.Date.valueOf("2024-01-15")).toDF("col") + val inputAttrs = input.queryExecution.analyzed.output + val yearsExpr = Years(input.col("col").expr) + val proto = exprToProto(yearsExpr, inputAttrs, binding = false) + + assert(proto.isDefined, "Comet failed to serialize the Years expression!") + + val expr = proto.get + assert(expr.hasCast, "Expected the result to be a Cast (to Integer)") + assert(expr.getCast.getChild.hasScalarFunc, "Expected Cast child to be a Scalar Function") + assert( + expr.getCast.getChild.getScalarFunc.getFunc == "datepart", + "Expected function to be 'datepart'") + } + + test("Years support level") { + val supportedTypes = Seq(DateType, TimestampType, TimestampNTZType) + val unsupportedTypes = Seq(StringType, IntegerType, LongType) + + supportedTypes.foreach { dt => + val child = Literal.default(dt) + val expr = Years(child) + val result = CometYears.getSupportLevel(expr) + + assert(result.isInstanceOf[Compatible], s"Expected $dt to be Compatible") + } + + unsupportedTypes.foreach { dt => + val child = Literal.default(dt) + val expr = Years(child) + val result = CometYears.getSupportLevel(expr) + + assert(result.isInstanceOf[Unsupported], s"Expected $dt to be Unsupported") + } + } }