Skip to content
Open
5 changes: 0 additions & 5 deletions docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ the [Comet Supported Expressions Guide](expressions.md) for more information on
timezone is UTC.
[#2649](https://github.com/apache/datafusion-comet/issues/2649)

### Aggregate Expressions

- **Corr**: Returns null instead of NaN in some edge cases.
[#2646](https://github.com/apache/datafusion-comet/issues/2646)

### Struct Expressions

- **StructsToJson (to_json)**: Does not support `+Infinity` and `-Infinity` for numeric types (float, double).
Expand Down
42 changes: 21 additions & 21 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,27 +195,27 @@ Expressions that are not Spark-compatible will fall back to Spark by default and

## Aggregate Expressions

| Expression | SQL | Spark-Compatible? | Compatibility Notes |
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------------------------------------------------------- |
| Average | | Yes, except for ANSI mode | |
| BitAndAgg | | Yes | |
| BitOrAgg | | Yes | |
| BitXorAgg | | Yes | |
| BoolAnd | `bool_and` | Yes | |
| BoolOr | `bool_or` | Yes | |
| Corr | | No | Returns null instead of NaN in some edge cases ([#2646](https://github.com/apache/datafusion-comet/issues/2646)) |
| Count | | Yes | |
| CovPopulation | | Yes | |
| CovSample | | Yes | |
| First | | No | This function is not deterministic. Results may not match Spark. |
| Last | | No | This function is not deterministic. Results may not match Spark. |
| Max | | Yes | |
| Min | | Yes | |
| StddevPop | | Yes | |
| StddevSamp | | Yes | |
| Sum | | Yes, except for ANSI mode | |
| VariancePop | | Yes | |
| VarianceSamp | | Yes | |
| Expression | SQL | Spark-Compatible? | Compatibility Notes |
| ------------- | ---------- | ------------------------- | ---------------------------------------------------------------- |
| Average | | Yes, except for ANSI mode | |
| BitAndAgg | | Yes | |
| BitOrAgg | | Yes | |
| BitXorAgg | | Yes | |
| BoolAnd | `bool_and` | Yes | |
| BoolOr | `bool_or` | Yes | |
| Corr | | Yes | |
| Count | | Yes | |
| CovPopulation | | Yes | |
| CovSample | | Yes | |
| First | | No | This function is not deterministic. Results may not match Spark. |
| Last | | No | This function is not deterministic. Results may not match Spark. |
| Max | | Yes | |
| Min | | Yes | |
| StddevPop | | Yes | |
| StddevSamp | | Yes | |
| Sum | | Yes, except for ANSI mode | |
| VariancePop | | Yes | |
| VarianceSamp | | Yes | |

## Window Functions

Expand Down
17 changes: 10 additions & 7 deletions native/spark-expr/src/agg_funcs/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,19 +221,22 @@ impl Accumulator for CorrelationAccumulator {
let stddev1 = self.stddev1.evaluate()?;
let stddev2 = self.stddev2.evaluate()?;

if self.covar.get_count() == 0.0 {
return Ok(ScalarValue::Float64(None));
} else if self.covar.get_count() == 1.0 {
if self.null_on_divide_by_zero {
return Ok(ScalarValue::Float64(None));
} else {
return Ok(ScalarValue::Float64(Some(f64::NAN)));
}
}
match (covar, stddev1, stddev2) {
(
ScalarValue::Float64(Some(c)),
ScalarValue::Float64(Some(s1)),
ScalarValue::Float64(Some(s2)),
) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))),
_ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)),
_ => {
if self.covar.get_count() == 1.0 {
return Ok(ScalarValue::Float64(Some(f64::NAN)));
}
Ok(ScalarValue::Float64(None))
}
_ => Ok(ScalarValue::Float64(None)),
}
}

Expand Down
7 changes: 0 additions & 7 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -584,13 +584,6 @@ object CometStddevPop extends CometAggregateExpressionSerde[StddevPop] with Come
}

object CometCorr extends CometAggregateExpressionSerde[Corr] {

override def getSupportLevel(expr: Corr): SupportLevel =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude flagged some edge cases we can document -

