Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1420,14 +1420,65 @@ object Expand {
}
}

val taggedChildOutput =
if (SQLConf.get.getConf(SQLConf.EXPAND_TAG_PASSTHROUGH_DUPLICATES_ENABLED)) {
tagPassthroughDuplicates(childOutput, groupByAliases)
} else {
childOutput
}

val output = if (hasDuplicateGroupingSets) {
val gpos = AttributeReference("_gen_grouping_pos", IntegerType, false)()
childOutput ++ groupByAttrs.map(_.newInstance()) :+ gid :+ gpos
taggedChildOutput ++ groupByAttrs.map(_.newInstance()) :+ gid :+ gpos
} else {
childOutput ++ groupByAttrs.map(_.newInstance()) :+ gid
taggedChildOutput ++ groupByAttrs.map(_.newInstance()) :+ gid
}
Expand(projections, output, Project(childOutput ++ groupByAliases, child))
}

/**
* Tags child output attributes that will be duplicated in the Expand output with
* `__is_duplicate` metadata.
*
* When a `groupByAlias` wraps a simple attribute (e.g., `Alias(c1#0, "c1")`), the
* Expand output contains both the pass-through child attribute (`c1#0`) and a new
* grouping instance created via `newInstance()` (e.g., `c1#5`). Both share the same
* name, which causes `AMBIGUOUS_REFERENCE` errors during name-based resolution.
*
* By tagging the pass-through copy with `__is_duplicate`,
* `AttributeSeq.getCandidatesForResolution` filters it out when multiple candidates
* match by name, allowing the produced grouping attribute to be resolved instead.
* ExprId-based resolution (used for already-resolved expressions like aggregate
* functions) is unaffected by this metadata.
*
* Only child attributes whose ExprId matches a `groupByAlias` child that is a simple
* `Attribute` are tagged. Complex grouping expressions (e.g., `c1 + 1`) produce
* aliases with different names than any child column, so no name conflict arises.
*/
private def tagPassthroughDuplicates(
childOutput: Seq[Attribute],
groupByAliases: Seq[Alias]): Seq[Attribute] = {
val duplicatedExprIds = new java.util.HashSet[ExprId](groupByAliases.size)
groupByAliases.foreach {
case Alias(attr: Attribute, _) => duplicatedExprIds.add(attr.exprId)
case _ =>
}

if (!duplicatedExprIds.isEmpty) {
childOutput.map { attr =>
if (duplicatedExprIds.contains(attr.exprId)) {
attr.withMetadata(new MetadataBuilder()
.withMetadata(attr.metadata)
.putNull("__is_duplicate")
.build())
} else {
attr
}
}
} else {
childOutput
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val EXPAND_TAG_PASSTHROUGH_DUPLICATES_ENABLED =
buildConf("spark.sql.analyzer.expandTagPassthroughDuplicates")
.internal()
.version("4.2.0")
.doc(
"When true, Expand tags pass-through child attributes that share a name with a " +
"grouping attribute using __is_duplicate metadata, so that name-based resolution " +
"against the Expand output does not produce AMBIGUOUS_REFERENCE errors.")
.booleanConf
.createWithDefault(true)

val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis

import java.util.TimeZone

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Expand Down Expand Up @@ -279,6 +281,186 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
}

test("Expand tags pass-through duplicates for simple attribute grouping") {
val groupByAliases = Seq(Alias(a, "a")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs = BaseGroupingSets.rollupExprs(Seq(Seq(groupByAttrs.head)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
val childOutputInExpand = expand.output.take(r1.output.length)

assert(childOutputInExpand.head.metadata.contains("__is_duplicate"),
"pass-through attribute 'a' should be tagged with __is_duplicate")
assert(!childOutputInExpand(1).metadata.contains("__is_duplicate"),
"non-grouped attribute 'b' should not be tagged")
assert(!childOutputInExpand(2).metadata.contains("__is_duplicate"),
"non-grouped attribute 'c' should not be tagged")
}

test("Expand does not tag pass-through for complex grouping expressions") {
val complexExpr = a + Literal(1)
val groupByAliases = Seq(Alias(complexExpr, "(a + 1)")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs = BaseGroupingSets.rollupExprs(Seq(Seq(groupByAttrs.head)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
val childOutputInExpand = expand.output.take(r1.output.length)

childOutputInExpand.foreach { attr =>
assert(!attr.metadata.contains("__is_duplicate"),
s"attribute '${attr.name}' should not be tagged for complex grouping expression")
}
}

test("Expand tags multiple pass-through duplicates for multi-column grouping") {
val groupByAliases = Seq(Alias(a, "a")(), Alias(b, "b")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs =
BaseGroupingSets.rollupExprs(groupByAttrs.map(Seq(_)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
val childOutputInExpand = expand.output.take(r1.output.length)

assert(childOutputInExpand.head.metadata.contains("__is_duplicate"),
"pass-through attribute 'a' should be tagged")
assert(childOutputInExpand(1).metadata.contains("__is_duplicate"),
"pass-through attribute 'b' should be tagged")
assert(!childOutputInExpand(2).metadata.contains("__is_duplicate"),
"non-grouped attribute 'c' should not be tagged")
}

test("Expand tagged duplicates preserve ExprId") {
val groupByAliases = Seq(Alias(a, "a")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs = BaseGroupingSets.rollupExprs(Seq(Seq(groupByAttrs.head)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
val taggedAttr = expand.output.head

assert(taggedAttr.exprId == a.exprId,
"tagged attribute should preserve the original ExprId")
assert(taggedAttr.name == a.name,
"tagged attribute should preserve the original name")
}

test("Expand pass-through tagging prevents AMBIGUOUS_REFERENCE on name-based resolution") {
// The Expand for ROLLUP(a) produces an output with two attributes named "a":
// the pass-through child attribute (a#original) and the new grouping instance
// (a#new). Any operator that resolves "a" by name against this output would see
// two candidates and throw AMBIGUOUS_REFERENCE.
//
// This can happen when an operator sits directly above the Expand and contains an
// unresolved reference -- for example, a Filter or Project inserted between the
// Expand and its parent Aggregate by a custom analysis rule, or a correlated
// subquery whose outer reference resolves against the Expand's output.
val groupByAliases = Seq(Alias(a, "a")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs = BaseGroupingSets.rollupExprs(Seq(Seq(groupByAttrs.head)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
assert(expand.output.count(_.name == "a") == 2,
"Expand output should have 2 attributes named 'a'")

// With __is_duplicate tagging (the fix), resolve() returns a single result.
val resolved = expand.output.resolve(Seq("a"), caseSensitiveResolution)
assert(resolved.isDefined, "should resolve 'a' successfully with tagging")
assert(!resolved.get.toAttribute.metadata.contains("__is_duplicate"),
"resolved attribute should be the grouping instance, not the tagged pass-through")

// Without tagging, resolve() throws AMBIGUOUS_REFERENCE because both
// candidates match and neither is deprioritized.
val untaggedOutput = expand.output.map { attr =>
if (attr.metadata.contains("__is_duplicate")) {
attr.withMetadata(Metadata.empty)
} else {
attr
}
}
checkError(
exception = intercept[AnalysisException] {
untaggedOutput.resolve(Seq("a"), caseSensitiveResolution)
},
condition = "AMBIGUOUS_REFERENCE",
parameters = Map("name" -> "`a`", "referenceNames" -> "[`a`, `a`]")
)
}

test("Expand does not tag pass-through duplicates when flag is disabled") {
withSQLConf(
SQLConf.EXPAND_TAG_PASSTHROUGH_DUPLICATES_ENABLED.key -> "false") {
val groupByAliases = Seq(Alias(a, "a")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs = BaseGroupingSets.rollupExprs(Seq(Seq(groupByAttrs.head)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
val childOutputInExpand = expand.output.take(r1.output.length)

childOutputInExpand.foreach { attr =>
assert(!attr.metadata.contains("__is_duplicate"),
s"attribute '${attr.name}' should not be tagged when flag is disabled")
}

assert(expand.output.count(_.name == "a") == 2)
checkError(
exception = intercept[AnalysisException] {
expand.output.resolve(Seq("a"), caseSensitiveResolution)
},
condition = "AMBIGUOUS_REFERENCE",
parameters = Map("name" -> "`a`", "referenceNames" -> "[`a`, `a`]")
)
}
}

test("Expand does not tag multi-column pass-through duplicates when flag is disabled") {
withSQLConf(
SQLConf.EXPAND_TAG_PASSTHROUGH_DUPLICATES_ENABLED.key -> "false") {
val groupByAliases = Seq(Alias(a, "a")(), Alias(b, "b")())
val groupByAttrs = groupByAliases.map(_.toAttribute)
val gidAttr = AttributeReference(VirtualColumn.groupingIdName, GroupingID.dataType, false)()
val groupingSetsAttrs =
BaseGroupingSets.rollupExprs(groupByAttrs.map(Seq(_)))
.map(_.map(e => groupByAttrs.find(_.semanticEquals(e)).get))

val expand = Expand(groupingSetsAttrs, groupByAliases, groupByAttrs, gidAttr, r1)
val childOutputInExpand = expand.output.take(r1.output.length)

childOutputInExpand.foreach { attr =>
assert(!attr.metadata.contains("__is_duplicate"),
s"attribute '${attr.name}' should not be tagged when flag is disabled")
}

assert(expand.output.count(_.name == "a") == 2)
checkError(
exception = intercept[AnalysisException] {
expand.output.resolve(Seq("a"), caseSensitiveResolution)
},
condition = "AMBIGUOUS_REFERENCE",
parameters = Map("name" -> "`a`", "referenceNames" -> "[`a`, `a`]")
)

assert(expand.output.count(_.name == "b") == 2)
checkError(
exception = intercept[AnalysisException] {
expand.output.resolve(Seq("b"), caseSensitiveResolution)
},
condition = "AMBIGUOUS_REFERENCE",
parameters = Map("name" -> "`b`", "referenceNames" -> "[`b`, `b`]")
)
}
}

test("sort with grouping function") {
// Sort with Grouping function
val originalPlan = Sort(
Expand Down