diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 07e0eb1a77aa..4f73169ad282 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,7 +45,7 @@ use crate::utils::{ grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, }; use crate::{ - BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, + BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, GroupingSet, LogicalPlanBuilder, Operator, Prepare, TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, build_join_schema, expr_vec_fmt, requalify_sides_if_needed, }; @@ -3595,11 +3595,12 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + let max_ordinal = max_grouping_set_duplicate_ordinal(&group_expr); qualified_fields.push(( None, Field::new( Self::INTERNAL_GROUPING_ID, - Self::grouping_id_type(qualified_fields.len()), + Self::grouping_id_type(qualified_fields.len(), max_ordinal), false, ) .into(), @@ -3685,15 +3686,24 @@ impl Aggregate { } /// Returns the data type of the grouping id. - /// The grouping ID value is a bitmask where each set bit - /// indicates that the corresponding grouping expression is - /// null - pub fn grouping_id_type(group_exprs: usize) -> DataType { - if group_exprs <= 8 { + /// + /// The grouping ID packs two pieces of information into a single integer: + /// - The low `group_exprs` bits are the semantic bitmask (a set bit means the + /// corresponding grouping expression is NULL for this grouping set). + /// - The bits above position `group_exprs` encode a duplicate ordinal that + /// distinguishes multiple occurrences of the same grouping set pattern. + /// + /// `max_ordinal` is the highest ordinal value that will appear (0 when there + /// are no duplicate grouping sets). The type is chosen to be the smallest + /// unsigned integer that can represent both parts. + pub fn grouping_id_type(group_exprs: usize, max_ordinal: usize) -> DataType { + let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize; + let total_bits = group_exprs + ordinal_bits; + if total_bits <= 8 { DataType::UInt8 - } else if group_exprs <= 16 { + } else if total_bits <= 16 { DataType::UInt16 - } else if group_exprs <= 32 { + } else if total_bits <= 32 { DataType::UInt32 } else { DataType::UInt64 @@ -3702,21 +3712,36 @@ impl Aggregate { /// Internal column used when the aggregation is a grouping set. /// - /// This column contains a bitmask where each bit represents a grouping - /// expression. The least significant bit corresponds to the rightmost - /// grouping expression. A bit value of 0 indicates that the corresponding - /// column is included in the grouping set, while a value of 1 means it is excluded. + /// This column packs two values into a single unsigned integer: + /// + /// - **Low bits (positions 0 .. n-1)**: a semantic bitmask where each bit + /// represents one of the `n` grouping expressions. The least significant + /// bit corresponds to the rightmost grouping expression. A `1` bit means + /// the corresponding column is replaced with `NULL` for this grouping set; + /// a `0` bit means it is included. + /// - **High bits (positions n and above)**: a *duplicate ordinal* that + /// distinguishes multiple occurrences of the same semantic grouping set + /// pattern within a single query. The ordinal is `0` for the first + /// occurrence, `1` for the second, and so on. + /// + /// The integer type is chosen by [`Self::grouping_id_type`] to be the + /// smallest `UInt8 / UInt16 / UInt32 / UInt64` that can represent both + /// parts. /// - /// For example, for the grouping expressions CUBE(a, b), the grouping ID - /// column will have the following values: + /// For example, for the grouping expressions CUBE(a, b) (no duplicates), + /// the grouping ID column will have the following values: /// 0b00: Both `a` and `b` are included /// 0b01: `b` is excluded /// 0b10: `a` is excluded /// 0b11: Both `a` and `b` are excluded /// - /// This internal column is necessary because excluded columns are replaced - /// with `NULL` values. To handle these cases correctly, we must distinguish - /// between an actual `NULL` value in a column and a column being excluded from the set. + /// When the same set appears twice and `n = 2`, the duplicate ordinal is + /// packed into bit 2: + /// first occurrence: `0b0_01` (ordinal = 0, mask = 0b01) + /// second occurrence: `0b1_01` (ordinal = 1, mask = 0b01) + /// + /// The GROUPING function always masks the value with `(1 << n) - 1` before + /// interpreting it so the ordinal bits are invisible to user-facing SQL. pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } @@ -3737,6 +3762,24 @@ impl PartialOrd for Aggregate { } } +/// Returns the highest duplicate ordinal across all grouping sets in `group_expr`. +/// +/// The ordinal for each occurrence of a grouping set pattern is its 0-based +/// index among identical entries. For example, if the same set appears three +/// times, the ordinals are 0, 1, 2 and this function returns 2. +/// Returns 0 when no grouping set is duplicated. +fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize { + if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() { + let mut counts: HashMap<&[Expr], usize> = HashMap::new(); + for set in sets { + *counts.entry(set).or_insert(0) += 1; + } + counts.into_values().max().unwrap_or(0).saturating_sub(1) + } else { + 0 + } +} + /// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. fn contains_grouping_set(group_expr: &[Expr]) -> bool { group_expr @@ -5053,6 +5096,14 @@ mod tests { ); } + #[test] + fn grouping_id_type_accounts_for_duplicate_ordinal_bits() { + // 8 grouping columns fit in UInt8 when there are no duplicate ordinals, + // but adding one duplicate ordinal bit widens the type to UInt16. + assert_eq!(Aggregate::grouping_id_type(8, 0), DataType::UInt8); + assert_eq!(Aggregate::grouping_id_type(8, 1), DataType::UInt16); + } + #[test] fn test_filter_is_scalar() { // test empty placeholder diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 6b8ae3e8531b..c12d7fd2ec2f 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -99,10 +99,17 @@ fn replace_grouping_exprs( { match expr { Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + let grouping_id_type = is_grouping_set + .then(|| { + schema + .field_with_name(None, Aggregate::INTERNAL_GROUPING_ID) + .map(|f| f.data_type().clone()) + }) + .transpose()?; let grouping_expr = grouping_function_on_id( function, &group_expr_to_bitmap_index, - is_grouping_set, + grouping_id_type, )?; projection_exprs.push(Expr::Alias(Alias::new( grouping_expr, @@ -184,40 +191,44 @@ fn validate_args( fn grouping_function_on_id( function: &AggregateFunction, group_by_expr: &HashMap<&Expr, usize>, - is_grouping_set: bool, + // None means not a grouping set (result is always 0). + grouping_id_type: Option, ) -> Result { validate_args(function, group_by_expr)?; let args = &function.params.args; // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 - if !is_grouping_set { + let Some(grouping_id_type) = grouping_id_type else { return Ok(Expr::Literal(ScalarValue::from(0i32), None)); - } - - let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8), None) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16), None) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32), None) - } else { - Expr::Literal(ScalarValue::from(value as u64), None) - } }; + // Use the actual __grouping_id column type to size literals correctly. This + // accounts for duplicate-ordinal bits that `Aggregate::grouping_id_type` + // packs into the high bits of the column, which a simple count of grouping + // expressions would miss. + let literal = |value: usize| match &grouping_id_type { + DataType::UInt8 => Expr::Literal(ScalarValue::from(value as u8), None), + DataType::UInt16 => Expr::Literal(ScalarValue::from(value as u16), None), + DataType::UInt32 => Expr::Literal(ScalarValue::from(value as u32), None), + DataType::UInt64 => Expr::Literal(ScalarValue::from(value as u64), None), + other => panic!("unexpected __grouping_id type: {other}"), + }; let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); - // The grouping call is exactly our internal grouping id - if args.len() == group_by_expr_count + if args.len() == group_by_expr.len() && args .iter() .rev() .enumerate() .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) { - return Ok(cast(grouping_id_column, DataType::Int32)); + let n = group_by_expr.len(); + // Mask the ordinal bits above position `n` so only the semantic bitmask is visible. + // checked_shl returns None when n >= 64 (all bits are semantic), mapping to u64::MAX. + let semantic_mask: u64 = 1u64.checked_shl(n as u32).map_or(u64::MAX, |m| m - 1); + let masked_id = + bitwise_and(grouping_id_column.clone(), literal(semantic_mask as usize)); + return Ok(cast(masked_id, DataType::Int32)); } args.iter() diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 24bf2265ff05..76de3e0cba67 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -37,7 +37,7 @@ use crate::{ use datafusion_common::config::ConfigOptions; use datafusion_physical_expr::utils::collect_columns; use parking_lot::Mutex; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -396,6 +396,15 @@ impl PhysicalGroupBy { self.expr.len() + usize::from(self.has_grouping_set) } + /// Returns the Arrow data type of the `__grouping_id` column. + /// + /// The type is chosen to be wide enough to hold both the semantic bitmask + /// (in the low `n` bits, where `n` is the number of grouping expressions) + /// and the duplicate ordinal (in the high bits). + fn grouping_id_data_type(&self) -> DataType { + Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups)) + } + pub fn group_schema(&self, schema: &Schema) -> Result { Ok(Arc::new(Schema::new(self.group_fields(schema)?))) } @@ -420,7 +429,7 @@ impl PhysicalGroupBy { fields.push( Field::new( Aggregate::INTERNAL_GROUPING_ID, - Aggregate::grouping_id_type(self.expr.len()), + self.grouping_id_data_type(), false, ) .into(), @@ -1937,27 +1946,72 @@ fn evaluate_optional( .collect() } -fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { - if group.len() > 64 { +/// Builds the internal `__grouping_id` array for a single grouping set. +/// +/// The returned array packs two values into a single integer: +/// +/// - Low `n` bits (positions 0 .. n-1): the semantic bitmask. A `1` bit +/// at position `i` means that the `i`-th grouping column (counting from the +/// least significant bit, i.e. the *last* column in the `group` slice) is +/// `NULL` for this grouping set. +/// - High bits (positions n and above): the duplicate `ordinal`, which +/// distinguishes multiple occurrences of the same grouping-set pattern. The +/// ordinal is `0` for the first occurrence, `1` for the second, and so on. +/// +/// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 / +/// UInt64` that can represent both parts. It matches the type returned by +/// [`Aggregate::grouping_id_type`]. +fn group_id_array( + group: &[bool], + ordinal: usize, + max_ordinal: usize, + batch: &RecordBatch, +) -> Result { + let n = group.len(); + if n > 64 { return not_impl_err!( "Grouping sets with more than 64 columns are not supported" ); } - let group_id = group.iter().fold(0u64, |acc, &is_null| { + let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize; + let total_bits = n + ordinal_bits; + if total_bits > 64 { + return not_impl_err!( + "Grouping sets with {n} columns and a maximum duplicate ordinal of \ + {max_ordinal} require {total_bits} bits, which exceeds 64" + ); + } + let semantic_id = group.iter().fold(0u64, |acc, &is_null| { (acc << 1) | if is_null { 1 } else { 0 } }); + let full_id = semantic_id | ((ordinal as u64) << n); let num_rows = batch.num_rows(); - if group.len() <= 8 { - Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) - } else if group.len() <= 16 { - Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) - } else if group.len() <= 32 { - Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) + if total_bits <= 8 { + Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows]))) + } else if total_bits <= 16 { + Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows]))) + } else if total_bits <= 32 { + Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows]))) } else { - Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) + Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows]))) } } +/// Returns the highest duplicate ordinal across all grouping sets. +/// +/// At the call-site, the ordinal is the 0-based index assigned to each +/// occurrence of a repeated grouping-set pattern: the first occurrence gets +/// ordinal 0, the second gets 1, and so on. If the same `Vec` appears +/// three times the ordinals are 0, 1, 2 and this function returns 2. +/// Returns 0 when no grouping set is duplicated. +fn max_duplicate_ordinal(groups: &[Vec]) -> usize { + let mut counts: HashMap<&[bool], usize> = HashMap::new(); + for group in groups { + *counts.entry(group).or_insert(0) += 1; + } + counts.into_values().max().unwrap_or(0).saturating_sub(1) +} + /// Evaluate a group by expression against a `RecordBatch` /// /// Arguments: @@ -1972,6 +2026,8 @@ pub fn evaluate_group_by( group_by: &PhysicalGroupBy, batch: &RecordBatch, ) -> Result>> { + let max_ordinal = max_duplicate_ordinal(&group_by.groups); + let mut ordinal_per_pattern: HashMap<&[bool], usize> = HashMap::new(); let exprs = evaluate_expressions_to_arrays( group_by.expr.iter().map(|(expr, _)| expr), batch, @@ -1985,6 +2041,10 @@ pub fn evaluate_group_by( .groups .iter() .map(|group| { + let ordinal = ordinal_per_pattern.entry(group).or_insert(0); + let current_ordinal = *ordinal; + *ordinal += 1; + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { if *is_null { @@ -1994,7 +2054,12 @@ pub fn evaluate_group_by( } })); if !group_by.is_single() { - group_values.push(group_id_array(group, batch)?); + group_values.push(group_id_array( + group, + current_ordinal, + max_ordinal, + batch, + )?); } Ok(group_values) }) diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 59db63ba420e..b31342495153 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5203,6 +5203,41 @@ NULL NULL 1 statement ok drop table t; +# regression: duplicate grouping sets must not be collapsed into one +statement ok +create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) as values +(10, 'CLERK', 1300, null), +(20, 'MANAGER', 3000, null); + +query ITIIIII +select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal) +from duplicate_grouping_sets +group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) +order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal); +---- +10 CLERK NULL NULL 0 0 1 +10 CLERK NULL NULL 0 0 1 +10 NULL 1300 NULL 0 1 0 +20 MANAGER NULL NULL 0 0 1 +20 MANAGER NULL NULL 0 0 1 +20 NULL 3000 NULL 0 1 0 + +query ITII +select deptno, job, sal, grouping(deptno, job, sal) +from duplicate_grouping_sets +group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) +order by deptno, job, sal, grouping(deptno, job, sal); +---- +10 CLERK NULL 1 +10 CLERK NULL 1 +10 NULL 1300 2 +20 MANAGER NULL 1 +20 MANAGER NULL 1 +20 NULL 3000 2 + +statement ok +drop table duplicate_grouping_sets; + # test multi group by for binary type without nulls statement ok create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb);