diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 5db07c4cea527..639bc07a68b1e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3647,6 +3647,11 @@ "expects one of binary formats 'base64', 'hex', 'utf-8', but got ." ] }, + "BITS_RANGE" : { + "message" : [ + "expects an integer value in [1, 64], but got ." + ] + }, "BIT_POSITION_RANGE" : { "message" : [ "expects an integer value in [0, ), but got ." diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 330e323b5b205..c38ad0c1ae945 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -304,8 +304,13 @@ def bitwise_not(col: "ColumnOrName") -> Column: bitwise_not.__doc__ = pysparkfuncs.bitwise_not.__doc__ -def bit_count(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("bit_count", col) +def bit_count(col: "ColumnOrName", bits: Optional[Union[Column, int]] = None) -> Column: + if bits is None: + return _invoke_function_over_columns("bit_count", col) + else: + bits = _enum_to_value(bits) + bits = lit(bits) if isinstance(bits, int) else bits + return _invoke_function_over_columns("bit_count", col, bits) bit_count.__doc__ = pysparkfuncs.bit_count.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index e2f85a7533063..2f3da48e63387 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -3763,17 +3763,27 @@ def bitwise_not(col: "ColumnOrName") -> Column: @_try_remote_functions -def bit_count(col: "ColumnOrName") -> Column: +def bit_count(col: "ColumnOrName", bits: Optional[Union[Column, int]] = None) -> Column: """ Returns the number of bits that are set in the argument expr as an unsigned 64-bit integer, or NULL if the argument is NULL. + If bits is specified, treats expr as a bits-bit signed integer in 2's complement + representation before counting set bits. + .. versionadded:: 3.5.0 + .. versionchanged:: 4.1.0 + Added optional `bits` parameter for Trino-compatible bit width control. + Parameters ---------- col : :class:`~pyspark.sql.Column` or column name target column to compute on. + bits : :class:`~pyspark.sql.Column` or int, optional + the bit width to use for counting. Must be between 1 and 64. + + .. versionadded:: 4.1.0 Returns ------- @@ -3800,8 +3810,21 @@ def bit_count(col: "ColumnOrName") -> Column: | 3| 2| | NULL| NULL| +-----+----------------+ + + >>> from pyspark.sql import functions as sf + >>> spark.range(1).select(sf.bit_count(sf.lit(-7), 8)).show() + +----------------+ + |bit_count(-7, 8)| + +----------------+ + | 6| + +----------------+ """ - return _invoke_function_over_columns("bit_count", col) + if bits is None: + return _invoke_function_over_columns("bit_count", col) + else: + bits = _enum_to_value(bits) + bits = lit(bits) if isinstance(bits, int) else bits + return _invoke_function_over_columns("bit_count", col, bits) @_try_remote_functions diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index fce3662c36674..a54c54d3765be 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -3218,6 +3218,16 @@ object functions { */ def bit_count(e: Column): Column = Column.fn("bit_count", e) + /** + * Returns the number of bits that are set in the argument expr. + * If bits is specified, treats expr as a bits-bit signed integer in 2's complement + * representation before counting set bits. + * + * @group bitwise_funcs + * @since 4.1.0 + */ + def bit_count(e: Column, bits: Column): Column = Column.fn("bit_count", e, bits) + /** * Returns the value of the bit (0 or 1) at the specified position. The positions are numbered * from right to left, starting at zero. The position argument cannot be negative. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ff27fc5625f09..348ed83d47eed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -907,7 +907,7 @@ object FunctionRegistry { expression[ShiftLeft]("<<", true, Some("4.0.0")), expression[ShiftRight](">>", true, Some("4.0.0")), expression[ShiftRightUnsigned](">>>", true, Some("4.0.0")), - expression[BitwiseCount]("bit_count"), + expressionBuilder("bit_count", BitCountExpressionBuilder), expression[BitAndAgg]("bit_and"), expression[BitOrAgg]("bit_or"), expression[BitXorAgg]("bit_xor"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 5fac0a93bf9bf..7ced5ab67a283 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types._ /** @@ -214,38 +214,118 @@ case class BitwiseNot(child: Expression) } @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the number of bits that are set in the argument expr as an" + - " unsigned 64-bit integer, or NULL if the argument is NULL.", + usage = "_FUNC_(expr[, bits]) - Returns the number of bits that are set in the argument expr." + + " If bits is specified, treats expr as a bits-bit signed integer in 2's complement." + + " If bits is not specified, uses the natural bit width of the input type.", examples = """ Examples: > SELECT _FUNC_(0); 0 + > SELECT _FUNC_(9, 64); + 2 + > SELECT _FUNC_(-7, 8); + 6 + > SELECT _FUNC_(-7, 64); + 62 """, since = "3.0.0", group = "bitwise_funcs") -case class BitwiseCount(child: Expression) +object BitCountExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + expressions.length match { + case 1 => BitwiseCount(expressions.head) + case 2 => + val bitsExpr = expressions(1) + if (!bitsExpr.foldable || bitsExpr.dataType != IntegerType) { + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "bits", IntegerType) + } + val bitsVal = bitsExpr.eval() + if (bitsVal == null) { + throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "bits", IntegerType) + } + val bits = bitsVal.asInstanceOf[Int] + if (bits < 1 || bits > 64) { + throw QueryCompilationErrors.bitsRangeError(funcName, bits) + } + BitwiseCount(expressions.head, bits) + case n => + throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(1, 2), n) + } + } +} + +case class BitwiseCount(child: Expression, bits: Int = 0) extends UnaryExpression with ExpectsInputTypes { + require(bits == 0 || (bits >= 1 && bits <= 64), + s"bits must be 0 (natural width) or in [1, 64], got $bits") override def nullIntolerant: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegralType, BooleanType)) override def dataType: DataType = IntegerType - override def toString: String = s"bit_count($child)" + override def toString: String = + if (bits > 0) s"bit_count($child, $bits)" else s"bit_count($child)" override def prettyName: String = "bit_count" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.dataType match { - case BooleanType => defineCodeGen(ctx, ev, c => s"($c) ? 1 : 0") - case _ => defineCodeGen(ctx, ev, c => s"java.lang.Long.bitCount($c)") + override def sql: String = + if (bits > 0) s"${prettyName}(${child.sql}, $bits)" else s"${prettyName}(${child.sql})" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (bits > 0) { + // Explicit bit width: cast to long, apply mask, use Long.bitCount + val maskStr = if (bits == 64) "-1L" else s"((1L << $bits) - 1)" + child.dataType match { + case BooleanType => defineCodeGen(ctx, ev, + c => s"(int) java.lang.Long.bitCount(($c ? 1L : 0L) & $maskStr)") + case ByteType | ShortType | IntegerType => defineCodeGen(ctx, ev, + c => s"(int) java.lang.Long.bitCount(((long) $c) & $maskStr)") + case LongType => defineCodeGen(ctx, ev, + c => s"(int) java.lang.Long.bitCount($c & $maskStr)") + } + } else { + // Natural bit width: use type-specific bitCount + child.dataType match { + case BooleanType => defineCodeGen(ctx, ev, c => s"($c) ? 1 : 0") + case ByteType => defineCodeGen(ctx, ev, c => + s"java.lang.Integer.bitCount(java.lang.Byte.toUnsignedInt($c))") + case ShortType => defineCodeGen(ctx, ev, c => + s"java.lang.Integer.bitCount(java.lang.Short.toUnsignedInt($c))") + case IntegerType => + defineCodeGen(ctx, ev, c => s"java.lang.Integer.bitCount($c)") + case LongType => + defineCodeGen(ctx, ev, c => s"java.lang.Long.bitCount($c)") + } + } } - protected override def nullSafeEval(input: Any): Any = child.dataType match { - case BooleanType => if (input.asInstanceOf[Boolean]) 1 else 0 - case ByteType => java.lang.Long.bitCount(input.asInstanceOf[Byte]) - case ShortType => java.lang.Long.bitCount(input.asInstanceOf[Short]) - case IntegerType => java.lang.Long.bitCount(input.asInstanceOf[Int]) - case LongType => java.lang.Long.bitCount(input.asInstanceOf[Long]) + protected override def nullSafeEval(input: Any): Any = { + if (bits > 0) { + // Explicit bit width: cast to long, apply mask, use Long.bitCount + val mask = if (bits == 64) -1L else (1L << bits) - 1 + val longVal = child.dataType match { + case BooleanType => if (input.asInstanceOf[Boolean]) 1L else 0L + case ByteType => input.asInstanceOf[Byte].toLong + case ShortType => input.asInstanceOf[Short].toLong + case IntegerType => input.asInstanceOf[Int].toLong + case LongType => input.asInstanceOf[Long] + } + java.lang.Long.bitCount(longVal & mask).toInt + } else { + // Natural bit width: use type-specific bitCount + child.dataType match { + case BooleanType => if (input.asInstanceOf[Boolean]) 1 else 0 + case ByteType => java.lang.Integer.bitCount( + java.lang.Byte.toUnsignedInt(input.asInstanceOf[Byte])) + case ShortType => java.lang.Integer.bitCount( + java.lang.Short.toUnsignedInt(input.asInstanceOf[Short])) + case IntegerType => + java.lang.Integer.bitCount(input.asInstanceOf[Int]) + case LongType => + java.lang.Long.bitCount(input.asInstanceOf[Long]) + } + } } override protected def withNewChildInternal(newChild: Expression): BitwiseCount = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index a784276200c0d..a750f53c20dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1448,6 +1448,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "paramType" -> toSQLType(paramType))) } + def bitsRangeError(funcName: String, invalidValue: Int): Throwable = { + new AnalysisException( + errorClass = "INVALID_PARAMETER_VALUE.BITS_RANGE", + messageParameters = Map( + "parameter" -> toSQLId("bits"), + "functionName" -> toSQLId(funcName), + "invalidValue" -> invalidValue.toString)) + } + def literalTypeUnsupportedForSourceTypeError(field: String, source: Expression): Throwable = { new AnalysisException( errorClass = "INVALID_EXTRACT_FIELD", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala index 63602d04b5c79..d8677b09a57ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala @@ -17,13 +17,17 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalacheck.Gen +import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks + import org.apache.spark.SparkFunSuite import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ -class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { +class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper + with ScalaCheckPropertyChecks { import IntegralLiteralTestUtils._ @@ -175,6 +179,135 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BitwiseCount(Literal(-9223372036854775808L)), 1) } + test("BitCount should respect input integer type bit width") { + // BIT_COUNT(-1) should return the number of bits in the type, not always 64 + checkEvaluation(BitwiseCount(Literal((-1).toByte)), 8) + checkEvaluation(BitwiseCount(Literal((-1).toShort)), 16) + checkEvaluation(BitwiseCount(Literal(-1)), 32) + checkEvaluation(BitwiseCount(Literal(-1L)), 64) + + // Additional cases with specific negative values + checkEvaluation(BitwiseCount(Literal(Int.MinValue)), 1) + checkEvaluation(BitwiseCount(Literal(-65536)), 16) + checkEvaluation(BitwiseCount(Literal(-256)), 24) + checkEvaluation(BitwiseCount(Literal(Byte.MinValue)), 1) + checkEvaluation(BitwiseCount(Literal((-256).toShort)), 8) + checkEvaluation(BitwiseCount(Literal(Short.MinValue)), 1) + checkEvaluation(BitwiseCount(Literal(Long.MinValue)), 1) + + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen((e: Expression) => BitwiseCount(e), dt) + } + } + + test("BitCount with explicit bits parameter") { + // Trino-compatible examples + checkEvaluation(BitwiseCount(Literal(9), 64), 2) + checkEvaluation(BitwiseCount(Literal(9), 8), 2) + checkEvaluation(BitwiseCount(Literal(-7), 64), 62) + checkEvaluation(BitwiseCount(Literal(-7), 8), 6) + checkEvaluation(BitwiseCount(Literal(0), 8), 0) + checkEvaluation(BitwiseCount(Literal(-1), 8), 8) + checkEvaluation(BitwiseCount(Literal(-1), 64), 64) + checkEvaluation(BitwiseCount(Literal(-1), 1), 1) + checkEvaluation(BitwiseCount(Literal(0), 1), 0) + + // null first argument + checkEvaluation(BitwiseCount(Literal.create(null, IntegerType), 8), null) + checkEvaluation(BitwiseCount(Literal.create(null, LongType), 64), null) + + // boolean with bits + checkEvaluation(BitwiseCount(Literal(true), 8), 1) + checkEvaluation(BitwiseCount(Literal(false), 8), 0) + + // codegen consistency for two-argument form + DataTypeTestUtils.integralType.foreach { dt => + checkConsistencyBetweenInterpretedAndCodegen( + (e: Expression) => BitwiseCount(e, 32), dt) + } + } + + test("BitCount builder validation") { + // non-foldable bits + val ex1 = intercept[Exception] { + BitCountExpressionBuilder.build("bit_count", + Seq(Literal(1), $"col".int)) + } + assert(ex1.getMessage.contains("NON_FOLDABLE_ARGUMENT") || + ex1.getMessage.contains("foldable")) + + // null bits + val ex2 = intercept[Exception] { + BitCountExpressionBuilder.build("bit_count", + Seq(Literal(1), Literal.create(null, IntegerType))) + } + assert(ex2.getMessage.contains("NON_FOLDABLE_ARGUMENT") || + ex2.getMessage.contains("foldable")) + + // bits out of range + val ex3 = intercept[Exception] { + BitCountExpressionBuilder.build("bit_count", + Seq(Literal(1), Literal(0))) + } + assert(ex3.getMessage.contains("BITS_RANGE") || + ex3.getMessage.contains("[1, 64]")) + + val ex4 = intercept[Exception] { + BitCountExpressionBuilder.build("bit_count", + Seq(Literal(1), Literal(65))) + } + assert(ex4.getMessage.contains("BITS_RANGE") || + ex4.getMessage.contains("[1, 64]")) + + // wrong number of arguments + val ex5 = intercept[Exception] { + BitCountExpressionBuilder.build("bit_count", + Seq(Literal(1), Literal(8), Literal(3))) + } + assert(ex5.getMessage.contains("WRONG_NUM_ARGS") || + ex5.getMessage.contains("number")) + + // direct construction with invalid bits triggers require + intercept[IllegalArgumentException] { BitwiseCount(Literal(1), -1) } + intercept[IllegalArgumentException] { BitwiseCount(Literal(1), 65) } + intercept[IllegalArgumentException] { BitwiseCount(Literal(1), 100) } + } + + test("BitCount property: two-argument matches Long.bitCount with mask") { + val genValue = Gen.choose(Long.MinValue, Long.MaxValue) + val genBits = Gen.choose(1, 64) + forAll(genValue, genBits, minSuccessful(200)) { (value: Long, bits: Int) => + val mask = if (bits == 64) -1L else (1L << bits) - 1 + val expected = java.lang.Long.bitCount(value & mask) + val result = BitwiseCount(Literal(value), bits).eval(null) + assert(result === expected, + s"bit_count($value, $bits): expected $expected but got $result") + } + } + + test("BitCount property: single-argument matches natural bit width") { + // Long + forAll(Gen.choose(Long.MinValue, Long.MaxValue), minSuccessful(100)) { (value: Long) => + val expected = java.lang.Long.bitCount(value) + assert(BitwiseCount(Literal(value)).eval(null) === expected) + } + // Int + forAll(Gen.choose(Int.MinValue, Int.MaxValue), minSuccessful(100)) { (value: Int) => + val expected = java.lang.Integer.bitCount(value) + assert(BitwiseCount(Literal(value)).eval(null) === expected) + } + // Short + forAll(Gen.choose(Short.MinValue, Short.MaxValue), minSuccessful(100)) { (value: Short) => + val expected = java.lang.Integer.bitCount(java.lang.Short.toUnsignedInt(value)) + assert(BitwiseCount(Literal(value)).eval(null) === expected) + } + // Byte + forAll(Gen.choose(Byte.MinValue, Byte.MaxValue), minSuccessful(100)) { (value: Byte) => + val expected = java.lang.Integer.bitCount(java.lang.Byte.toUnsignedInt(value)) + assert(BitwiseCount(Literal(value)).eval(null) === expected) + } + } + test("BitGet") { val nullLongLiteral = Literal.create(null, LongType) val nullIntLiteral = Literal.create(null, IntegerType) diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index f2c72fa18ed6d..8db0f4610d998 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -47,6 +47,7 @@ | org.apache.spark.sql.catalyst.expressions.Base64 | base64 | SELECT base64('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.Between | between | SELECT 0.5 between 0.1 AND 1.0 | struct | | org.apache.spark.sql.catalyst.expressions.Bin | bin | SELECT bin(13) | struct | +| org.apache.spark.sql.catalyst.expressions.BitCountExpressionBuilder | bit_count | SELECT bit_count(0) | struct | | org.apache.spark.sql.catalyst.expressions.BitLength | bit_length | SELECT bit_length('Spark SQL') | struct | | org.apache.spark.sql.catalyst.expressions.BitmapAndAgg | bitmap_and_agg | SELECT substring(hex(bitmap_and_agg(col)), 0, 6) FROM VALUES (X 'F0'), (X '70'), (X '30') AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.BitmapBitPosition | bitmap_bit_position | SELECT bitmap_bit_position(1) | struct | @@ -55,7 +56,6 @@ | org.apache.spark.sql.catalyst.expressions.BitmapCount | bitmap_count | SELECT bitmap_count(X '1010') | struct | | org.apache.spark.sql.catalyst.expressions.BitmapOrAgg | bitmap_or_agg | SELECT substring(hex(bitmap_or_agg(col)), 0, 6) FROM VALUES (X '10'), (X '20'), (X '40') AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.BitwiseAnd | & | SELECT 3 & 5 | struct<(3 & 5):int> | -| org.apache.spark.sql.catalyst.expressions.BitwiseCount | bit_count | SELECT bit_count(0) | struct | | org.apache.spark.sql.catalyst.expressions.BitwiseGet | bit_get | SELECT bit_get(11, 0) | struct | | org.apache.spark.sql.catalyst.expressions.BitwiseGet | getbit | SELECT getbit(11, 0) | struct | | org.apache.spark.sql.catalyst.expressions.BitwiseNot | ~ | SELECT ~ 0 | struct<~0:int> | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out index 35033e4a2d967..bb95ef066d350 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/bitwise.sql.out @@ -173,6 +173,166 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select bit_count(9, 64) +-- !query analysis +Project [bit_count(9, 64) AS bit_count(9, 64)#x] ++- OneRowRelation + + +-- !query +select bit_count(9, 8) +-- !query analysis +Project [bit_count(9, 8) AS bit_count(9, 8)#x] ++- OneRowRelation + + +-- !query +select bit_count(-7, 64) +-- !query analysis +Project [bit_count(-7, 64) AS bit_count(-7, 64)#x] ++- OneRowRelation + + +-- !query +select bit_count(-7, 8) +-- !query analysis +Project [bit_count(-7, 8) AS bit_count(-7, 8)#x] ++- OneRowRelation + + +-- !query +select bit_count(0, 8) +-- !query analysis +Project [bit_count(0, 8) AS bit_count(0, 8)#x] ++- OneRowRelation + + +-- !query +select bit_count(-1, 8) +-- !query analysis +Project [bit_count(-1, 8) AS bit_count(-1, 8)#x] ++- OneRowRelation + + +-- !query +select bit_count(-1, 64) +-- !query analysis +Project [bit_count(-1, 64) AS bit_count(-1, 64)#x] ++- OneRowRelation + + +-- !query +select bit_count(9, 0) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_PARAMETER_VALUE.BITS_RANGE", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "`bit_count`", + "invalidValue" : "0", + "parameter" : "`bits`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "bit_count(9, 0)" + } ] +} + + +-- !query +select bit_count(9, 65) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_PARAMETER_VALUE.BITS_RANGE", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "`bit_count`", + "invalidValue" : "65", + "parameter" : "`bits`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "bit_count(9, 65)" + } ] +} + + +-- !query +select bit_count(9, null) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "42K08", + "messageParameters" : { + "funcName" : "`bit_count`", + "paramName" : "`bits`", + "paramType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 25, + "fragment" : "bit_count(9, null)" + } ] +} + + +-- !query +select bit_count(9, 'a') +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "42K08", + "messageParameters" : { + "funcName" : "`bit_count`", + "paramName" : "`bits`", + "paramType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "bit_count(9, 'a')" + } ] +} + + +-- !query +select bit_count(9, 8, 3) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "3", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2]", + "functionName" : "`bit_count`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 25, + "fragment" : "bit_count(9, 8, 3)" + } ] +} + + -- !query CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES (1, 1, 1, 1L), diff --git a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql index e080fdd32a4aa..0b316a5a63e8e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/bitwise.sql @@ -38,6 +38,22 @@ select bit_count(-9223372036854775808L); select bit_count("bit count"); select bit_count('a'); +-- two-argument bit_count (Trino-compatible) +select bit_count(9, 64); +select bit_count(9, 8); +select bit_count(-7, 64); +select bit_count(-7, 8); +select bit_count(0, 8); +select bit_count(-1, 8); +select bit_count(-1, 64); + +-- two-argument bit_count error cases +select bit_count(9, 0); +select bit_count(9, 65); +select bit_count(9, null); +select bit_count(9, 'a'); +select bit_count(9, 8, 3); + -- test for bit_xor -- CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES diff --git a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out index 7233b0d0ae499..ba03b79205da0 100644 --- a/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/bitwise.sql.out @@ -195,6 +195,183 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +select bit_count(9, 64) +-- !query schema +struct +-- !query output +2 + + +-- !query +select bit_count(9, 8) +-- !query schema +struct +-- !query output +2 + + +-- !query +select bit_count(-7, 64) +-- !query schema +struct +-- !query output +62 + + +-- !query +select bit_count(-7, 8) +-- !query schema +struct +-- !query output +6 + + +-- !query +select bit_count(0, 8) +-- !query schema +struct +-- !query output +0 + + +-- !query +select bit_count(-1, 8) +-- !query schema +struct +-- !query output +8 + + +-- !query +select bit_count(-1, 64) +-- !query schema +struct +-- !query output +64 + + +-- !query +select bit_count(9, 0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_PARAMETER_VALUE.BITS_RANGE", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "`bit_count`", + "invalidValue" : "0", + "parameter" : "`bits`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 22, + "fragment" : "bit_count(9, 0)" + } ] +} + + +-- !query +select bit_count(9, 65) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "INVALID_PARAMETER_VALUE.BITS_RANGE", + "sqlState" : "22023", + "messageParameters" : { + "functionName" : "`bit_count`", + "invalidValue" : "65", + "parameter" : "`bits`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 23, + "fragment" : "bit_count(9, 65)" + } ] +} + + +-- !query +select bit_count(9, null) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "42K08", + "messageParameters" : { + "funcName" : "`bit_count`", + "paramName" : "`bits`", + "paramType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 25, + "fragment" : "bit_count(9, null)" + } ] +} + + +-- !query +select bit_count(9, 'a') +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_FOLDABLE_ARGUMENT", + "sqlState" : "42K08", + "messageParameters" : { + "funcName" : "`bit_count`", + "paramName" : "`bits`", + "paramType" : "\"INT\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 24, + "fragment" : "bit_count(9, 'a')" + } ] +} + + +-- !query +select bit_count(9, 8, 3) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + "sqlState" : "42605", + "messageParameters" : { + "actualNum" : "3", + "docroot" : "https://spark.apache.org/docs/latest", + "expectedNum" : "[1, 2]", + "functionName" : "`bit_count`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 25, + "fragment" : "bit_count(9, 8, 3)" + } ] +} + + -- !query CREATE OR REPLACE TEMPORARY VIEW bitwise_test AS SELECT * FROM VALUES (1, 1, 1, 1L),