 ▎ 1. Legacy mode: When spark.sql.legacy.statisticalAggregate=true, nullOnDivideByZero is false and Spark returns NaN for the n=1 case. With this workaround, Comet would return null instead (because the NaN row gets skipped → n=0). Should we add a getSupportLevel guard that returns Incompatible when
  corr.nullOnDivideByZero is false? Or at least document this?
 ▎ 2. Mixed groups: For a group containing (NaN, NaN) alongside valid pairs like (1.0, 2.0), Spark returns NaN (NaN contaminates the accumulator), while this workaround would skip the NaN row and compute a valid correlation over the remaining rows. Is that a known limitation we're OK with?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Worth double-checking: the original incompatibility note said "returns null instead of NaN in some edge cases." That also describes the behavior in correlation.rs:evaluate() when stddev is zero (constant values produce stddev=0, Spark returns NaN from 0/0, Comet returns null from the null_on_divide_by_zero guard). Is that case also resolved, or should Incompatible remain for that scenario?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up fixing properly for the legacy mode as well. Added tests too

Incompatible(
Some(
"Returns null instead of NaN in some edge cases" +
" (https://github.com/apache/datafusion-comet/issues/2646)"))

override def convert(
aggExpr: AggregateExpression,
corr: Corr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
-- specific language governing permissions and limitations
-- under the License.

-- Config: spark.comet.expression.Corr.allowIncompatible=true

statement
CREATE TABLE test_corr(x double, y double, grp string) USING parquet
Expand All @@ -28,3 +27,13 @@ SELECT corr(x, y) FROM test_corr

query tolerance=1e-6
SELECT grp, corr(x, y) FROM test_corr GROUP BY grp ORDER BY grp

-- Test permutations of NULL and NaN
statement
CREATE TABLE test_corr_nan(x double, y double, grp string) USING parquet

statement
INSERT INTO test_corr_nan VALUES (cast('NaN' as double), cast('NaN' as double), 'both_nan'), (cast('NaN' as double), 1.0, 'nan_val'), (1.0, cast('NaN' as double), 'val_nan'), (NULL, cast('NaN' as double), 'null_nan'), (cast('NaN' as double), NULL, 'nan_null'), (NULL, NULL, 'both_null'), (NULL, 1.0, 'null_val'), (1.0, NULL, 'val_null'), (cast('NaN' as double), cast('NaN' as double), 'mixed'), (1.0, 2.0, 'mixed'), (3.0, 4.0, 'mixed'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan'), (cast('NaN' as double), cast('NaN' as double), 'multi_nan')

query tolerance=1e-6
SELECT grp, corr(x, y) FROM test_corr_nan GROUP BY grp ORDER BY grp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.util.Random
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.expressions.aggregate.Corr
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
Expand Down Expand Up @@ -1306,9 +1305,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}

test("covariance & correlation") {
withSQLConf(
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.getExprAllowIncompatConfigKey(classOf[Corr]) -> "true") {
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq("jvm", "native").foreach { cometShuffleMode =>
withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> cometShuffleMode) {
Seq(true, false).foreach { dictionary =>
Expand Down Expand Up @@ -1379,6 +1376,31 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("corr - nan/null") {
Seq(true, false).foreach { nullOnDivideByZero =>
withSQLConf("spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) {
withTable("t") {
sql("""create table t using parquet as
select cast(null as float) f1, CAST('NaN' AS float) f2, cast(null as double) d1, CAST('NaN' AS double) d2
from range(1)
""")

checkSparkAnswerAndOperator("""
|select
| corr(f1, f2) c1,
| corr(f1, f1) c2,
| corr(f2, f1) c3,
| corr(f2, f2) c4,
| corr(d1, d2) c5,
| corr(d1, d1) c6,
| corr(d2, d1) c7,
| corr(d2, d2) c8
| FROM t""".stripMargin)
}
}
}
}

test("var_pop and var_samp") {
withSQLConf(CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true") {
Seq("native", "jvm").foreach { cometShuffleMode =>
Expand Down
Loading