diff --git a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java index b18c525c8849..7b3196f3f7c7 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/optimizer/calcite/stats/FilterSelectivityEstimator.java @@ -22,14 +22,19 @@ import java.util.Collections; import java.util.GregorianCalendar; import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.Set; +import com.google.common.collect.BoundType; +import com.google.common.collect.Range; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelOptUtil.InputReferencedVisitor; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; @@ -184,91 +189,383 @@ public Double visitCall(RexCall call) { return selectivity; } + /** + * Return whether the expression is a removable cast based on stats and type bounds. + * + *

+ * In Hive, if a value cannot be represented by the cast, the result of the cast is NULL, + * and therefore cannot fulfill the predicate. So the possible range of the values + * is limited by the range of possible values of the type. + *

+ * + * @param exp the expression to check + * @param tableScan the table that provides the statistics + * @return true if the expression is a removable cast, false otherwise + */ + private boolean isRemovableCast(RexNode exp, HiveTableScan tableScan) { + if(SqlKind.CAST != exp.getKind()) { + return false; + } + RexCall cast = (RexCall) exp; + RexNode op0 = cast.getOperands().getFirst(); + if (!(op0 instanceof RexInputRef)) { + return false; + } + int index = ((RexInputRef) op0).getIndex(); + final List colStats = tableScan.getColStat(Collections.singletonList(index)); + if (colStats.isEmpty()) { + return false; + } + + SqlTypeName sourceType = op0.getType().getSqlTypeName(); + SqlTypeName targetType = cast.getType().getSqlTypeName(); + + switch (sourceType) { + case TINYINT, SMALLINT, INTEGER, BIGINT: + switch (targetType) {// additional checks are needed + case TINYINT, SMALLINT, INTEGER, BIGINT: + return isRemovableIntegerCast(cast, op0, colStats); + case FLOAT, DOUBLE, DECIMAL: + return true; + default: + return false; + } + case FLOAT, DOUBLE, DECIMAL: + switch (targetType) { + // these CASTs do not show a modulo behavior, so it's ok to remove such a cast + case TINYINT, SMALLINT, INTEGER, BIGINT, FLOAT, DOUBLE, DECIMAL: + return true; + default: + return false; + } + case TIMESTAMP, DATE: + switch (targetType) { + case TIMESTAMP, DATE: + return true; + default: + return false; + } + // unknown type, do not remove the cast + default: + return false; + } + } + + private static boolean isRemovableIntegerCast(RexCall cast, RexNode op0, List colStats) { + // If the source type is completely within the target type, the cast is lossless + Range targetRange = getRangeOfType(cast.getType(), BoundType.CLOSED, BoundType.CLOSED); + Range sourceRange = getRangeOfType(op0.getType(), BoundType.CLOSED, BoundType.CLOSED); + if (targetRange.encloses(sourceRange)) { + return true; + } + + // Check that the possible values of the input column are all within the type range of the cast + // otherwise the CAST introduces some modulo-like behavior + ColStatistics colStat = colStats.getFirst(); + ColStatistics.Range colRange = colStat.getRange(); + if (colRange == null || colRange.minValue == null || colRange.maxValue == null) { + return false; + } + + // are all values of the input column accepted by the cast? + SqlTypeName targetType = cast.getType().getSqlTypeName(); + double min = ((Number) targetType.getLimit(false, SqlTypeName.Limit.OVERFLOW, false, -1, -1)).doubleValue(); + double max = ((Number) targetType.getLimit(true, SqlTypeName.Limit.OVERFLOW, false, -1, -1)).doubleValue(); + return min < colRange.minValue.doubleValue() && colRange.maxValue.doubleValue() < max; + } + + /** + * Get the range of values that are rounded to valid values of a type. + * + * @param type the type + * @param lowerBound the lower bound type of the result + * @param upperBound the upper bound type of the result + * @return the range of the type + */ + private static Range getRangeOfType(RelDataType type, BoundType lowerBound, BoundType upperBound) { + switch (type.getSqlTypeName()) { + // in case of integer types, + case TINYINT: + return Range.closed(-128.99998f, 127.99999f); + case SMALLINT: + return Range.closed(-32768.996f, 32767.998f); + case INTEGER: + return Range.closed(-2.1474836E9f, 2.1474836E9f); + case BIGINT, DATE, TIMESTAMP: + return Range.closed(-9.223372E18f, 9.223372E18f); + case DECIMAL: + return getRangeOfDecimalType(type, lowerBound, upperBound); + case FLOAT, DOUBLE: + return Range.closed(-Float.MAX_VALUE, Float.MAX_VALUE); + default: + throw new IllegalStateException("Unsupported type: " + type); + } + } + + private static Range getRangeOfDecimalType(RelDataType type, BoundType lowerBound, BoundType upperBound) { + // values outside the representable range are cast to NULL, so adapt the boundaries + int digits = type.getPrecision() - type.getScale(); + // the cast does some rounding, i.e., CAST(99.9499 AS DECIMAL(3,1)) = 99.9 + // but CAST(99.95 AS DECIMAL(3,1)) = NULL + float adjust = (float) (5 * Math.pow(10, -(type.getScale() + 1))); + // the range of values supported by the type is interval [-typeRangeExtent, typeRangeExtent] (both inclusive) + // e.g., the typeRangeExt is 99.94999 for DECIMAL(3,1) + float typeRangeExtent = Math.nextDown((float) (Math.pow(10, digits) - adjust)); + + // the resulting value of +- adjust would be rounded up, so in some cases we need to use Math.nextDown + boolean lowerInclusive = BoundType.CLOSED.equals(lowerBound); + boolean upperInclusive = BoundType.CLOSED.equals(upperBound); + float lowerUniverse = lowerInclusive ? -typeRangeExtent : Math.nextDown(-typeRangeExtent); + float upperUniverse = upperInclusive ? typeRangeExtent : Math.nextUp(typeRangeExtent); + return makeRange(lowerUniverse, lowerBound, upperUniverse, upperBound); + } + + /** + * Adjust the type boundaries if necessary. + * + * @param predicateRange boundaries of the range predicate + * @param type the type + * @param typeRange the boundaries of the type range + * @return the adjusted boundary + */ + private static Range adjustRangeToType(Range predicateRange, RelDataType type, + Range typeRange) { + boolean lowerInclusive = BoundType.CLOSED.equals(predicateRange.lowerBoundType()); + boolean upperInclusive = BoundType.CLOSED.equals(predicateRange.upperBoundType()); + switch (type.getSqlTypeName()) { + case TINYINT, SMALLINT, INTEGER, BIGINT: { + // when casting a floating point, its values are rounded towards 0 + // i.e, 10.99 is rounded to 10, and -10.99 is rounded to -10 + // to take this into account, the predicate range is transformed in the following ways + // [10.0, 15.0] -> [10, 15.99999] + // (10.0, 15.0) -> [11, 14.99999] + // [10.2, 15.2] -> [11, 15.99999] + // (10.2, 15.2) -> [11, 15.99999] + + // [-15.0, -10.0] -> [-15.9999, -10] + // (-15.0, -10.0) -> [-14.9999, -11] + // [-15.2, -10.2] -> [-15.9999, -11] + // (-15.2, -10.2) -> [-15.9999, -11] + + // normalize the range to make the formulas easier + Range range = convertRangeToClosedOpen(predicateRange); + typeRange = convertRangeToClosedOpen(typeRange); + float adjustedLower = (range.lowerEndpoint() >= 0 ? (float) Math.ceil(range.lowerEndpoint()) + : Math.nextUp(-(float) Math.ceil(Math.nextUp(-range.lowerEndpoint())))); + float adjustedUpper = range.upperEndpoint() >= 0 ? Math.nextDown((float) Math.ceil(range.upperEndpoint())) + : Math.nextUp((float) -Math.ceil(-range.upperEndpoint())); + float lower = Math.max(adjustedLower, typeRange.lowerEndpoint()); + float upper = Math.min(adjustedUpper, typeRange.upperEndpoint()); + return makeRange(lower, BoundType.CLOSED, upper, BoundType.OPEN); + } + case DECIMAL: { + // The cast to DECIMAL rounds the value the same way as {@link RoundingMode#HALF_UP}. + // The boundaries are adjusted accordingly. + float adjust = (float) (5 * Math.pow(10, -(type.getScale() + 1))); + // the resulting value of +- adjust would be rounded up, so in some cases we need to use Math.nextDown + float adjustedLower = + lowerInclusive ? predicateRange.lowerEndpoint() - adjust : addAndDown(predicateRange.lowerEndpoint(), adjust); + float adjustedUpper = upperInclusive ? addAndDown(predicateRange.upperEndpoint(), adjust) + : predicateRange.upperEndpoint() - adjust; + float lower = Math.max(adjustedLower, typeRange.lowerEndpoint()); + float upper = Math.min(adjustedUpper, typeRange.upperEndpoint()); + return makeRange(lower, predicateRange.lowerBoundType(), upper, predicateRange.upperBoundType()); + } + case TIMESTAMP, DATE: + return predicateRange; + default: + return typeRange.isConnected(predicateRange) ? typeRange.intersection(predicateRange) : Range.closedOpen(0f, 0f); + } + } + + private static float addAndDown(float v, float positiveSummand) { + float r = v + positiveSummand; + if (r == v) { + // the result is below the resolution of float; do not return a value smaller than v + return r; + } else { + return Math.nextDown(r); + } + } + + /** + * If the arguments lead to a valid range, it is returned, otherwise an empty range is returned. + */ + private static Range makeRange(float lower, BoundType lowerType, float upper, BoundType upperType) { + if (lower > upper) { + return Range.closedOpen(0f, 0f); + } + if (lower == upper && lowerType == BoundType.OPEN && upperType == BoundType.OPEN) { + return Range.closedOpen(0f, 0f); + } + + return Range.range(lower, lowerType, upper, upperType); + } + private double computeRangePredicateSelectivity(RexCall call, SqlKind op) { - final boolean isLiteralLeft = call.getOperands().get(0).getKind().equals(SqlKind.LITERAL); - final boolean isLiteralRight = call.getOperands().get(1).getKind().equals(SqlKind.LITERAL); - final boolean isInputRefLeft = call.getOperands().get(0).getKind().equals(SqlKind.INPUT_REF); - final boolean isInputRefRight = call.getOperands().get(1).getKind().equals(SqlKind.INPUT_REF); + double defaultSelectivity = ((double) 1 / (double) 3); + if (!(childRel instanceof HiveTableScan)) { + return defaultSelectivity; + } - if (childRel instanceof HiveTableScan && isLiteralLeft != isLiteralRight && isInputRefLeft != isInputRefRight) { - final HiveTableScan t = (HiveTableScan) childRel; - final int inputRefIndex = ((RexInputRef) call.getOperands().get(isInputRefLeft ? 0 : 1)).getIndex(); - final List colStats = t.getColStat(Collections.singletonList(inputRefIndex)); + // search for the literal + List operands = call.getOperands(); + final Optional leftLiteral = extractLiteral(operands.get(0)); + final Optional rightLiteral = extractLiteral(operands.get(1)); + // ensure that there's exactly one literal + if ((leftLiteral.isPresent()) == (rightLiteral.isPresent())) { + return defaultSelectivity; + } + int literalOpIdx = leftLiteral.isPresent() ? 0 : 1; + + // analyze the predicate + float value = leftLiteral.orElseGet(rightLiteral::get); + int boundaryIdx; + boolean openBound = op == SqlKind.LESS_THAN || op == SqlKind.GREATER_THAN; + switch (op) { + case LESS_THAN, LESS_THAN_OR_EQUAL: + boundaryIdx = literalOpIdx; + break; + case GREATER_THAN, GREATER_THAN_OR_EQUAL: + boundaryIdx = 1 - literalOpIdx; + break; + default: + return defaultSelectivity; + } + float[] boundaryValues = new float[] { Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY }; + BoundType[] inclusive = new BoundType[] { BoundType.CLOSED, BoundType.CLOSED }; + boundaryValues[boundaryIdx] = value; + inclusive[boundaryIdx] = openBound ? BoundType.OPEN : BoundType.CLOSED; + Range boundaries = Range.range(boundaryValues[0], inclusive[0], boundaryValues[1], inclusive[1]); + + // extract the column index from the other operator + final HiveTableScan scan = (HiveTableScan) childRel; + int inputRefOpIndex = 1 - literalOpIdx; + RexNode node = operands.get(inputRefOpIndex); + if (isRemovableCast(node, scan)) { + Range typeRange = + getRangeOfType(node.getType(), boundaries.lowerBoundType(), boundaries.upperBoundType()); + boundaries = adjustRangeToType(boundaries, node.getType(), typeRange); + node = RexUtil.removeCast(node); + } - if (!colStats.isEmpty() && isHistogramAvailable(colStats.get(0))) { - final KllFloatsSketch kll = KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram())); - final Object boundValueObject = ((RexLiteral) call.getOperands().get(isLiteralLeft ? 0 : 1)).getValue(); - final SqlTypeName typeName = call.getOperands().get(isInputRefLeft ? 0 : 1).getType().getSqlTypeName(); - float value = extractLiteral(typeName, boundValueObject); - boolean closedBound = op.equals(SqlKind.LESS_THAN_OR_EQUAL) || op.equals(SqlKind.GREATER_THAN_OR_EQUAL); - - double selectivity; - if (op.equals(SqlKind.LESS_THAN_OR_EQUAL) || op.equals(SqlKind.LESS_THAN)) { - selectivity = closedBound ? lessThanOrEqualSelectivity(kll, value) : lessThanSelectivity(kll, value); - } else { - selectivity = closedBound ? greaterThanOrEqualSelectivity(kll, value) : greaterThanSelectivity(kll, value); - } + int inputRefIndex = -1; + if (node.getKind().equals(SqlKind.INPUT_REF)) { + inputRefIndex = ((RexInputRef) node).getIndex(); + } - // selectivity does not account for null values, we multiply for the number of non-null values (getN) - // and we divide by the total (non-null + null values) to get the overall selectivity. - // - // Example: consider a filter "col < 3", and the following table rows: - // _____ - // | col | - // |_____| - // |1 | - // |null | - // |null | - // |3 | - // |4 | - // ------- - // kll.getN() would be 3, selectivity 1/3, t.getTable().getRowCount() 5 - // so the final result would be 3 * 1/3 / 5 = 1/5, as expected. - return kll.getN() * selectivity / t.getTable().getRowCount(); - } + if (inputRefIndex < 0) { + return defaultSelectivity; } - return ((double) 1 / (double) 3); + + final List colStats = scan.getColStat(Collections.singletonList(inputRefIndex)); + if (colStats.isEmpty() || !isHistogramAvailable(colStats.get(0))) { + return defaultSelectivity; + } + + final KllFloatsSketch kll = KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram())); + double rawSelectivity = rangedSelectivity(kll, boundaries); + return scaleSelectivityToNullableValues(kll, rawSelectivity, scan); + } + + /** + * Adjust the selectivity estimate to take NULL values into account. + *

+ * The rawSelectivity does not account for null values. We multiply with the number of non-null values (getN) + * and we divide by the total number (non-null + null values) to get the overall selectivity. + *

+ * Example: consider a filter "col < 3", and the following table rows: + *

+   *  _____
+   * | col |
+   * |_____|
+   * |1    |
+   * |null |
+   * |null |
+   * |3    |
+   * |4    |
+   * -------
+   * 
+ * kll.getN() would be 3, rawSelectivity 1/3, scan.getTable().getRowCount() 5 + * so the final result would be 3 * 1/3 / 5 = 1/5, as expected. + */ + private static double scaleSelectivityToNullableValues(KllFloatsSketch kll, double rawSelectivity, + HiveTableScan scan) { + if (scan.getTable() == null) { + return rawSelectivity; + } + return kll.getN() * rawSelectivity / scan.getTable().getRowCount(); } private Double computeBetweenPredicateSelectivity(RexCall call) { - final boolean hasLiteralBool = call.getOperands().get(0).getKind().equals(SqlKind.LITERAL); - final boolean hasInputRef = call.getOperands().get(1).getKind().equals(SqlKind.INPUT_REF); - final boolean hasLiteralLeft = call.getOperands().get(2).getKind().equals(SqlKind.LITERAL); - final boolean hasLiteralRight = call.getOperands().get(3).getKind().equals(SqlKind.LITERAL); + if (!(childRel instanceof HiveTableScan)) { + return computeFunctionSelectivity(call); + } + + List operands = call.getOperands(); + final boolean hasLiteralBool = operands.get(0).getKind().equals(SqlKind.LITERAL); + Optional leftLiteral = extractLiteral(operands.get(2)); + Optional rightLiteral = extractLiteral(operands.get(3)); + + if (hasLiteralBool && leftLiteral.isPresent() && rightLiteral.isPresent()) { + final HiveTableScan scan = (HiveTableScan) childRel; + float leftValue = leftLiteral.get(); + float rightValue = rightLiteral.get(); - if (childRel instanceof HiveTableScan && hasLiteralBool && hasInputRef && hasLiteralLeft && hasLiteralRight) { - final HiveTableScan t = (HiveTableScan) childRel; - final int inputRefIndex = ((RexInputRef) call.getOperands().get(1)).getIndex(); - final List colStats = t.getColStat(Collections.singletonList(inputRefIndex)); + boolean inverseBool = RexLiteral.booleanValue(operands.getFirst()); + // when they are equal it's an equality predicate, we cannot handle it as "BETWEEN" + if (Objects.equals(leftValue, rightValue)) { + return inverseBool ? computeNotEqualitySelectivity(call) : computeFunctionSelectivity(call); + } + + Range rangeBoundaries = makeRange(leftValue, BoundType.CLOSED, rightValue, BoundType.CLOSED); + Range typeBoundaries = inverseBool ? Range.closed(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY) : null; + + RexNode expr = operands.get(1); // expr to be checked by the BETWEEN + if (isRemovableCast(expr, scan)) { + typeBoundaries = + getRangeOfType(expr.getType(), rangeBoundaries.lowerBoundType(), rangeBoundaries.upperBoundType()); + rangeBoundaries = adjustRangeToType(rangeBoundaries, expr.getType(), typeBoundaries); + expr = RexUtil.removeCast(expr); + } + + int inputRefIndex = -1; + if (expr.getKind().equals(SqlKind.INPUT_REF)) { + inputRefIndex = ((RexInputRef) expr).getIndex(); + } + + if (inputRefIndex < 0) { + return computeFunctionSelectivity(call); + } + final List colStats = scan.getColStat(Collections.singletonList(inputRefIndex)); if (!colStats.isEmpty() && isHistogramAvailable(colStats.get(0))) { final KllFloatsSketch kll = KllFloatsSketch.heapify(Memory.wrap(colStats.get(0).getHistogram())); - final SqlTypeName typeName = call.getOperands().get(1).getType().getSqlTypeName(); - final Object inverseBoolValueObject = ((RexLiteral) call.getOperands().get(0)).getValue(); - boolean inverseBool = Boolean.parseBoolean(inverseBoolValueObject.toString()); - final Object leftBoundValueObject = ((RexLiteral) call.getOperands().get(2)).getValue(); - float leftValue = extractLiteral(typeName, leftBoundValueObject); - final Object rightBoundValueObject = ((RexLiteral) call.getOperands().get(3)).getValue(); - float rightValue = extractLiteral(typeName, rightBoundValueObject); - // when inverseBool == true, this is a NOT_BETWEEN and selectivity must be inverted + double rawSelectivity = rangedSelectivity(kll, rangeBoundaries); if (inverseBool) { - if (rightValue == leftValue) { - return computeNotEqualitySelectivity(call); - } else if (rightValue < leftValue) { - return 1.0; - } - return 1.0 - (kll.getN() * betweenSelectivity(kll, leftValue, rightValue) / t.getTable().getRowCount()); - } - // when they are equal it's an equality predicate, we cannot handle it as "between" - if (Double.compare(leftValue, rightValue) != 0) { - return kll.getN() * betweenSelectivity(kll, leftValue, rightValue) / t.getTable().getRowCount(); + // when inverseBool == true, this is a NOT_BETWEEN and selectivity must be inverted + // if there's a cast, the inversion is with respect to its codomain (range of the values of the cast) + double typeRangeSelectivity = rangedSelectivity(kll, typeBoundaries); + rawSelectivity = typeRangeSelectivity - rawSelectivity; } + return scaleSelectivityToNullableValues(kll, rawSelectivity, scan); } } return computeFunctionSelectivity(call); } - private float extractLiteral(SqlTypeName typeName, Object boundValueObject) { + private Optional extractLiteral(RexNode node) { + if (node.getKind() != SqlKind.LITERAL) { + return Optional.empty(); + } + RexLiteral literal = (RexLiteral) node; + if (literal.getValue() == null) { + return Optional.empty(); + } + return extractLiteral(literal.getTypeName(), literal.getValue()); + } + + private Optional extractLiteral(SqlTypeName typeName, Object boundValueObject) { final String boundValueString = boundValueObject.toString(); float value; @@ -299,10 +596,9 @@ private float extractLiteral(SqlTypeName typeName, Object boundValueObject) { value = ((GregorianCalendar) boundValueObject).toInstant().getEpochSecond(); break; default: - throw new IllegalStateException( - "Unsupported type for comparator selectivity evaluation using histogram: " + typeName); + return Optional.empty(); } - return value; + return Optional.of(value); } /** @@ -470,7 +766,7 @@ private boolean isPartitionPredicate(RexNode expr, RelNode r) { } else if (r instanceof Filter) { return isPartitionPredicate(expr, ((Filter) r).getInput()); } else if (r instanceof HiveTableScan) { - RelOptHiveTable table = (RelOptHiveTable) ((HiveTableScan) r).getTable(); + RelOptHiveTable table = (RelOptHiveTable) r.getTable(); ImmutableBitSet cols = RelOptUtil.InputFinder.bits(expr); return table.containsPartitionColumnsOnly(cols); } @@ -489,7 +785,43 @@ public Double visitLiteral(RexLiteral literal) { return null; } - private static double rangedSelectivity(KllFloatsSketch kll, float val1, float val2) { + /** + * Returns the selectivity of a predicate "val1 <= column < val2". + * @param kll the sketch + * @param boundaries the boundaries + * @return the selectivity of "val1 <= column < val2" + */ + private static double rangedSelectivity(KllFloatsSketch kll, Range boundaries) { + // convert the condition to a range val1 <= x < val2 + Range closedOpen = convertRangeToClosedOpen(boundaries); + return rangedSelectivity(kll, closedOpen.lowerEndpoint(), closedOpen.upperEndpoint()); + } + + /** + * Normalizes the range to the form "val1 <= column < val2". + */ + private static Range convertRangeToClosedOpen(Range boundaries) { + boolean leftClosed = BoundType.CLOSED.equals(boundaries.lowerBoundType()); + boolean rightOpen = BoundType.OPEN.equals(boundaries.upperBoundType()); + if (leftClosed && rightOpen) { + return boundaries; + } + float newLower = leftClosed ? boundaries.lowerEndpoint() : Math.nextUp(boundaries.lowerEndpoint()); + float newUpper = rightOpen ? boundaries.upperEndpoint() : Math.nextUp(boundaries.upperEndpoint()); + return Range.closedOpen(newLower, newUpper); + } + + /** + * Returns the selectivity of a predicate "val1 <= column < val2". + * @param kll the sketch + * @param val1 lower bound (inclusive) + * @param val2 upper bound (exclusive) + * @return the selectivity of "val1 <= column < val2" + */ + static double rangedSelectivity(KllFloatsSketch kll, float val1, float val2) { + if (val1 >= val2) { + return 0; + } float[] splitPoints = new float[] { val1, val2 }; double[] boundaries = kll.getCDF(splitPoints, QuantileSearchCriteria.EXCLUSIVE); return boundaries[1] - boundaries[0]; @@ -574,7 +906,7 @@ public static double betweenSelectivity(KllFloatsSketch kll, float leftValue, fl "Selectivity for BETWEEN leftValue AND rightValue when the two values coincide is not supported, found: " + "leftValue = " + leftValue + " and rightValue = " + rightValue); } - return rangedSelectivity(kll, Math.nextDown(leftValue), Math.nextUp(rightValue)); + return rangedSelectivity(kll, leftValue, Math.nextUp(rightValue)); } /** diff --git a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java index 4255c756e078..81820d77241a 100644 --- a/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java +++ b/ql/src/test/org/apache/hadoop/hive/ql/optimizer/calcite/stats/TestFilterSelectivityEstimator.java @@ -17,7 +17,6 @@ */ package org.apache.hadoop.hive.ql.optimizer.calcite.stats; -import com.google.common.collect.ImmutableList; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptPlanner; @@ -26,8 +25,12 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.tools.RelBuilder; @@ -43,6 +46,7 @@ import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan; import org.apache.hadoop.hive.ql.parse.CalcitePlanner; import org.apache.hadoop.hive.ql.plan.ColStatistics; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; @@ -51,24 +55,77 @@ import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalTime; +import java.time.ZoneOffset; import java.util.Collections; +import java.util.List; +import static org.apache.calcite.sql.type.SqlTypeName.BIGINT; +import static org.apache.calcite.sql.type.SqlTypeName.INTEGER; +import static org.apache.calcite.sql.type.SqlTypeName.SMALLINT; +import static org.apache.calcite.sql.type.SqlTypeName.TIMESTAMP; +import static org.apache.calcite.sql.type.SqlTypeName.TINYINT; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.betweenSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.greaterThanOrEqualSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.greaterThanSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.isHistogramAvailable; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.lessThanOrEqualSelectivity; import static org.apache.hadoop.hive.ql.optimizer.calcite.stats.FilterSelectivityEstimator.lessThanSelectivity; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @RunWith(MockitoJUnitRunner.class) public class TestFilterSelectivityEstimator { private static final float[] VALUES = { 1, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5, 6, 7 }; + private static final float[] VALUES2 = { + // rounding for DECIMAL(3,1) + // -99.95f and its two predecessors and successors + -99.95001f, -99.950005f, -99.95f, -99.94999f, -99.94998f, + // some values + 0f, 1f, 10f, + // rounding for DECIMAL(3,1) + // 99.95f and its two predecessors and successors + 99.94998f, 99.94999f, 99.95f, 99.950005f, 99.95001f, + // 100f and its two predecessors and successors + 99.999985f, 99.99999f, 100f, 100.00001f, 100.000015f, + // 100.05f and its two predecessors and successors + 100.04999f, 100.049995f, 100.05f, 100.05001f, 100.05002f, + // some values + 1_000f, 10_000f, 100_000f, 1_000_000f, 1e19f }; + private static final float[] VALUES3 = { + // the closest floats that are CAST to the integer types, and one below and above the range + -9.223373E18f, -9.223372E18f, 9.223372E18f, 9.223373E18f, // long + -2.147484E9f, -2.1474836E9f, 2.1474836E9f, 2.147484E9f, // integer + -32769.0f, -32768.996f, 32767.998f, 32768.0f, // short + -129f, -128.99998f, 127.99999f, 128.0f, // byte + // numbers for checking the rounding when casting to integer types + 10f, 10.0001f, 10.9999f, 11f, + // corresponding negative values + -11f, -10.9999f, -10.0001f, -10f }; + + /** + * Both dates and timestamps are converted to epoch seconds. + *

+ * See {@link org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp#evaluate(GenericUDF.DeferredObject[])}. + */ + private static final float[] VALUES_TIME = { + timestamp("2020-11-01"), timestamp("2020-11-02"), timestamp("2020-11-03"), timestamp("2020-11-04"), + timestamp("2020-11-05T11:23:45Z"), timestamp("2020-11-06"), timestamp("2020-11-07") }; + private static final KllFloatsSketch KLL = StatisticsTestUtils.createKll(VALUES); - private static final float DELTA = Float.MIN_VALUE; + private static final KllFloatsSketch KLL2 = StatisticsTestUtils.createKll(VALUES2); + private static final KllFloatsSketch KLL3 = StatisticsTestUtils.createKll(VALUES3); + private static final KllFloatsSketch KLL_TIME = StatisticsTestUtils.createKll(VALUES_TIME); + private static final float DELTA = 1e-7f; private static final RexBuilder REX_BUILDER = new RexBuilder(new JavaTypeFactoryImpl(new HiveTypeSystemImpl())); private static final RelDataTypeFactory TYPE_FACTORY = REX_BUILDER.getTypeFactory(); + private static RelOptCluster relOptCluster; private static RexNode intMinus1; private static RexNode int0; @@ -85,7 +142,6 @@ public class TestFilterSelectivityEstimator { private static RexNode inputRef0; private static RexNode boolFalse; private static RexNode boolTrue; - private static ColStatistics stats; @Mock private RelOptSchema schemaMock; @@ -94,12 +150,14 @@ public class TestFilterSelectivityEstimator { @Mock private RelMetadataQuery mq; - private HiveTableScan tableScan; + private ColStatistics stats; private RelNode scan; + private RexNode currentInputRef; + private int currentValuesSize; @BeforeClass public static void beforeClass() { - RelDataType integerType = TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER); + RelDataType integerType = TYPE_FACTORY.createSqlType(INTEGER); intMinus1 = REX_BUILDER.makeLiteral(-1, integerType, true); int0 = REX_BUILDER.makeLiteral(0, integerType, true); int1 = REX_BUILDER.makeLiteral(1, integerType, true); @@ -113,25 +171,61 @@ public static void beforeClass() { int11 = REX_BUILDER.makeLiteral(11, integerType, true); boolFalse = REX_BUILDER.makeLiteral(false, TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN), true); boolTrue = REX_BUILDER.makeLiteral(true, TYPE_FACTORY.createSqlType(SqlTypeName.BOOLEAN), true); - tableType = TYPE_FACTORY.createStructType(ImmutableList.of(integerType), ImmutableList.of("f1")); + RelDataTypeFactory.Builder b = new RelDataTypeFactory.Builder(TYPE_FACTORY); + b.add("f_numeric", decimalType(38, 25)); + b.add("f_decimal10s3", decimalType(10, 3)); + b.add("f_float", TYPE_FACTORY.createSqlType(SqlTypeName.FLOAT)); + b.add("f_double", TYPE_FACTORY.createSqlType(SqlTypeName.DOUBLE)); + b.add("f_tinyint", TYPE_FACTORY.createSqlType(TINYINT)); + b.add("f_smallint", TYPE_FACTORY.createSqlType(SMALLINT)); + b.add("f_integer", integerType); + b.add("f_bigint", TYPE_FACTORY.createSqlType(BIGINT)); + b.add("f_timestamp", SqlTypeName.TIMESTAMP); + b.add("f_date", SqlTypeName.DATE).build(); + tableType = b.build(); RelOptPlanner planner = CalcitePlanner.createPlanner(new HiveConf()); relOptCluster = RelOptCluster.create(planner, REX_BUILDER); + } - stats = new ColStatistics(); - stats.setHistogram(KLL.toByteArray()); + private static ColStatistics.Range rangeOf(float[] values) { + float min = Float.MAX_VALUE, max = -Float.MAX_VALUE; + for (float v : values) { + min = Math.min(min, v); + max = Math.max(max, v); + } + return new ColStatistics.Range(min, max); } @Before public void before() { + currentValuesSize = VALUES.length; doReturn(tableType).when(tableMock).getRowType(); - doReturn((double) VALUES.length).when(tableMock).getRowCount(); + when(tableMock.getRowCount()).thenAnswer(a -> (double) currentValuesSize); RelBuilder relBuilder = HiveRelFactories.HIVE_BUILDER.create(relOptCluster, schemaMock); - tableScan = new HiveTableScan(relOptCluster, relOptCluster.traitSetOf(HiveRelNode.CONVENTION), - tableMock, "table", null, false, false); + HiveTableScan tableScan = + new HiveTableScan(relOptCluster, relOptCluster.traitSetOf(HiveRelNode.CONVENTION), tableMock, "table", null, + false, false); scan = relBuilder.push(tableScan).build(); inputRef0 = REX_BUILDER.makeInputRef(scan, 0); + currentInputRef = inputRef0; + + stats = new ColStatistics(); + stats.setHistogram(KLL.toByteArray()); + stats.setRange(rangeOf(VALUES)); + } + + /** + * Note: call this method only at the beginning of a test method. + */ + private void useFieldWithValues(String fieldname, float[] values, KllFloatsSketch sketch) { + currentValuesSize = values.length; + stats.setHistogram(sketch.toByteArray()); + stats.setRange(rangeOf(values)); + int fieldIndex = scan.getRowType().getFieldNames().indexOf(fieldname); + currentInputRef = REX_BUILDER.makeInputRef(scan, fieldIndex); + doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(fieldIndex)); } @Test @@ -420,7 +514,7 @@ public void testComputeRangePredicateSelectivityBetweenLeftLowerThanRight() { @Test public void testComputeRangePredicateSelectivityBetweenLeftEqualsRight() { - doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0)); + verify(tableMock, never()).getColStat(any()); doReturn(10.0).when(mq).getDistinctRowCount(scan, ImmutableBitSet.of(0), REX_BUILDER.makeLiteral(true)); RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolFalse, inputRef0, int3, int3); FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); @@ -454,7 +548,7 @@ public void testComputeRangePredicateSelectivityNotBetweenRightLowerThanLeft() { @Test public void testComputeRangePredicateSelectivityNotBetweenLeftEqualsRight() { - doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0)); + verify(tableMock, never()).getColStat(any()); RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, inputRef0, int3, int3); FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); Assert.assertEquals(1, estimator.estimateSelectivity(filter), DELTA); @@ -511,6 +605,428 @@ public void testComputeRangePredicateSelectivityNotBetweenWithNULLS() { doReturn(Collections.singletonList(stats)).when(tableMock).getColStat(Collections.singletonList(0)); RexNode filter = REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, inputRef0, int1, int3); FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); - Assert.assertEquals(0.55, estimator.estimateSelectivity(filter), DELTA); + // only the values 4, 5, 6, 7 fulfill the condition NOT BETWEEN 1 AND 3 + // (the NULL values do not fulfill the condition) + Assert.assertEquals(0.2, estimator.estimateSelectivity(filter), DELTA); + } + + @Test + public void testRangePredicateCastIntegerValuesInsideTypeRange() { + // use VALUES2, even if the tested types cannot represent its values + // we're only interested in whether the cast to a smaller integer type results in the default selectivity + useFieldWithValues("f_tinyint", VALUES, KLL); + checkSelectivity(3 / 13.f, ge(cast("f_tinyint", TINYINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_tinyint", SMALLINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_tinyint", INTEGER), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_tinyint", BIGINT), int5)); + + useFieldWithValues("f_smallint", VALUES, KLL); + checkSelectivity(3 / 13.f, ge(cast("f_smallint", TINYINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_smallint", SMALLINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_smallint", INTEGER), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_smallint", BIGINT), int5)); + + useFieldWithValues("f_integer", VALUES, KLL); + checkSelectivity(3 / 13.f, ge(cast("f_integer", TINYINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_integer", SMALLINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_integer", INTEGER), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_integer", BIGINT), int5)); + + useFieldWithValues("f_bigint", VALUES, KLL); + checkSelectivity(3 / 13.f, ge(cast("f_bigint", TINYINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_bigint", SMALLINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_bigint", INTEGER), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_bigint", BIGINT), int5)); + } + + @Test + public void testRangePredicateCastIntegerValuesOutsideTypeRange() { + // use VALUES2, even if the tested types cannot represent its values + // we're only interested in whether the cast to a smaller integer type results in the default selectivity + useFieldWithValues("f_tinyint", VALUES2, KLL2); + checkSelectivity(16 / 28.f, ge(cast("f_tinyint", TINYINT), int5)); + checkSelectivity(18 / 28.f, ge(cast("f_tinyint", SMALLINT), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_tinyint", INTEGER), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_tinyint", BIGINT), int5)); + + useFieldWithValues("f_smallint", VALUES2, KLL2); + checkSelectivity(1 / 3.f, ge(cast("f_smallint", TINYINT), int5)); + checkSelectivity(18 / 28.f, ge(cast("f_smallint", SMALLINT), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_smallint", INTEGER), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_smallint", BIGINT), int5)); + + useFieldWithValues("f_integer", VALUES2, KLL2); + checkSelectivity(1 / 3.f, ge(cast("f_integer", TINYINT), int5)); + checkSelectivity(1 / 3.f, ge(cast("f_integer", SMALLINT), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_integer", INTEGER), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_integer", BIGINT), int5)); + + useFieldWithValues("f_bigint", VALUES2, KLL2); + checkSelectivity(1 / 3.f, ge(cast("f_bigint", TINYINT), int5)); + checkSelectivity(1 / 3.f, ge(cast("f_bigint", SMALLINT), int5)); + checkSelectivity(1 / 3.f, ge(cast("f_bigint", INTEGER), int5)); + checkSelectivity(20 / 28.f, ge(cast("f_bigint", BIGINT), int5)); + } + + @Test + public void testRangePredicateTypeMatrix() { + // checks many possible combinations of types + List fields = tableType.getFieldList(); + for (var srcField : fields) { + if (isTemporal(srcField.getType())) { + continue; + } + + useFieldWithValues(srcField.getName(), VALUES, KLL); + + for (var tgt : fields) { + try { + if (isTemporal(tgt.getType())) { + continue; + } + + RexNode expr = cast(srcField.getName(), tgt.getType()); + checkBetweenSelectivity(3, VALUES.length, VALUES.length, expr, 5, 7); + } catch (AssertionError e) { + throw new AssertionError("Error when casting from " + srcField.getType() + " to " + tgt.getType(), e); + } + } + } + } + + private boolean isTemporal(RelDataType type) { + return type.getSqlTypeName() == TIMESTAMP || type.getSqlTypeName() == SqlTypeName.DATE; + } + + @Test + public void testRangePredicateWithCast() { + useFieldWithValues("f_numeric", VALUES, KLL); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", TINYINT), int5)); + checkSelectivity(10 / 13.f, lt(cast("f_numeric", TINYINT), int5)); + checkSelectivity(2 / 13.f, gt(cast("f_numeric", TINYINT), int5)); + checkSelectivity(11 / 13.f, le(cast("f_numeric", TINYINT), int5)); + + checkSelectivity(12 / 13f, ge(cast("f_numeric", TINYINT), int2)); + checkSelectivity(1 / 13f, lt(cast("f_numeric", TINYINT), int2)); + checkSelectivity(5 / 13f, gt(cast("f_numeric", TINYINT), int2)); + checkSelectivity(8 / 13f, le(cast("f_numeric", TINYINT), int2)); + + // check some types + checkSelectivity(3 / 13.f, ge(cast("f_numeric", INTEGER), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", SMALLINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", BIGINT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", SqlTypeName.FLOAT), int5)); + checkSelectivity(3 / 13.f, ge(cast("f_numeric", SqlTypeName.DOUBLE), int5)); + } + + @Test + public void testRangePredicateWithCast2() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + RelDataType decimal3s1 = decimalType(3, 1); + checkSelectivity(4 / 28.f, ge(cast("f_numeric", decimal3s1), literalFloat(1))); + + // values from -99.94999 to 99.94999 (both inclusive) + checkSelectivity(7 / 28.f, lt(cast("f_numeric", decimal3s1), literalFloat(100))); + checkSelectivity(7 / 28.f, le(cast("f_numeric", decimal3s1), literalFloat(100))); + checkSelectivity(0 / 28.f, gt(cast("f_numeric", decimal3s1), literalFloat(100))); + checkSelectivity(0 / 28.f, ge(cast("f_numeric", decimal3s1), literalFloat(100))); + + RelDataType decimal4s1 = decimalType(4, 1); + checkSelectivity(10 / 28.f, lt(cast("f_numeric", decimal4s1), literalFloat(100))); + checkSelectivity(20 / 28.f, le(cast("f_numeric", decimal4s1), literalFloat(100))); + checkSelectivity(3 / 28.f, gt(cast("f_numeric", decimal4s1), literalFloat(100))); + checkSelectivity(13 / 28.f, ge(cast("f_numeric", decimal4s1), literalFloat(100))); + + RelDataType decimal2s1 = decimalType(2, 1); + checkSelectivity(2 / 28.f, lt(cast("f_numeric", decimal2s1), literalFloat(100))); + checkSelectivity(2 / 28.f, le(cast("f_numeric", decimal2s1), literalFloat(100))); + checkSelectivity(0 / 28.f, gt(cast("f_numeric", decimal2s1), literalFloat(100))); + checkSelectivity(0 / 28.f, ge(cast("f_numeric", decimal2s1), literalFloat(100))); + + // expected: 100_000f + RelDataType decimal7s1 = decimalType(7, 1); + checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), literalFloat(10000))); + + // expected: 10_000f, 100_000f, because CAST(1_000_000 AS DECIMAL(7,1)) = NULL, and similar for even larger values + checkSelectivity(2 / 28.f, ge(cast("f_numeric", decimal7s1), literalFloat(9999))); + checkSelectivity(2 / 28.f, ge(cast("f_numeric", decimal7s1), literalFloat(10000))); + + // expected: 100_000f + checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), literalFloat(10000))); + checkSelectivity(1 / 28.f, gt(cast("f_numeric", decimal7s1), literalFloat(10001))); + + // expected 1f, 10f, 99.94998f, 99.94999f + checkSelectivity(4 / 28.f, ge(cast("f_numeric", decimal3s1), literalFloat(1))); + checkSelectivity(3 / 28.f, gt(cast("f_numeric", decimal3s1), literalFloat(1))); + // expected -99.94999f, -99.94998f, 0f, 1f + checkSelectivity(4 / 28.f, le(cast("f_numeric", decimal3s1), literalFloat(1))); + checkSelectivity(3 / 28.f, lt(cast("f_numeric", decimal3s1), literalFloat(1))); + } + + private void checkTimeFieldOnMidnightTimestamps(RexNode field) { + // note: use only values from VALUES_TIME that specify a date without hh:mm:ss! + checkSelectivity(7 / 7.f, ge(field, literalTimestamp("2020-11-01"))); + checkSelectivity(5 / 7.f, ge(field, literalTimestamp("2020-11-03"))); + checkSelectivity(1 / 7.f, ge(field, literalTimestamp("2020-11-07"))); + + checkSelectivity(6 / 7.f, gt(field, literalTimestamp("2020-11-01"))); + checkSelectivity(4 / 7.f, gt(field, literalTimestamp("2020-11-03"))); + checkSelectivity(0 / 7.f, gt(field, literalTimestamp("2020-11-07"))); + + checkSelectivity(1 / 7.f, le(field, literalTimestamp("2020-11-01"))); + checkSelectivity(3 / 7.f, le(field, literalTimestamp("2020-11-03"))); + checkSelectivity(7 / 7.f, le(field, literalTimestamp("2020-11-07"))); + + checkSelectivity(0 / 7.f, lt(field, literalTimestamp("2020-11-01"))); + checkSelectivity(2 / 7.f, lt(field, literalTimestamp("2020-11-03"))); + checkSelectivity(6 / 7.f, lt(field, literalTimestamp("2020-11-07"))); + } + + private void checkTimeFieldOnIntraDayTimestamps(RexNode field) { + checkSelectivity(3 / 7.f, ge(field, literalTimestamp("2020-11-05T11:23:45Z"))); + checkSelectivity(2 / 7.f, gt(field, literalTimestamp("2020-11-05T11:23:45Z"))); + checkSelectivity(5 / 7.f, le(field, literalTimestamp("2020-11-05T11:23:45Z"))); + checkSelectivity(4 / 7.f, lt(field, literalTimestamp("2020-11-05T11:23:45Z"))); + } + + @Test + public void testRangePredicateOnTimestamp() { + useFieldWithValues("f_timestamp", VALUES_TIME, KLL_TIME); + checkTimeFieldOnMidnightTimestamps(currentInputRef); + checkTimeFieldOnIntraDayTimestamps(currentInputRef); + } + + @Test + public void testRangePredicateOnTimestampWithCast() { + useFieldWithValues("f_timestamp", VALUES_TIME, KLL_TIME); + RexNode expr1 = cast("f_timestamp", SqlTypeName.DATE); + checkTimeFieldOnMidnightTimestamps(expr1); + checkTimeFieldOnIntraDayTimestamps(expr1); + + RexNode expr2 = cast("f_timestamp", SqlTypeName.TIMESTAMP); + checkTimeFieldOnMidnightTimestamps(expr2); + checkTimeFieldOnIntraDayTimestamps(expr2); + } + + @Test + public void testRangePredicateOnDate() { + useFieldWithValues("f_date", VALUES_TIME, KLL_TIME); + checkTimeFieldOnMidnightTimestamps(currentInputRef); + + // it does not make sense to compare with "2020-11-05T11:23:45Z", + // as that value would not be stored as-is in a date column, but as "2020-11-05" instead + } + + @Test + public void testRangePredicateOnDateWithCast() { + useFieldWithValues("f_date", VALUES_TIME, KLL_TIME); + checkTimeFieldOnMidnightTimestamps(cast("f_date", SqlTypeName.DATE)); + checkTimeFieldOnMidnightTimestamps(cast("f_date", SqlTypeName.TIMESTAMP)); + + // it does not make sense to compare with "2020-11-05T11:23:45Z", + // as that value would not be stored as-is in a date column, but as "2020-11-05" instead + } + + @Test + public void testBetweenWithCastToTinyIntCheckRounding() { + useFieldWithValues("f_numeric", VALUES3, KLL3); + float total = VALUES3.length; + float universe = 10; // the number of values that "survive" the cast + RexNode cast = cast("f_numeric", TINYINT); + // check rounding of positive numbers + checkBetweenSelectivity(3, universe, total, cast, 0, 10); + checkBetweenSelectivity(4, universe, total, cast, 0, 11); + checkBetweenSelectivity(4, universe, total, cast, 10, 20); + checkBetweenSelectivity(1, universe, total, cast, 11, 20); + + // check rounding of negative numbers + checkBetweenSelectivity(4, universe, total, cast, -20, -10); + checkBetweenSelectivity(1, universe, total, cast, -20, -11); + checkBetweenSelectivity(3, universe, total, cast, -10, 0); + checkBetweenSelectivity(4, universe, total, cast, -11, 0); + } + + @Test + public void testBetweenWithCastToTinyInt() { + useFieldWithValues("f_numeric", VALUES3, KLL3); + float total = VALUES3.length; + float universe = 10; // the number of values that "survive" the cast + RexNode cast = cast("f_numeric", TINYINT); + checkBetweenSelectivity(5, universe, total, cast, 0, 1e20f); + checkBetweenSelectivity(5, universe, total, cast, -1e20f, 0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToSmallInt() { + useFieldWithValues("f_numeric", VALUES3, KLL3); + float total = VALUES3.length; + float universe = 14; // the number of values that "survive" the cast + RexNode cast = cast("f_numeric", SMALLINT); + checkBetweenSelectivity(7, universe, total, cast, 0, 1e20f); + checkBetweenSelectivity(7, universe, total, cast, -1e20f, 0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToInteger() { + useFieldWithValues("f_numeric", VALUES3, KLL3); + float total = VALUES3.length; + float universe = 18; // the number of values that "survive" the cast + RexNode cast = cast("f_numeric", INTEGER); + checkBetweenSelectivity(9, universe, total, cast, 0, 1e20f); + checkBetweenSelectivity(9, universe, total, cast, -1e20f, 0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToBigInt() { + useFieldWithValues("f_numeric", VALUES3, KLL3); + float total = VALUES3.length; + float universe = 22; // the number of values that "survive" the cast + RexNode cast = cast("f_numeric", BIGINT); + checkBetweenSelectivity(11, universe, total, cast, 0, 1e20f); + checkBetweenSelectivity(11, universe, total, cast, -1e20f, 0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToSmallInt2() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 23; // the number of values that "survive" the cast + RexNode cast = cast("f_numeric", TINYINT); + checkBetweenSelectivity(8, universe, total, cast, 100f, 1000f); + checkBetweenSelectivity(17, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToDecimal2s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 2; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(2, 1), inputRef0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 1000f); + checkBetweenSelectivity(1, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToDecimal3s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 7; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(3, 1), inputRef0); + checkBetweenSelectivity(0, universe, total, cast, 100f, 1000f); + checkBetweenSelectivity(4, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToDecimal4s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 23; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(4, 1), inputRef0); + // the values between -999.94999... and 999.94999... (both inclusive) pass through the cast + // the values between 99.95 and 100 are rounded up to 100, so they fulfill the BETWEEN + checkBetweenSelectivity(13, universe, total, cast, 100, 1000); + checkBetweenSelectivity(14, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + @Test + public void testBetweenWithCastToDecimal7s1() { + useFieldWithValues("f_numeric", VALUES2, KLL2); + float total = VALUES2.length; + float universe = 26; // the number of values that "survive" the cast + RexNode cast = REX_BUILDER.makeCast(decimalType(7, 1), inputRef0); + checkBetweenSelectivity(14, universe, total, cast, 100, 1000); + checkBetweenSelectivity(14, universe, total, cast, 1f, 100f); + checkBetweenSelectivity(0, universe, total, cast, 100f, 0f); + } + + private void checkSelectivity(float expectedSelectivity, RexNode filter) { + FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); + Assert.assertEquals(filter.toString(), expectedSelectivity, estimator.estimateSelectivity(filter), DELTA); + + // convert "col OP value" to "value INVERSE_OP col", and check it + RexNode inverted = RexUtil.invert(REX_BUILDER, (RexCall) filter); + if (inverted != null) { + Assert.assertEquals(filter.toString(), expectedSelectivity, estimator.estimateSelectivity(inverted), DELTA); + } + } + + private void checkBetweenSelectivity(float expectedEntries, float universe, float total, RexNode value, float lower, + float upper) { + RexNode betweenFilter = + REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolFalse, value, literalFloat(lower), literalFloat(upper)); + FilterSelectivityEstimator estimator = new FilterSelectivityEstimator(scan, mq); + String between = "BETWEEN " + lower + " AND " + upper; + float expectedSelectivity = expectedEntries / total; + String message = between + ": calcite filter " + betweenFilter.toString(); + Assert.assertEquals(message, expectedSelectivity, estimator.estimateSelectivity(betweenFilter), DELTA); + + // invert the filter to a NOT BETWEEN + RexNode invBetween = + REX_BUILDER.makeCall(HiveBetween.INSTANCE, boolTrue, value, literalFloat(lower), literalFloat(upper)); + String invMessage = "NOT " + between + ": calcite filter " + invBetween.toString(); + float invExpectedSelectivity = (universe - expectedEntries) / total; + Assert.assertEquals(invMessage, invExpectedSelectivity, estimator.estimateSelectivity(invBetween), DELTA); + } + + private RexNode cast(String fieldname, SqlTypeName typeName) { + return cast(fieldname, type(typeName)); + } + + private RexNode cast(String fieldname, RelDataType type) { + int fieldIndex = scan.getRowType().getFieldNames().indexOf(fieldname); + RexNode column = REX_BUILDER.makeInputRef(scan, fieldIndex); + return REX_BUILDER.makeCast(type, column); + } + + private RexNode ge(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, expr, value); + } + + private RexNode gt(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.GREATER_THAN, expr, value); + } + + private RexNode le(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, expr, value); + } + + private RexNode lt(RexNode expr, RexNode value) { + return REX_BUILDER.makeCall(SqlStdOperatorTable.LESS_THAN, expr, value); + } + + private static RelDataType type(SqlTypeName typeName) { + return REX_BUILDER.getTypeFactory().createSqlType(typeName); + } + + private static RelDataType decimalType(int precision, int scale) { + return REX_BUILDER.getTypeFactory().createSqlType(SqlTypeName.DECIMAL, precision, scale); + } + + private static RexLiteral literalTimestamp(String timestamp) { + return REX_BUILDER.makeLiteral(timestampMillis(timestamp), + REX_BUILDER.getTypeFactory().createSqlType(SqlTypeName.TIMESTAMP)); + } + + private RexNode literalFloat(float f) { + return REX_BUILDER.makeLiteral(f, type(SqlTypeName.FLOAT)); + } + + private static long timestampMillis(String timestamp) { + if (!timestamp.contains(":")) { + return LocalDate.parse(timestamp).toEpochSecond(LocalTime.MIDNIGHT, ZoneOffset.UTC) * 1000; + } + return Instant.parse(timestamp).toEpochMilli(); + } + + private static long timestamp(String timestamp) { + return timestampMillis(timestamp) / 1000; } }