-
Notifications
You must be signed in to change notification settings - Fork 2k
fix: preserve duplicate GROUPING SETS rows #21058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3bdd9e7
c9453ef
2a401c5
c04403e
546de18
3e9b4f0
f6ef0fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,7 @@ use crate::{ | |
| use datafusion_common::config::ConfigOptions; | ||
| use datafusion_physical_expr::utils::collect_columns; | ||
| use parking_lot::Mutex; | ||
| use std::collections::HashSet; | ||
| use std::collections::{HashMap, HashSet}; | ||
|
|
||
| use arrow::array::{ArrayRef, UInt8Array, UInt16Array, UInt32Array, UInt64Array}; | ||
| use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; | ||
|
|
@@ -396,6 +396,15 @@ impl PhysicalGroupBy { | |
| self.expr.len() + usize::from(self.has_grouping_set) | ||
| } | ||
|
|
||
| /// Returns the Arrow data type of the `__grouping_id` column. | ||
| /// | ||
| /// The type is chosen to be wide enough to hold both the semantic bitmask | ||
| /// (in the low `n` bits, where `n` is the number of grouping expressions) | ||
| /// and the duplicate ordinal (in the high bits). | ||
| fn grouping_id_data_type(&self) -> DataType { | ||
| Aggregate::grouping_id_type(self.expr.len(), max_duplicate_ordinal(&self.groups)) | ||
| } | ||
|
|
||
| pub fn group_schema(&self, schema: &Schema) -> Result<SchemaRef> { | ||
| Ok(Arc::new(Schema::new(self.group_fields(schema)?))) | ||
| } | ||
|
|
@@ -420,7 +429,7 @@ impl PhysicalGroupBy { | |
| fields.push( | ||
| Field::new( | ||
| Aggregate::INTERNAL_GROUPING_ID, | ||
| Aggregate::grouping_id_type(self.expr.len()), | ||
| self.grouping_id_data_type(), | ||
| false, | ||
| ) | ||
| .into(), | ||
|
|
@@ -1937,27 +1946,72 @@ fn evaluate_optional( | |
| .collect() | ||
| } | ||
|
|
||
| fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> { | ||
| if group.len() > 64 { | ||
| /// Builds the internal `__grouping_id` array for a single grouping set. | ||
| /// | ||
| /// The returned array packs two values into a single integer: | ||
| /// | ||
| /// - Low `n` bits (positions 0 .. n-1): the semantic bitmask. A `1` bit | ||
| /// at position `i` means that the `i`-th grouping column (counting from the | ||
| /// least significant bit, i.e. the *last* column in the `group` slice) is | ||
| /// `NULL` for this grouping set. | ||
| /// - High bits (positions n and above): the duplicate `ordinal`, which | ||
| /// distinguishes multiple occurrences of the same grouping-set pattern. The | ||
| /// ordinal is `0` for the first occurrence, `1` for the second, and so on. | ||
| /// | ||
| /// The integer type is chosen to be the smallest `UInt8 / UInt16 / UInt32 / | ||
| /// UInt64` that can represent both parts. It matches the type returned by | ||
| /// [`Aggregate::grouping_id_type`]. | ||
| fn group_id_array( | ||
| group: &[bool], | ||
| ordinal: usize, | ||
| max_ordinal: usize, | ||
| batch: &RecordBatch, | ||
| ) -> Result<ArrayRef> { | ||
| let n = group.len(); | ||
| if n > 64 { | ||
| return not_impl_err!( | ||
| "Grouping sets with more than 64 columns are not supported" | ||
| ); | ||
| } | ||
| let group_id = group.iter().fold(0u64, |acc, &is_null| { | ||
| let ordinal_bits = usize::BITS as usize - max_ordinal.leading_zeros() as usize; | ||
| let total_bits = n + ordinal_bits; | ||
| if total_bits > 64 { | ||
| return not_impl_err!( | ||
| "Grouping sets with {n} columns and a maximum duplicate ordinal of \ | ||
| {max_ordinal} require {total_bits} bits, which exceeds 64" | ||
| ); | ||
| } | ||
| let semantic_id = group.iter().fold(0u64, |acc, &is_null| { | ||
| (acc << 1) | if is_null { 1 } else { 0 } | ||
| }); | ||
| let full_id = semantic_id | ((ordinal as u64) << n); | ||
| let num_rows = batch.num_rows(); | ||
| if group.len() <= 8 { | ||
| Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) | ||
| } else if group.len() <= 16 { | ||
| Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) | ||
| } else if group.len() <= 32 { | ||
| Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) | ||
| if total_bits <= 8 { | ||
| Ok(Arc::new(UInt8Array::from(vec![full_id as u8; num_rows]))) | ||
| } else if total_bits <= 16 { | ||
| Ok(Arc::new(UInt16Array::from(vec![full_id as u16; num_rows]))) | ||
| } else if total_bits <= 32 { | ||
| Ok(Arc::new(UInt32Array::from(vec![full_id as u32; num_rows]))) | ||
| } else { | ||
| Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) | ||
| Ok(Arc::new(UInt64Array::from(vec![full_id; num_rows]))) | ||
| } | ||
| } | ||
|
|
||
| /// Returns the highest duplicate ordinal across all grouping sets. | ||
| /// | ||
| /// At the call-site, the ordinal is the 0-based index assigned to each | ||
| /// occurrence of a repeated grouping-set pattern: the first occurrence gets | ||
| /// ordinal 0, the second gets 1, and so on. If the same `Vec<bool>` appears | ||
| /// three times the ordinals are 0, 1, 2 and this function returns 2. | ||
| /// Returns 0 when no grouping set is duplicated. | ||
| fn max_duplicate_ordinal(groups: &[Vec<bool>]) -> usize { | ||
| let mut counts: HashMap<&[bool], usize> = HashMap::new(); | ||
| for group in groups { | ||
| *counts.entry(group).or_insert(0) += 1; | ||
| } | ||
| counts.into_values().max().unwrap_or(0).saturating_sub(1) | ||
| } | ||
|
|
||
| /// Evaluate a group by expression against a `RecordBatch` | ||
| /// | ||
| /// Arguments: | ||
|
|
@@ -1972,6 +2026,8 @@ pub fn evaluate_group_by( | |
| group_by: &PhysicalGroupBy, | ||
| batch: &RecordBatch, | ||
| ) -> Result<Vec<Vec<ArrayRef>>> { | ||
| let max_ordinal = max_duplicate_ordinal(&group_by.groups); | ||
| let mut ordinal_per_pattern: HashMap<&Vec<bool>, usize> = HashMap::new(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| let exprs = evaluate_expressions_to_arrays( | ||
| group_by.expr.iter().map(|(expr, _)| expr), | ||
| batch, | ||
|
|
@@ -1985,6 +2041,10 @@ pub fn evaluate_group_by( | |
| .groups | ||
| .iter() | ||
| .map(|group| { | ||
| let ordinal = ordinal_per_pattern.entry(group).or_insert(0); | ||
| let current_ordinal = *ordinal; | ||
| *ordinal += 1; | ||
|
|
||
| let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); | ||
| group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { | ||
| if *is_null { | ||
|
|
@@ -1994,7 +2054,12 @@ pub fn evaluate_group_by( | |
| } | ||
| })); | ||
| if !group_by.is_single() { | ||
| group_values.push(group_id_array(group, batch)?); | ||
| group_values.push(group_id_array( | ||
| group, | ||
| current_ordinal, | ||
| max_ordinal, | ||
| batch, | ||
| )?); | ||
| } | ||
| Ok(group_values) | ||
| }) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5203,6 +5203,41 @@ NULL NULL 1 | |
| statement ok | ||
| drop table t; | ||
|
|
||
| # regression: duplicate grouping sets must not be collapsed into one | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's worth having a test case for the situation where adding the duplicate ordinal widens the size of the grouping ID field, that's a bit tricky. |
||
| statement ok | ||
| create table duplicate_grouping_sets(deptno int, job varchar, sal int, comm int) as values | ||
| (10, 'CLERK', 1300, null), | ||
| (20, 'MANAGER', 3000, null); | ||
|
|
||
| query ITIIIII | ||
| select deptno, job, sal, sum(comm), grouping(deptno), grouping(job), grouping(sal) | ||
| from duplicate_grouping_sets | ||
| group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) | ||
| order by deptno, job, sal, grouping(deptno), grouping(job), grouping(sal); | ||
| ---- | ||
| 10 CLERK NULL NULL 0 0 1 | ||
| 10 CLERK NULL NULL 0 0 1 | ||
| 10 NULL 1300 NULL 0 1 0 | ||
| 20 MANAGER NULL NULL 0 0 1 | ||
| 20 MANAGER NULL NULL 0 0 1 | ||
| 20 NULL 3000 NULL 0 1 0 | ||
|
|
||
| query ITII | ||
| select deptno, job, sal, grouping(deptno, job, sal) | ||
| from duplicate_grouping_sets | ||
| group by grouping sets ((deptno, job), (deptno, sal), (deptno, job)) | ||
| order by deptno, job, sal, grouping(deptno, job, sal); | ||
| ---- | ||
| 10 CLERK NULL 1 | ||
| 10 CLERK NULL 1 | ||
| 10 NULL 1300 2 | ||
| 20 MANAGER NULL 1 | ||
| 20 MANAGER NULL 1 | ||
| 20 NULL 3000 2 | ||
|
|
||
| statement ok | ||
| drop table duplicate_grouping_sets; | ||
|
|
||
| # test multi group by for binary type without nulls | ||
| statement ok | ||
| create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n >= 64is rejected earlier, right? I wonder if it would be cleaner to just assertn < 64and then remove the conditional.