From 78835f55234d1cace82e68c28c05cfc5c7e99053 Mon Sep 17 00:00:00 2001 From: yash Date: Sat, 21 Mar 2026 12:28:21 +0530 Subject: [PATCH 1/7] perf: default multi COUNT(DISTINCT) logical optimizer rewrite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add MultiDistinctCountRewrite in datafusion-optimizer and register it in Optimizer::new() after SingleDistinctToGroupBy. Rewrites 2+ simple COUNT(DISTINCT) on different args into a join of two-phase aggregates; filter distinct_arg IS NOT NULL on each branch for correct NULL semantics. ✅ Unit tests in datafusion-optimizer; ✅ SQL integration test (NULLs) in core_integration. --- datafusion/core/tests/sql/aggregates/mod.rs | 1 + .../multi_distinct_count_rewrite.rs | 58 +++ datafusion/optimizer/src/lib.rs | 1 + .../src/multi_distinct_count_rewrite.rs | 395 ++++++++++++++++++ datafusion/optimizer/src/optimizer.rs | 2 + 5 files changed, 457 insertions(+) create mode 100644 datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs create mode 100644 datafusion/optimizer/src/multi_distinct_count_rewrite.rs diff --git a/datafusion/core/tests/sql/aggregates/mod.rs b/datafusion/core/tests/sql/aggregates/mod.rs index ede40d5c4ceca..ffdb886e89c0b 100644 --- a/datafusion/core/tests/sql/aggregates/mod.rs +++ b/datafusion/core/tests/sql/aggregates/mod.rs @@ -1024,3 +1024,4 @@ pub fn split_fuzz_timestamp_data_into_batches( pub mod basic; pub mod dict_nulls; +pub mod multi_distinct_count_rewrite; diff --git a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs new file mode 100644 index 0000000000000..4a398cd3cd9d7 --- /dev/null +++ b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! End-to-end SQL tests for the multi-`COUNT(DISTINCT)` logical optimizer rewrite. + +use super::*; +use arrow::array::{Int32Array, StringArray}; +use datafusion::common::test_util::batches_to_sort_string; +use datafusion_catalog::MemTable; + +#[tokio::test] +async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(StringArray::from(vec![Some("x"), None, Some("x")])), + Arc::new(StringArray::from(vec![None, Some("y"), Some("y")])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = + "SELECT g, COUNT(DISTINCT b) AS cb, COUNT(DISTINCT c) AS cc FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+----+----+\n\ + | g | cb | cc |\n\ + +---+----+----+\n\ + | 1 | 1 | 1 |\n\ + +---+----+----+" + ); + Ok(()) +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e610091824092..bdec72eba7e18 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -59,6 +59,7 @@ pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod extract_leaf_expressions; pub mod filter_null_join_keys; +pub mod multi_distinct_count_rewrite; pub mod optimize_projections; pub mod optimize_unions; pub mod optimizer; diff --git a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs new file mode 100644 index 0000000000000..b9669fdcb6f7f --- /dev/null +++ b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs @@ -0,0 +1,395 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rewrites a single [`Aggregate`] that contains multiple `COUNT(DISTINCT ...)` +//! into a join of smaller aggregates so each distinct is computed with one +//! accumulator set, reducing peak memory for high-cardinality distincts. + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::{ + Column, JoinConstraint, NullEquality, Result, tree_node::Transformed, +}; +use datafusion_expr::builder::project; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::logical_plan::{ + Aggregate, Join, JoinType, LogicalPlan, SubqueryAlias, +}; +use datafusion_expr::{Expr, col, lit, logical_plan::LogicalPlanBuilder}; + +const MAX_DISTINCT_REWRITE_BRANCHES: usize = 8; + +/// Optimizer rule: multiple `COUNT(DISTINCT ...)` → join of per-distinct sub-aggregates. +#[derive(Default, Debug)] +pub struct MultiDistinctCountRewrite {} + +impl MultiDistinctCountRewrite { + /// Create a new rule instance. + pub fn new() -> Self { + Self {} + } + + fn is_simple_count_distinct( + e: &Expr, + ) -> Option<(Arc, Expr)> { + if let Expr::AggregateFunction(AggregateFunction { func, params }) = e { + let AggregateFunctionParams { + distinct, + args, + filter, + order_by, + .. + } = ¶ms; + if func.name().eq_ignore_ascii_case("count") + && *distinct + && args.len() == 1 + && filter.is_none() + && order_by.is_empty() + { + let arg = args.first().cloned()?; + if Self::is_safe_distinct_arg(&arg) { + return Some((Arc::clone(func), arg)); + } + } + } + None + } + + fn is_safe_distinct_arg(e: &Expr) -> bool { + if e.is_volatile() { + return false; + } + match e { + Expr::Column(_) => true, + Expr::ScalarFunction(ScalarFunction { func, args }) => { + matches!(func.name().to_ascii_lowercase().as_str(), "lower" | "upper") + && args.len() == 1 + && matches!(args.first(), Some(Expr::Column(_))) + } + Expr::Cast(cast) => matches!(cast.expr.as_ref(), Expr::Column(_)), + _ => false, + } + } + + fn is_simple_group_expr(e: &Expr) -> bool { + matches!(e, Expr::Column(_)) + } + + fn contains_grouping_set(group_expr: &[Expr]) -> bool { + group_expr + .first() + .is_some_and(|e| matches!(e, Expr::GroupingSet(_))) + } +} + +impl OptimizerRule for MultiDistinctCountRewrite { + fn name(&self) -> &str { + "multi_distinct_count_rewrite" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + schema, + group_expr, + .. + }) = plan + else { + return Ok(Transformed::no(plan)); + }; + + if Self::contains_grouping_set(&group_expr) { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + if !group_expr.iter().all(Self::is_simple_group_expr) { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + let group_size = group_expr.len(); + let mut distinct_list: Vec<(Expr, usize, Arc)> = + vec![]; + let mut other_list: Vec<(Expr, usize)> = vec![]; + + for (i, e) in aggr_expr.iter().enumerate() { + if let Some((func, arg)) = Self::is_simple_count_distinct(e) { + distinct_list.push((arg, group_size + i, func)); + } else { + other_list.push((e.clone(), group_size + i)); + } + } + + if distinct_list.len() < 2 { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + if distinct_list.len() > MAX_DISTINCT_REWRITE_BRANCHES { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + { + use std::collections::HashSet; + let mut seen: HashSet<&Expr> = HashSet::new(); + for (arg, _, _) in distinct_list.iter() { + if !seen.insert(arg) { + return Ok(Transformed::no(LogicalPlan::Aggregate( + Aggregate::try_new(input, group_expr, aggr_expr)?, + ))); + } + } + } + + let count_udaf = Arc::clone(&distinct_list[0].2); + + let count_star = Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::clone(&count_udaf), + vec![lit(1i64)], + false, + None, + vec![], + None, + )) + .alias("_cnt"); + + let base_aggr_exprs: Vec = other_list + .iter() + .map(|(e, schema_idx)| { + let (q, f) = schema.qualified_field(*schema_idx); + e.clone().alias_qualified(q.cloned(), f.name()) + }) + .collect(); + let base_plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::clone(&input), + group_expr.clone(), + base_aggr_exprs, + )?); + + let base_alias = config.alias_generator().next("mdc_base"); + let base_aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(base_plan), + &base_alias, + )?); + + let mut current = Arc::new(base_aliased); + + for (distinct_arg, schema_aggr_idx, _) in distinct_list.iter() { + // COUNT(DISTINCT x) ignores NULLs; filter before grouping by x. + let filtered_input = LogicalPlanBuilder::from(input.as_ref().clone()) + .filter(distinct_arg.clone().is_not_null())? + .build()?; + + let inner_group: Vec = group_expr + .iter() + .cloned() + .chain(std::iter::once(distinct_arg.clone())) + .collect(); + let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(filtered_input), + inner_group, + vec![count_star.clone()], + )?); + + let (_, field) = schema.qualified_field(*schema_aggr_idx); + let outer_name = field.name().clone(); + let outer_aggr = Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::clone(&count_udaf), + vec![col("_cnt")], + false, + None, + vec![], + None, + )) + .alias(outer_name.clone()); + + let branch_plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(inner_agg), + group_expr.clone(), + vec![outer_aggr], + )?); + + let alias_name = config.alias_generator().next("mdc_d"); + let branch_aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(branch_plan), + &alias_name, + )?); + + let left_schema = current.schema(); + let right_schema = branch_aliased.schema(); + let join_keys: Vec<(Expr, Expr)> = (0..group_size) + .map(|i| { + let (lq, lf) = left_schema.qualified_field(i); + let (rq, rf) = right_schema.qualified_field(i); + ( + Expr::Column(Column::new(lq.cloned(), lf.name())), + Expr::Column(Column::new(rq.cloned(), rf.name())), + ) + }) + .collect(); + + let join = Join::try_new( + current, + Arc::new(branch_aliased), + join_keys, + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + false, + )?; + current = Arc::new(LogicalPlan::Join(join)); + } + + let join_schema = current.schema(); + + let mut proj_exprs: Vec = vec![]; + for i in 0..group_size { + let (q, f) = schema.qualified_field(i); + let orig_name = f.name(); + let (join_q, join_f) = join_schema.qualified_field(i); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } + for (field_idx, (_, schema_aggr_idx)) in other_list.iter().enumerate() { + let (q, f) = schema.qualified_field(*schema_aggr_idx); + let orig_name = f.name(); + let join_idx = group_size + field_idx; + let (join_q, join_f) = join_schema.qualified_field(join_idx); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } + let base_field_count = group_size + other_list.len(); + for (idx, (_, schema_aggr_idx, _)) in distinct_list.iter().enumerate() { + let (q, f) = schema.qualified_field(*schema_aggr_idx); + let orig_name = f.name(); + let branch_start_idx = base_field_count + idx * (group_size + 1); + let branch_aggr_idx = branch_start_idx + group_size; + let (join_q, join_f) = join_schema.qualified_field(branch_aggr_idx); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } + + let out = project((*current).clone(), proj_exprs)?; + Ok(Transformed::yes(out)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Optimizer; + use crate::OptimizerContext; + use crate::OptimizerRule; + use crate::test::*; + use datafusion_expr::LogicalPlan; + use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct}; + + fn optimize_with_rule( + plan: LogicalPlan, + rule: Arc, + ) -> Result { + Optimizer::with_rules(vec![rule]).optimize( + plan, + &OptimizerContext::new(), + |_, _| {}, + ) + } + + #[test] + fn rewrites_two_count_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count_distinct(col("b")), count_distinct(col("c"))], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("Filter: test.b IS NOT NULL"), + "expected null filter on b, got:\n{s}" + ); + assert!( + s.contains("Filter: test.c IS NOT NULL"), + "expected null filter on c, got:\n{s}" + ); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base alias, got:\n{s}" + ); + assert!( + s.matches("SubqueryAlias: mdc_d").count() >= 2, + "expected distinct branches, got:\n{s}" + ); + Ok(()) + } + + #[test] + fn does_not_rewrite_single_count_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![count_distinct(col("b"))])? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } + + #[test] + fn does_not_rewrite_mixed_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count_distinct(col("b")), count(col("c"))], + )? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index bdea6a83072cd..08d5bbfbb424e 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -45,6 +45,7 @@ use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::extract_leaf_expressions::{ExtractLeafExpressions, PushDownLeafProjections}; use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::multi_distinct_count_rewrite::MultiDistinctCountRewrite; use crate::optimize_projections::OptimizeProjections; use crate::optimize_unions::OptimizeUnions; use crate::plan_signature::LogicalPlanSignature; @@ -298,6 +299,7 @@ impl Optimizer { Arc::new(PushDownLimit::new()), Arc::new(PushDownFilter::new()), Arc::new(SingleDistinctToGroupBy::new()), + Arc::new(MultiDistinctCountRewrite::new()), // The previous optimizations added expressions and projections, // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), From 6140cd60d0352f402dd8b643ffd4ae7ab5ee82ab Mon Sep 17 00:00:00 2001 From: yash Date: Sat, 21 Mar 2026 15:37:05 +0530 Subject: [PATCH 2/7] fix: global multi-COUNT(DISTINCT) rewrite without invalid empty Aggregate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✅ Omit base Aggregate when GROUP BY is empty and only COUNT(DISTINCT) branches exist (matches clickbench extended global queries). ✅ First distinct branch seeds the plan; subsequent branches join (empty keys → Cross Join in plan). ✅ Add rewrites_global_three_count_distinct unit test. ❌ Previous shape could error: Aggregate with no grouping and no aggregate expressions. --- .../src/multi_distinct_count_rewrite.rs | 113 ++++++++++++------ 1 file changed, 77 insertions(+), 36 deletions(-) diff --git a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs index b9669fdcb6f7f..caaf4a615f3c9 100644 --- a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs +++ b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs @@ -192,19 +192,26 @@ impl OptimizerRule for MultiDistinctCountRewrite { e.clone().alias_qualified(q.cloned(), f.name()) }) .collect(); - let base_plan = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::clone(&input), - group_expr.clone(), - base_aggr_exprs, - )?); - let base_alias = config.alias_generator().next("mdc_base"); - let base_aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - Arc::new(base_plan), - &base_alias, - )?); - - let mut current = Arc::new(base_aliased); + // `Aggregate` must have at least one of grouping exprs or aggregate exprs. + // Global multi-`COUNT(DISTINCT)` (no GROUP BY, no other aggs) has neither — skip a base node. + let base_plan_opt: Option> = + if group_expr.is_empty() && other_list.is_empty() { + None + } else { + let base_plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::clone(&input), + group_expr.clone(), + base_aggr_exprs, + )?); + let base_alias = config.alias_generator().next("mdc_base"); + Some(Arc::new(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(base_plan), + &base_alias, + )?))) + }; + + let mut current = base_plan_opt; for (distinct_arg, schema_aggr_idx, _) in distinct_list.iter() { // COUNT(DISTINCT x) ignores NULLs; filter before grouping by x. @@ -247,32 +254,38 @@ impl OptimizerRule for MultiDistinctCountRewrite { &alias_name, )?); - let left_schema = current.schema(); - let right_schema = branch_aliased.schema(); - let join_keys: Vec<(Expr, Expr)> = (0..group_size) - .map(|i| { - let (lq, lf) = left_schema.qualified_field(i); - let (rq, rf) = right_schema.qualified_field(i); - ( - Expr::Column(Column::new(lq.cloned(), lf.name())), - Expr::Column(Column::new(rq.cloned(), rf.name())), - ) - }) - .collect(); - - let join = Join::try_new( - current, - Arc::new(branch_aliased), - join_keys, - None, - JoinType::Inner, - JoinConstraint::On, - NullEquality::NullEqualsNothing, - false, - )?; - current = Arc::new(LogicalPlan::Join(join)); + current = match current { + None => Some(Arc::new(branch_aliased)), + Some(prev) => { + let left_schema = prev.schema(); + let right_schema = branch_aliased.schema(); + let join_keys: Vec<(Expr, Expr)> = (0..group_size) + .map(|i| { + let (lq, lf) = left_schema.qualified_field(i); + let (rq, rf) = right_schema.qualified_field(i); + ( + Expr::Column(Column::new(lq.cloned(), lf.name())), + Expr::Column(Column::new(rq.cloned(), rf.name())), + ) + }) + .collect(); + + let join = Join::try_new( + prev, + Arc::new(branch_aliased), + join_keys, + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + false, + )?; + Some(Arc::new(LogicalPlan::Join(join))) + } + }; } + let current = current.expect("distinct_list non-empty implies at least one branch"); let join_schema = current.schema(); let mut proj_exprs: Vec = vec![]; @@ -362,6 +375,34 @@ mod tests { Ok(()) } + #[test] + fn rewrites_global_three_count_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + count_distinct(col("a")), + count_distinct(col("b")), + count_distinct(col("c")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!( + s.contains("Cross Join") || s.contains("Inner Join"), + "expected join rewrite for global multi-distinct, got:\n{s}" + ); + assert!( + !s.contains("mdc_base"), + "global-only rewrite should not use mdc_base, got:\n{s}" + ); + Ok(()) + } + #[test] fn does_not_rewrite_single_count_distinct() -> Result<()> { let table_scan = test_table_scan()?; From 153bb55372bfcb630cdc2dda05d6a4629e4c05c3 Mon Sep 17 00:00:00 2001 From: yash Date: Sat, 21 Mar 2026 19:52:35 +0530 Subject: [PATCH 3/7] fix: preserve aggregate order in multi-distinct COUNT rewrite and add tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✅ Fix projection after join so output columns match the original aggregate list when COUNT(DISTINCT …) and non-distinct aggs are interleaved (schema-compatible with mixed BI-style queries). ✅ Add internal_err guard for inconsistent aggregate index mapping. ✅ Optimizer tests: three grouped COUNT(DISTINCT), non-distinct between distincts, CAST(distinct) args, no rewrite for GROUPING SETS. ✅ SQL integration: COUNT(*) + two COUNT(DISTINCT); two GROUP BY keys with expected results. ❌ Grouping-set / filtered-distinct cases remain explicitly out of scope for this rule (covered by unchanged-plan tests where applicable). Made-with: Cursor --- .../multi_distinct_count_rewrite.rs | 75 +++++++ .../src/multi_distinct_count_rewrite.rs | 212 ++++++++++++++++-- 2 files changed, 265 insertions(+), 22 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs index 4a398cd3cd9d7..c807bd64bee31 100644 --- a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs +++ b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs @@ -56,3 +56,78 @@ async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> { ); Ok(()) } + +/// `COUNT(*)` + two `COUNT(DISTINCT …)` per group (BI-style); must match non-rewritten semantics. +#[tokio::test] +async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![1, 2, 1])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g, COUNT(*) AS n, COUNT(DISTINCT b) AS db, COUNT(DISTINCT c) AS dc \ + FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+---+----+----+\n\ + | g | n | db | dc |\n\ + +---+---+----+----+\n\ + | 1 | 3 | 2 | 3 |\n\ + +---+---+----+----+" + ); + Ok(()) +} + +/// Multiple `GROUP BY` keys: join must align on all keys. +#[tokio::test] +async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g1", DataType::Int32, false), + Field::new("g2", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![1, 1, 2])), + Arc::new(Int32Array::from(vec![1, 1, 3])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g1, g2, COUNT(DISTINCT b) AS db, COUNT(DISTINCT c) AS dc \ + FROM t GROUP BY g1, g2"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+----+----+----+----+\n\ + | g1 | g2 | db | dc |\n\ + +----+----+----+----+\n\ + | 1 | 1 | 1 | 2 |\n\ + | 1 | 2 | 1 | 1 |\n\ + +----+----+----+----+" + ); + Ok(()) +} diff --git a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs index caaf4a615f3c9..57d2f9d589fbb 100644 --- a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs +++ b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs @@ -25,7 +25,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{ - Column, JoinConstraint, NullEquality, Result, tree_node::Transformed, + Column, JoinConstraint, NullEquality, Result, internal_err, tree_node::Transformed, }; use datafusion_expr::builder::project; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; @@ -205,10 +205,9 @@ impl OptimizerRule for MultiDistinctCountRewrite { base_aggr_exprs, )?); let base_alias = config.alias_generator().next("mdc_base"); - Some(Arc::new(LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - Arc::new(base_plan), - &base_alias, - )?))) + Some(Arc::new(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(Arc::new(base_plan), &base_alias)?, + ))) }; let mut current = base_plan_opt; @@ -285,9 +284,12 @@ impl OptimizerRule for MultiDistinctCountRewrite { }; } - let current = current.expect("distinct_list non-empty implies at least one branch"); + let current = + current.expect("distinct_list non-empty implies at least one branch"); let join_schema = current.schema(); + let base_field_count = group_size + other_list.len(); + let mut proj_exprs: Vec = vec![]; for i in 0..group_size { let (q, f) = schema.qualified_field(i); @@ -296,23 +298,36 @@ impl OptimizerRule for MultiDistinctCountRewrite { let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); } - for (field_idx, (_, schema_aggr_idx)) in other_list.iter().enumerate() { - let (q, f) = schema.qualified_field(*schema_aggr_idx); + // Preserve original aggregate column order (distinct and non-distinct may be interleaved). + for aggr_i in 0..aggr_expr.len() { + let schema_idx = group_size + aggr_i; + let (q, f) = schema.qualified_field(schema_idx); let orig_name = f.name(); - let join_idx = group_size + field_idx; - let (join_q, join_f) = join_schema.qualified_field(join_idx); - let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); - proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); - } - let base_field_count = group_size + other_list.len(); - for (idx, (_, schema_aggr_idx, _)) in distinct_list.iter().enumerate() { - let (q, f) = schema.qualified_field(*schema_aggr_idx); - let orig_name = f.name(); - let branch_start_idx = base_field_count + idx * (group_size + 1); - let branch_aggr_idx = branch_start_idx + group_size; - let (join_q, join_f) = join_schema.qualified_field(branch_aggr_idx); - let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); - proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + + if let Some((dist_idx, (_, _, _))) = distinct_list + .iter() + .enumerate() + .find(|(_, (_, idx, _))| *idx == schema_idx) + { + let branch_start_idx = base_field_count + dist_idx * (group_size + 1); + let branch_aggr_idx = branch_start_idx + group_size; + let (join_q, join_f) = join_schema.qualified_field(branch_aggr_idx); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } else if let Some((other_idx, _)) = other_list + .iter() + .enumerate() + .find(|(_, (_, idx))| *idx == schema_idx) + { + let join_idx = group_size + other_idx; + let (join_q, join_f) = join_schema.qualified_field(join_idx); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } else { + return internal_err!( + "aggregate index {aggr_i} (schema index {schema_idx}) is neither distinct nor other" + ); + } } let out = project((*current).clone(), proj_exprs)?; @@ -327,8 +342,13 @@ mod tests { use crate::OptimizerContext; use crate::OptimizerRule; use crate::test::*; + use arrow::datatypes::DataType; + use datafusion_expr::GroupingSet; use datafusion_expr::LogicalPlan; + use datafusion_expr::expr_fn::cast; + use datafusion_expr::logical_plan::Aggregate; use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; + use datafusion_expr::{Expr, col}; use datafusion_functions_aggregate::expr_fn::{count, count_distinct}; fn optimize_with_rule( @@ -403,6 +423,53 @@ mod tests { Ok(()) } + /// Grouped query with multiple `COUNT(DISTINCT …)` **and** non-distinct aggregates (typical BI). + /// Non-distinct aggs live in `mdc_base`; each distinct column gets a branch + join on keys. + #[test] + fn rewrites_two_count_distinct_with_non_distinct_count() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")), + count_distinct(col("c")), + count(col("a")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base aggregate for non-distinct aggs, got:\n{s}" + ); + Ok(()) + } + + #[test] + fn does_not_rewrite_two_count_distinct_same_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")).alias("cd1"), + count_distinct(col("b")).alias("cd2"), + ], + )? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } + #[test] fn does_not_rewrite_single_count_distinct() -> Result<()> { let table_scan = test_table_scan()?; @@ -417,6 +484,107 @@ mod tests { Ok(()) } + #[test] + fn rewrites_three_count_distinct_grouped() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")), + count_distinct(col("c")), + count_distinct(col("a")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!( + s.matches("Inner Join").count() >= 2, + "expected two joins for three branches, got:\n{s}" + ); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base aggregate, got:\n{s}" + ); + Ok(()) + } + + #[test] + fn rewrites_interleaved_non_distinct_between_distincts() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")), + count(col("a")), + count_distinct(col("c")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base for middle count(a), got:\n{s}" + ); + Ok(()) + } + + #[test] + fn rewrites_count_distinct_on_cast_exprs() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(cast(col("b"), DataType::Int64)), + count_distinct(cast(col("c"), DataType::Int64)), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("Filter: CAST(test.b AS Int64) IS NOT NULL"), + "expected null filter on cast(b), got:\n{s}" + ); + assert!( + s.contains("Filter: CAST(test.c AS Int64) IS NOT NULL"), + "expected null filter on cast(c), got:\n{s}" + ); + Ok(()) + } + + #[test] + fn does_not_rewrite_grouping_sets_multi_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let group_expr = vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![vec![ + col("a"), + ]]))]; + let aggr_expr = vec![count_distinct(col("b")), count_distinct(col("c"))]; + let plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(table_scan), + group_expr, + aggr_expr, + )?); + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } + #[test] fn does_not_rewrite_mixed_agg() -> Result<()> { let table_scan = test_table_scan()?; From 0b9585178364abca47dbe151688efe6b6ea3cd92 Mon Sep 17 00:00:00 2001 From: yash Date: Sat, 21 Mar 2026 20:21:34 +0530 Subject: [PATCH 4/7] test: SQL coverage for COUNT(DISTINCT lower/CAST) with multi-distinct rewrite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✅ End-to-end: COUNT(DISTINCT lower(b)) with 'Abc'/'aBC' plus second distinct on c (case collapse = 1). ✅ End-to-end: COUNT(DISTINCT CAST(x AS INT)) with 1.2/1.3 vs second CAST distinct on y (int collision = 1). ✅ docs: REPLY_PR_20940_DANDANDAN.md — integration table rows + note on safe distinct args (lower/upper/CAST). ❌ No optimizer or engine behavior change; asserts semantics match non-rewritten aggregation. Made-with: Cursor --- .../multi_distinct_count_rewrite.rs | 76 ++++++++++++++++- docs/improvements/REPLY_PR_20940_DANDANDAN.md | 82 +++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 docs/improvements/REPLY_PR_20940_DANDANDAN.md diff --git a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs index c807bd64bee31..bef4cd41f060e 100644 --- a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs +++ b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs @@ -18,7 +18,7 @@ //! End-to-end SQL tests for the multi-`COUNT(DISTINCT)` logical optimizer rewrite. use super::*; -use arrow::array::{Int32Array, StringArray}; +use arrow::array::{Float64Array, Int32Array, StringArray}; use datafusion::common::test_util::batches_to_sort_string; use datafusion_catalog::MemTable; @@ -131,3 +131,77 @@ async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> { ); Ok(()) } + +/// `COUNT(DISTINCT lower(b))` with `'Abc'` / `'aBC'`: distinct is on the **lowered** value (one bucket). +/// Two `COUNT(DISTINCT …)` so the rewrite applies; semantics match plain aggregation. +#[tokio::test] +async fn multi_count_distinct_lower_matches_expected_case_collapsing() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(StringArray::from(vec!["Abc", "aBC"])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g, COUNT(DISTINCT lower(b)) AS lb, COUNT(DISTINCT c) AS cc \ + FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+----+----+\n\ + | g | lb | cc |\n\ + +---+----+----+\n\ + | 1 | 1 | 2 |\n\ + +---+----+----+" + ); + Ok(()) +} + +/// `COUNT(DISTINCT CAST(x AS INT))` with `1.2` and `1.3`: both truncate to `1` → one distinct. +/// Exercises the same “expression in distinct, not raw column” path as `CAST` in the rule. +#[tokio::test] +async fn multi_count_distinct_cast_float_to_int_collapses_nearby_values() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Float64Array::from(vec![1.2, 1.3])), + Arc::new(Float64Array::from(vec![10.0, 20.0])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g, COUNT(DISTINCT CAST(x AS INT)) AS cx, COUNT(DISTINCT CAST(y AS INT)) AS cy \ + FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+----+----+\n\ + | g | cx | cy |\n\ + +---+----+----+\n\ + | 1 | 1 | 2 |\n\ + +---+----+----+" + ); + Ok(()) +} diff --git a/docs/improvements/REPLY_PR_20940_DANDANDAN.md b/docs/improvements/REPLY_PR_20940_DANDANDAN.md new file mode 100644 index 0000000000000..eb3524a609a0f --- /dev/null +++ b/docs/improvements/REPLY_PR_20940_DANDANDAN.md @@ -0,0 +1,82 @@ +# Draft reply: PR #20940 (`MultiDistinctToCrossJoin`) vs `MultiDistinctCountRewrite` + +Paste into [apache/datafusion#20940](https://github.com/apache/datafusion/pull/20940) as a comment (GitHub-flavored Markdown). + +--- + +Hi @Dandandan — thanks for this work; the cross-join split for **multiple distinct aggregates with no `GROUP BY`** is a strong fit for workloads like ClickBench extended. + +I’ve been working on a related but **different** pattern: **`GROUP BY` + several `COUNT(DISTINCT …)`** in the same aggregate (typical BI). In that situation, your rule **does not apply**, because `MultiDistinctToCrossJoin` needs an **empty** `GROUP BY` and **all** aggregates to be distinct on different columns. + +A concrete example from our benchmark suite (**category `08_complex_analytical`, query `Q8.3`**) on an `orders_data` table: + +```sql +-- Q8.3: Seller performance analysis +SELECT + seller_name, + COUNT(*) as total_orders, + COUNT(DISTINCT delivery_city) as cities_served, + COUNT(DISTINCT state) as states_served, + SUM(CASE WHEN order_status = 'Completed' THEN 1 ELSE 0 END) as completed_orders, + SUM(CASE WHEN order_status = 'Cancelled' THEN 1 ELSE 0 END) as cancelled_orders, + ROUND(100.0 * SUM(CASE WHEN order_status = 'Completed' THEN 1 ELSE 0 END) / COUNT(*), 2) as success_rate +FROM orders_data +GROUP BY seller_name +HAVING COUNT(*) > 100 +ORDER BY total_orders DESC +LIMIT 100; +``` + +This is **not** “global” multi-distinct: it’s **per `seller_name`**, with **multiple `COUNT(DISTINCT …)`** plus other aggregates. That’s the class my optimizer rule (`MultiDistinctCountRewrite`) targets — rewriting the **`COUNT(DISTINCT …)`** pieces into **joinable sub-aggregates aligned on the same `GROUP BY` keys**, with correct `NULL` handling where needed. + +So in simple terms: + +| | **Your PR (`MultiDistinctToCrossJoin`)** | **My work (`MultiDistinctCountRewrite`)** | +|---|------------------------------------------|-------------------------------------------| +| **Typical SQL** | `SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t` (no `GROUP BY`) | `SELECT …, COUNT(DISTINCT x), COUNT(DISTINCT y), … FROM t GROUP BY …` | +| **Example workload** | ClickBench extended–style **Q0 / Q1** | Our **Q8.3** (and similar grouped BI queries) | + +They’re **complementary**: different predicates, different plans, and they can **coexist** in the optimizer pipeline (we’d want to sanity-check rule order so we don’t double-rewrite the same node). + +--- + +## Tests for `MultiDistinctCountRewrite` (what they cover) + +### Optimizer unit tests — `datafusion/optimizer/src/multi_distinct_count_rewrite.rs` + +| Test | What it asserts | +|------|-----------------| +| `rewrites_two_count_distinct` | `GROUP BY a` + `COUNT(DISTINCT b)`, `COUNT(DISTINCT c)` → inner joins, per-branch null filters on `b`/`c`, `mdc_base` + two `mdc_d` aliases. | +| `rewrites_global_three_count_distinct` | No `GROUP BY`, three `COUNT(DISTINCT …)` → cross/inner join rewrite; **no** `mdc_base` (global-only path). | +| `rewrites_two_count_distinct_with_non_distinct_count` | Grouped BI-style: two distincts + `COUNT(a)` → join rewrite with **`mdc_base`** holding the non-distinct agg. | +| `does_not_rewrite_two_count_distinct_same_column` | Two `COUNT(DISTINCT b)` with different aliases → **no** rewrite (duplicate distinct key). | +| `does_not_rewrite_single_count_distinct` | Only one `COUNT(DISTINCT …)` → **no** rewrite (rule needs ≥2 distincts). | +| `rewrites_three_count_distinct_grouped` | Three grouped `COUNT(DISTINCT …)` on `b`, `c`, `a` → **two** inner joins + `mdc_base`. | +| `rewrites_interleaved_non_distinct_between_distincts` | Order `COUNT(DISTINCT b)`, `COUNT(a)`, `COUNT(DISTINCT c)` → rewrite + `mdc_base` for the middle non-distinct agg (projection order / interleaving). | +| `rewrites_count_distinct_on_cast_exprs` | `COUNT(DISTINCT CAST(b AS Int64))`, same for `c` → rewrite + null filters on the **cast** expressions. | +| `does_not_rewrite_grouping_sets_multi_distinct` | `GROUPING SETS` aggregate with two `COUNT(DISTINCT …)` → **no** rewrite (rule bails on grouping sets). | +| `does_not_rewrite_mixed_agg` | `COUNT(DISTINCT b)` + `COUNT(c)` → **no** rewrite (only **one** `COUNT(DISTINCT …)`; rule requires at least two). | + +### SQL integration — `datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs` + +| Test | What it asserts | +|------|-----------------| +| `multi_count_distinct_matches_expected_with_nulls` | End-to-end grouped two `COUNT(DISTINCT …)` with **NULLs** in distinct columns; exact sorted batch string vs expected counts. | +| `multi_count_distinct_with_count_star_matches_expected` | `COUNT(*)` plus two `COUNT(DISTINCT …)` per group (BI-style); exact result table. | +| `multi_count_distinct_two_group_keys_matches_expected` | **`GROUP BY g1, g2`** + two distincts; verifies joins line up on **all** group keys and numerics match. | +| `multi_count_distinct_lower_matches_expected_case_collapsing` | `COUNT(DISTINCT lower(b))` with `'Abc'` / `'aBC'` plus a second distinct on `c` → **one** distinct lowered value, **two** raw `c` values (semantics follow the expression inside `COUNT(DISTINCT …)`, not raw `b`). | +| `multi_count_distinct_cast_float_to_int_collapses_nearby_values` | `COUNT(DISTINCT CAST(x AS INT))` with `1.2` / `1.3` (both → `1`) vs a second distinct on `y` → exercises **cast collision** the same way as the logical-plan `CAST(column)` tests. | + +### Note: `lower(b)` and `CAST` inside `COUNT(DISTINCT …)` (reviewer question) + +The rule only rewrites when each distinct aggregate is a **simple** `COUNT(DISTINCT expr)` with `expr` that is: + +- a column, +- `lower`/`upper` of one column, or +- `CAST` of one column (non-volatile). + +The rewrite **does not change SQL semantics**: distinct is computed on the **evaluated** values of that expression (so `'Abc'` and `'aBC'` under `lower(b)` collapse to one distinct; `1.2` and `1.3` under `CAST(x AS INT)` collapse to one distinct). The two SQL tests above lock that in end-to-end alongside the multi-distinct rewrite. + +--- + +Happy to align naming, tests, and placement with you and the maintainers. From bfbe2865755858e6708a3faf6ef58da7540574ad Mon Sep 17 00:00:00 2001 From: yash Date: Mon, 23 Mar 2026 10:55:26 +0530 Subject: [PATCH 5/7] feat: gate multi-distinct COUNT rewrite behind session config (default off) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✅ Add datafusion.optimizer.enable_multi_distinct_count_rewrite (default false). ✅ MultiDistinctCountRewrite no-ops when disabled; OptimizerContext::with_enable_multi_distinct_count_rewrite for tests. ✅ SQL integration tests enable the flag via session helper; unit test skips_rewrite_when_config_disabled. ✅ Document option in user-guide configs.md. ❌ Does not change rewrite semantics when enabled. Made-with: Cursor --- datafusion/common/src/config.rs | 5 +++ .../multi_distinct_count_rewrite.rs | 19 ++++++--- .../src/multi_distinct_count_rewrite.rs | 41 ++++++++++++++++++- datafusion/optimizer/src/optimizer.rs | 8 ++++ docs/source/user-guide/configs.md | 1 + 5 files changed, 67 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index e6d5fbfc50a21..ace242930da6f 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -969,6 +969,11 @@ config_namespace! { /// predicate push down. pub filter_null_join_keys: bool, default = false + /// When `true`, rewrite one grouped aggregate that has multiple `COUNT(DISTINCT …)` into + /// joins of per-distinct sub-aggregates (can lower peak memory; adds join work). Default + /// `false` until workload benchmarks justify enabling broadly. + pub enable_multi_distinct_count_rewrite: bool, default = false + /// Should DataFusion repartition data using the aggregate keys to execute aggregates /// in parallel using the provided `target_partitions` level pub repartition_aggregations: bool, default = true diff --git a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs index bef4cd41f060e..64e3680cf1639 100644 --- a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs +++ b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs @@ -20,11 +20,20 @@ use super::*; use arrow::array::{Float64Array, Int32Array, StringArray}; use datafusion::common::test_util::batches_to_sort_string; +use datafusion::execution::config::SessionConfig; +use datafusion::execution::context::SessionContext; use datafusion_catalog::MemTable; +fn session_with_multi_distinct_count_rewrite() -> SessionContext { + SessionContext::new_with_config(SessionConfig::new().set_bool( + "datafusion.optimizer.enable_multi_distinct_count_rewrite", + true, + )) +} + #[tokio::test] async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = session_with_multi_distinct_count_rewrite(); let schema = Arc::new(Schema::new(vec![ Field::new("g", DataType::Int32, false), Field::new("b", DataType::Utf8, true), @@ -60,7 +69,7 @@ async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> { /// `COUNT(*)` + two `COUNT(DISTINCT …)` per group (BI-style); must match non-rewritten semantics. #[tokio::test] async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = session_with_multi_distinct_count_rewrite(); let schema = Arc::new(Schema::new(vec![ Field::new("g", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -96,7 +105,7 @@ async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> { /// Multiple `GROUP BY` keys: join must align on all keys. #[tokio::test] async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = session_with_multi_distinct_count_rewrite(); let schema = Arc::new(Schema::new(vec![ Field::new("g1", DataType::Int32, false), Field::new("g2", DataType::Int32, false), @@ -136,7 +145,7 @@ async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> { /// Two `COUNT(DISTINCT …)` so the rewrite applies; semantics match plain aggregation. #[tokio::test] async fn multi_count_distinct_lower_matches_expected_case_collapsing() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = session_with_multi_distinct_count_rewrite(); let schema = Arc::new(Schema::new(vec![ Field::new("g", DataType::Int32, false), Field::new("b", DataType::Utf8, false), @@ -173,7 +182,7 @@ async fn multi_count_distinct_lower_matches_expected_case_collapsing() -> Result /// Exercises the same “expression in distinct, not raw column” path as `CAST` in the rule. #[tokio::test] async fn multi_count_distinct_cast_float_to_int_collapses_nearby_values() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = session_with_multi_distinct_count_rewrite(); let schema = Arc::new(Schema::new(vec![ Field::new("g", DataType::Int32, false), Field::new("x", DataType::Float64, false), diff --git a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs index 57d2f9d589fbb..12daebff83d27 100644 --- a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs +++ b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs @@ -113,6 +113,14 @@ impl OptimizerRule for MultiDistinctCountRewrite { plan: LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { + if !config + .options() + .optimizer + .enable_multi_distinct_count_rewrite + { + return Ok(Transformed::no(plan)); + } + let LogicalPlan::Aggregate(Aggregate { input, aggr_expr, @@ -351,17 +359,27 @@ mod tests { use datafusion_expr::{Expr, col}; use datafusion_functions_aggregate::expr_fn::{count, count_distinct}; - fn optimize_with_rule( + fn optimize_with_rule_config( plan: LogicalPlan, rule: Arc, + enable_multi_distinct_count_rewrite: bool, ) -> Result { Optimizer::with_rules(vec![rule]).optimize( plan, - &OptimizerContext::new(), + &OptimizerContext::new().with_enable_multi_distinct_count_rewrite( + enable_multi_distinct_count_rewrite, + ), |_, _| {}, ) } + fn optimize_with_rule( + plan: LogicalPlan, + rule: Arc, + ) -> Result { + optimize_with_rule_config(plan, rule, true) + } + #[test] fn rewrites_two_count_distinct() -> Result<()> { let table_scan = test_table_scan()?; @@ -585,6 +603,25 @@ mod tests { Ok(()) } + #[test] + fn skips_rewrite_when_config_disabled() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count_distinct(col("b")), count_distinct(col("c"))], + )? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = optimize_with_rule_config( + plan, + Arc::new(MultiDistinctCountRewrite::new()), + false, + )?; + assert_eq!(before, optimized.display_indent_schema().to_string()); + Ok(()) + } + #[test] fn does_not_rewrite_mixed_agg() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 08d5bbfbb424e..9d626d789dafd 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -218,6 +218,14 @@ impl OptimizerContext { Arc::make_mut(&mut self.options).optimizer.max_passes = v as usize; self } + + /// Enable [`crate::multi_distinct_count_rewrite::MultiDistinctCountRewrite`] (default off). + pub fn with_enable_multi_distinct_count_rewrite(mut self, enable: bool) -> Self { + Arc::make_mut(&mut self.options) + .optimizer + .enable_multi_distinct_count_rewrite = enable; + self + } } impl Default for OptimizerContext { diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 56ab4d1539f92..99214645420a2 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -144,6 +144,7 @@ The following configuration settings are available: | datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down Aggregate dynamic filters into the file scan phase. | | datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.enable_multi_distinct_count_rewrite | false | When set to true, the optimizer may rewrite a single aggregate with multiple `COUNT(DISTINCT …)` (with `GROUP BY`) into joins of per-distinct sub-aggregates. This can reduce peak memory but adds join work; default off until benchmarks support enabling broadly. | | datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | | datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | | datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | From c64db4b522ea6dff024f3cc9d24f7858789eed0e Mon Sep 17 00:00:00 2001 From: Yash Gandhi Date: Mon, 23 Mar 2026 14:34:13 +0530 Subject: [PATCH 6/7] Delete docs/improvements/REPLY_PR_20940_DANDANDAN.md --- docs/improvements/REPLY_PR_20940_DANDANDAN.md | 82 ------------------- 1 file changed, 82 deletions(-) delete mode 100644 docs/improvements/REPLY_PR_20940_DANDANDAN.md diff --git a/docs/improvements/REPLY_PR_20940_DANDANDAN.md b/docs/improvements/REPLY_PR_20940_DANDANDAN.md deleted file mode 100644 index eb3524a609a0f..0000000000000 --- a/docs/improvements/REPLY_PR_20940_DANDANDAN.md +++ /dev/null @@ -1,82 +0,0 @@ -# Draft reply: PR #20940 (`MultiDistinctToCrossJoin`) vs `MultiDistinctCountRewrite` - -Paste into [apache/datafusion#20940](https://github.com/apache/datafusion/pull/20940) as a comment (GitHub-flavored Markdown). - ---- - -Hi @Dandandan — thanks for this work; the cross-join split for **multiple distinct aggregates with no `GROUP BY`** is a strong fit for workloads like ClickBench extended. - -I’ve been working on a related but **different** pattern: **`GROUP BY` + several `COUNT(DISTINCT …)`** in the same aggregate (typical BI). In that situation, your rule **does not apply**, because `MultiDistinctToCrossJoin` needs an **empty** `GROUP BY` and **all** aggregates to be distinct on different columns. - -A concrete example from our benchmark suite (**category `08_complex_analytical`, query `Q8.3`**) on an `orders_data` table: - -```sql --- Q8.3: Seller performance analysis -SELECT - seller_name, - COUNT(*) as total_orders, - COUNT(DISTINCT delivery_city) as cities_served, - COUNT(DISTINCT state) as states_served, - SUM(CASE WHEN order_status = 'Completed' THEN 1 ELSE 0 END) as completed_orders, - SUM(CASE WHEN order_status = 'Cancelled' THEN 1 ELSE 0 END) as cancelled_orders, - ROUND(100.0 * SUM(CASE WHEN order_status = 'Completed' THEN 1 ELSE 0 END) / COUNT(*), 2) as success_rate -FROM orders_data -GROUP BY seller_name -HAVING COUNT(*) > 100 -ORDER BY total_orders DESC -LIMIT 100; -``` - -This is **not** “global” multi-distinct: it’s **per `seller_name`**, with **multiple `COUNT(DISTINCT …)`** plus other aggregates. That’s the class my optimizer rule (`MultiDistinctCountRewrite`) targets — rewriting the **`COUNT(DISTINCT …)`** pieces into **joinable sub-aggregates aligned on the same `GROUP BY` keys**, with correct `NULL` handling where needed. - -So in simple terms: - -| | **Your PR (`MultiDistinctToCrossJoin`)** | **My work (`MultiDistinctCountRewrite`)** | -|---|------------------------------------------|-------------------------------------------| -| **Typical SQL** | `SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t` (no `GROUP BY`) | `SELECT …, COUNT(DISTINCT x), COUNT(DISTINCT y), … FROM t GROUP BY …` | -| **Example workload** | ClickBench extended–style **Q0 / Q1** | Our **Q8.3** (and similar grouped BI queries) | - -They’re **complementary**: different predicates, different plans, and they can **coexist** in the optimizer pipeline (we’d want to sanity-check rule order so we don’t double-rewrite the same node). - ---- - -## Tests for `MultiDistinctCountRewrite` (what they cover) - -### Optimizer unit tests — `datafusion/optimizer/src/multi_distinct_count_rewrite.rs` - -| Test | What it asserts | -|------|-----------------| -| `rewrites_two_count_distinct` | `GROUP BY a` + `COUNT(DISTINCT b)`, `COUNT(DISTINCT c)` → inner joins, per-branch null filters on `b`/`c`, `mdc_base` + two `mdc_d` aliases. | -| `rewrites_global_three_count_distinct` | No `GROUP BY`, three `COUNT(DISTINCT …)` → cross/inner join rewrite; **no** `mdc_base` (global-only path). | -| `rewrites_two_count_distinct_with_non_distinct_count` | Grouped BI-style: two distincts + `COUNT(a)` → join rewrite with **`mdc_base`** holding the non-distinct agg. | -| `does_not_rewrite_two_count_distinct_same_column` | Two `COUNT(DISTINCT b)` with different aliases → **no** rewrite (duplicate distinct key). | -| `does_not_rewrite_single_count_distinct` | Only one `COUNT(DISTINCT …)` → **no** rewrite (rule needs ≥2 distincts). | -| `rewrites_three_count_distinct_grouped` | Three grouped `COUNT(DISTINCT …)` on `b`, `c`, `a` → **two** inner joins + `mdc_base`. | -| `rewrites_interleaved_non_distinct_between_distincts` | Order `COUNT(DISTINCT b)`, `COUNT(a)`, `COUNT(DISTINCT c)` → rewrite + `mdc_base` for the middle non-distinct agg (projection order / interleaving). | -| `rewrites_count_distinct_on_cast_exprs` | `COUNT(DISTINCT CAST(b AS Int64))`, same for `c` → rewrite + null filters on the **cast** expressions. | -| `does_not_rewrite_grouping_sets_multi_distinct` | `GROUPING SETS` aggregate with two `COUNT(DISTINCT …)` → **no** rewrite (rule bails on grouping sets). | -| `does_not_rewrite_mixed_agg` | `COUNT(DISTINCT b)` + `COUNT(c)` → **no** rewrite (only **one** `COUNT(DISTINCT …)`; rule requires at least two). | - -### SQL integration — `datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs` - -| Test | What it asserts | -|------|-----------------| -| `multi_count_distinct_matches_expected_with_nulls` | End-to-end grouped two `COUNT(DISTINCT …)` with **NULLs** in distinct columns; exact sorted batch string vs expected counts. | -| `multi_count_distinct_with_count_star_matches_expected` | `COUNT(*)` plus two `COUNT(DISTINCT …)` per group (BI-style); exact result table. | -| `multi_count_distinct_two_group_keys_matches_expected` | **`GROUP BY g1, g2`** + two distincts; verifies joins line up on **all** group keys and numerics match. | -| `multi_count_distinct_lower_matches_expected_case_collapsing` | `COUNT(DISTINCT lower(b))` with `'Abc'` / `'aBC'` plus a second distinct on `c` → **one** distinct lowered value, **two** raw `c` values (semantics follow the expression inside `COUNT(DISTINCT …)`, not raw `b`). | -| `multi_count_distinct_cast_float_to_int_collapses_nearby_values` | `COUNT(DISTINCT CAST(x AS INT))` with `1.2` / `1.3` (both → `1`) vs a second distinct on `y` → exercises **cast collision** the same way as the logical-plan `CAST(column)` tests. | - -### Note: `lower(b)` and `CAST` inside `COUNT(DISTINCT …)` (reviewer question) - -The rule only rewrites when each distinct aggregate is a **simple** `COUNT(DISTINCT expr)` with `expr` that is: - -- a column, -- `lower`/`upper` of one column, or -- `CAST` of one column (non-volatile). - -The rewrite **does not change SQL semantics**: distinct is computed on the **evaluated** values of that expression (so `'Abc'` and `'aBC'` under `lower(b)` collapse to one distinct; `1.2` and `1.3` under `CAST(x AS INT)` collapse to one distinct). The two SQL tests above lock that in end-to-end alongside the multi-distinct rewrite. - ---- - -Happy to align naming, tests, and placement with you and the maintainers. From 1637b2d5213d6520131ab24d46e89557528d61d4 Mon Sep 17 00:00:00 2001 From: yash Date: Thu, 26 Mar 2026 17:27:56 +0530 Subject: [PATCH 7/7] docs: clarify multi DISTINCT rewrite memory trade-offs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ✅ Document `datafusion.optimizer.enable_multi_distinct_count_rewrite` as potentially reducing *aggregate-state* memory in some cases ❌ Avoid implying peak memory reduction is guaranteed (extra joins can add overhead) Made-with: Cursor --- docs/source/user-guide/configs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 99214645420a2..519978074fcec 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -144,7 +144,7 @@ The following configuration settings are available: | datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down Aggregate dynamic filters into the file scan phase. | | datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.enable_multi_distinct_count_rewrite | false | When set to true, the optimizer may rewrite a single aggregate with multiple `COUNT(DISTINCT …)` (with `GROUP BY`) into joins of per-distinct sub-aggregates. This can reduce peak memory but adds join work; default off until benchmarks support enabling broadly. | +| datafusion.optimizer.enable_multi_distinct_count_rewrite | false | When set to true, the optimizer may rewrite a single aggregate with multiple `COUNT(DISTINCT …)` (with `GROUP BY`) into joins of per-distinct sub-aggregates. This can reduce aggregate-state memory in some cases by splitting distincts into separate sub-aggregates, at the cost of additional joins; default off until benchmarks support enabling broadly. | | datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | | datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | | datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level |