diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index e6d5fbfc50a21..ace242930da6f 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -969,6 +969,11 @@ config_namespace! { /// predicate push down. pub filter_null_join_keys: bool, default = false + /// When `true`, rewrite one grouped aggregate that has multiple `COUNT(DISTINCT …)` into + /// joins of per-distinct sub-aggregates (can lower peak memory; adds join work). Default + /// `false` until workload benchmarks justify enabling broadly. + pub enable_multi_distinct_count_rewrite: bool, default = false + /// Should DataFusion repartition data using the aggregate keys to execute aggregates /// in parallel using the provided `target_partitions` level pub repartition_aggregations: bool, default = true diff --git a/datafusion/core/tests/sql/aggregates/mod.rs b/datafusion/core/tests/sql/aggregates/mod.rs index ede40d5c4ceca..ffdb886e89c0b 100644 --- a/datafusion/core/tests/sql/aggregates/mod.rs +++ b/datafusion/core/tests/sql/aggregates/mod.rs @@ -1024,3 +1024,4 @@ pub fn split_fuzz_timestamp_data_into_batches( pub mod basic; pub mod dict_nulls; +pub mod multi_distinct_count_rewrite; diff --git a/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs new file mode 100644 index 0000000000000..64e3680cf1639 --- /dev/null +++ b/datafusion/core/tests/sql/aggregates/multi_distinct_count_rewrite.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! End-to-end SQL tests for the multi-`COUNT(DISTINCT)` logical optimizer rewrite. + +use super::*; +use arrow::array::{Float64Array, Int32Array, StringArray}; +use datafusion::common::test_util::batches_to_sort_string; +use datafusion::execution::config::SessionConfig; +use datafusion::execution::context::SessionContext; +use datafusion_catalog::MemTable; + +fn session_with_multi_distinct_count_rewrite() -> SessionContext { + SessionContext::new_with_config(SessionConfig::new().set_bool( + "datafusion.optimizer.enable_multi_distinct_count_rewrite", + true, + )) +} + +#[tokio::test] +async fn multi_count_distinct_matches_expected_with_nulls() -> Result<()> { + let ctx = session_with_multi_distinct_count_rewrite(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(StringArray::from(vec![Some("x"), None, Some("x")])), + Arc::new(StringArray::from(vec![None, Some("y"), Some("y")])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = + "SELECT g, COUNT(DISTINCT b) AS cb, COUNT(DISTINCT c) AS cc FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+----+----+\n\ + | g | cb | cc |\n\ + +---+----+----+\n\ + | 1 | 1 | 1 |\n\ + +---+----+----+" + ); + Ok(()) +} + +/// `COUNT(*)` + two `COUNT(DISTINCT …)` per group (BI-style); must match non-rewritten semantics. +#[tokio::test] +async fn multi_count_distinct_with_count_star_matches_expected() -> Result<()> { + let ctx = session_with_multi_distinct_count_rewrite(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![1, 2, 1])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g, COUNT(*) AS n, COUNT(DISTINCT b) AS db, COUNT(DISTINCT c) AS dc \ + FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+---+----+----+\n\ + | g | n | db | dc |\n\ + +---+---+----+----+\n\ + | 1 | 3 | 2 | 3 |\n\ + +---+---+----+----+" + ); + Ok(()) +} + +/// Multiple `GROUP BY` keys: join must align on all keys. +#[tokio::test] +async fn multi_count_distinct_two_group_keys_matches_expected() -> Result<()> { + let ctx = session_with_multi_distinct_count_rewrite(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g1", DataType::Int32, false), + Field::new("g2", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 1])), + Arc::new(Int32Array::from(vec![1, 1, 2])), + Arc::new(Int32Array::from(vec![1, 1, 3])), + Arc::new(Int32Array::from(vec![1, 2, 3])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g1, g2, COUNT(DISTINCT b) AS db, COUNT(DISTINCT c) AS dc \ + FROM t GROUP BY g1, g2"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+----+----+----+----+\n\ + | g1 | g2 | db | dc |\n\ + +----+----+----+----+\n\ + | 1 | 1 | 1 | 2 |\n\ + | 1 | 2 | 1 | 1 |\n\ + +----+----+----+----+" + ); + Ok(()) +} + +/// `COUNT(DISTINCT lower(b))` with `'Abc'` / `'aBC'`: distinct is on the **lowered** value (one bucket). +/// Two `COUNT(DISTINCT …)` so the rewrite applies; semantics match plain aggregation. +#[tokio::test] +async fn multi_count_distinct_lower_matches_expected_case_collapsing() -> Result<()> { + let ctx = session_with_multi_distinct_count_rewrite(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(StringArray::from(vec!["Abc", "aBC"])), + Arc::new(StringArray::from(vec!["x", "y"])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g, COUNT(DISTINCT lower(b)) AS lb, COUNT(DISTINCT c) AS cc \ + FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+----+----+\n\ + | g | lb | cc |\n\ + +---+----+----+\n\ + | 1 | 1 | 2 |\n\ + +---+----+----+" + ); + Ok(()) +} + +/// `COUNT(DISTINCT CAST(x AS INT))` with `1.2` and `1.3`: both truncate to `1` → one distinct. +/// Exercises the same “expression in distinct, not raw column” path as `CAST` in the rule. +#[tokio::test] +async fn multi_count_distinct_cast_float_to_int_collapses_nearby_values() -> Result<()> { + let ctx = session_with_multi_distinct_count_rewrite(); + let schema = Arc::new(Schema::new(vec![ + Field::new("g", DataType::Int32, false), + Field::new("x", DataType::Float64, false), + Field::new("y", DataType::Float64, false), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Float64Array::from(vec![1.2, 1.3])), + Arc::new(Float64Array::from(vec![10.0, 20.0])), + ], + )?; + let provider = MemTable::try_new(schema, vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + let sql = "SELECT g, COUNT(DISTINCT CAST(x AS INT)) AS cx, COUNT(DISTINCT CAST(y AS INT)) AS cy \ + FROM t GROUP BY g"; + let batches = ctx.sql(sql).await?.collect().await?; + let out = batches_to_sort_string(&batches); + + assert_eq!( + out, + "+---+----+----+\n\ + | g | cx | cy |\n\ + +---+----+----+\n\ + | 1 | 1 | 2 |\n\ + +---+----+----+" + ); + Ok(()) +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index e610091824092..bdec72eba7e18 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -59,6 +59,7 @@ pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod extract_leaf_expressions; pub mod filter_null_join_keys; +pub mod multi_distinct_count_rewrite; pub mod optimize_projections; pub mod optimize_unions; pub mod optimizer; diff --git a/datafusion/optimizer/src/multi_distinct_count_rewrite.rs b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs new file mode 100644 index 0000000000000..12daebff83d27 --- /dev/null +++ b/datafusion/optimizer/src/multi_distinct_count_rewrite.rs @@ -0,0 +1,641 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rewrites a single [`Aggregate`] that contains multiple `COUNT(DISTINCT ...)` +//! into a join of smaller aggregates so each distinct is computed with one +//! accumulator set, reducing peak memory for high-cardinality distincts. + +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::{ + Column, JoinConstraint, NullEquality, Result, internal_err, tree_node::Transformed, +}; +use datafusion_expr::builder::project; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::logical_plan::{ + Aggregate, Join, JoinType, LogicalPlan, SubqueryAlias, +}; +use datafusion_expr::{Expr, col, lit, logical_plan::LogicalPlanBuilder}; + +const MAX_DISTINCT_REWRITE_BRANCHES: usize = 8; + +/// Optimizer rule: multiple `COUNT(DISTINCT ...)` → join of per-distinct sub-aggregates. +#[derive(Default, Debug)] +pub struct MultiDistinctCountRewrite {} + +impl MultiDistinctCountRewrite { + /// Create a new rule instance. + pub fn new() -> Self { + Self {} + } + + fn is_simple_count_distinct( + e: &Expr, + ) -> Option<(Arc, Expr)> { + if let Expr::AggregateFunction(AggregateFunction { func, params }) = e { + let AggregateFunctionParams { + distinct, + args, + filter, + order_by, + .. + } = ¶ms; + if func.name().eq_ignore_ascii_case("count") + && *distinct + && args.len() == 1 + && filter.is_none() + && order_by.is_empty() + { + let arg = args.first().cloned()?; + if Self::is_safe_distinct_arg(&arg) { + return Some((Arc::clone(func), arg)); + } + } + } + None + } + + fn is_safe_distinct_arg(e: &Expr) -> bool { + if e.is_volatile() { + return false; + } + match e { + Expr::Column(_) => true, + Expr::ScalarFunction(ScalarFunction { func, args }) => { + matches!(func.name().to_ascii_lowercase().as_str(), "lower" | "upper") + && args.len() == 1 + && matches!(args.first(), Some(Expr::Column(_))) + } + Expr::Cast(cast) => matches!(cast.expr.as_ref(), Expr::Column(_)), + _ => false, + } + } + + fn is_simple_group_expr(e: &Expr) -> bool { + matches!(e, Expr::Column(_)) + } + + fn contains_grouping_set(group_expr: &[Expr]) -> bool { + group_expr + .first() + .is_some_and(|e| matches!(e, Expr::GroupingSet(_))) + } +} + +impl OptimizerRule for MultiDistinctCountRewrite { + fn name(&self) -> &str { + "multi_distinct_count_rewrite" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + if !config + .options() + .optimizer + .enable_multi_distinct_count_rewrite + { + return Ok(Transformed::no(plan)); + } + + let LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + schema, + group_expr, + .. + }) = plan + else { + return Ok(Transformed::no(plan)); + }; + + if Self::contains_grouping_set(&group_expr) { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + if !group_expr.iter().all(Self::is_simple_group_expr) { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + let group_size = group_expr.len(); + let mut distinct_list: Vec<(Expr, usize, Arc)> = + vec![]; + let mut other_list: Vec<(Expr, usize)> = vec![]; + + for (i, e) in aggr_expr.iter().enumerate() { + if let Some((func, arg)) = Self::is_simple_count_distinct(e) { + distinct_list.push((arg, group_size + i, func)); + } else { + other_list.push((e.clone(), group_size + i)); + } + } + + if distinct_list.len() < 2 { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + if distinct_list.len() > MAX_DISTINCT_REWRITE_BRANCHES { + return Ok(Transformed::no(LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?))); + } + + { + use std::collections::HashSet; + let mut seen: HashSet<&Expr> = HashSet::new(); + for (arg, _, _) in distinct_list.iter() { + if !seen.insert(arg) { + return Ok(Transformed::no(LogicalPlan::Aggregate( + Aggregate::try_new(input, group_expr, aggr_expr)?, + ))); + } + } + } + + let count_udaf = Arc::clone(&distinct_list[0].2); + + let count_star = Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::clone(&count_udaf), + vec![lit(1i64)], + false, + None, + vec![], + None, + )) + .alias("_cnt"); + + let base_aggr_exprs: Vec = other_list + .iter() + .map(|(e, schema_idx)| { + let (q, f) = schema.qualified_field(*schema_idx); + e.clone().alias_qualified(q.cloned(), f.name()) + }) + .collect(); + + // `Aggregate` must have at least one of grouping exprs or aggregate exprs. + // Global multi-`COUNT(DISTINCT)` (no GROUP BY, no other aggs) has neither — skip a base node. + let base_plan_opt: Option> = + if group_expr.is_empty() && other_list.is_empty() { + None + } else { + let base_plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::clone(&input), + group_expr.clone(), + base_aggr_exprs, + )?); + let base_alias = config.alias_generator().next("mdc_base"); + Some(Arc::new(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(Arc::new(base_plan), &base_alias)?, + ))) + }; + + let mut current = base_plan_opt; + + for (distinct_arg, schema_aggr_idx, _) in distinct_list.iter() { + // COUNT(DISTINCT x) ignores NULLs; filter before grouping by x. + let filtered_input = LogicalPlanBuilder::from(input.as_ref().clone()) + .filter(distinct_arg.clone().is_not_null())? + .build()?; + + let inner_group: Vec = group_expr + .iter() + .cloned() + .chain(std::iter::once(distinct_arg.clone())) + .collect(); + let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(filtered_input), + inner_group, + vec![count_star.clone()], + )?); + + let (_, field) = schema.qualified_field(*schema_aggr_idx); + let outer_name = field.name().clone(); + let outer_aggr = Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::clone(&count_udaf), + vec![col("_cnt")], + false, + None, + vec![], + None, + )) + .alias(outer_name.clone()); + + let branch_plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(inner_agg), + group_expr.clone(), + vec![outer_aggr], + )?); + + let alias_name = config.alias_generator().next("mdc_d"); + let branch_aliased = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(branch_plan), + &alias_name, + )?); + + current = match current { + None => Some(Arc::new(branch_aliased)), + Some(prev) => { + let left_schema = prev.schema(); + let right_schema = branch_aliased.schema(); + let join_keys: Vec<(Expr, Expr)> = (0..group_size) + .map(|i| { + let (lq, lf) = left_schema.qualified_field(i); + let (rq, rf) = right_schema.qualified_field(i); + ( + Expr::Column(Column::new(lq.cloned(), lf.name())), + Expr::Column(Column::new(rq.cloned(), rf.name())), + ) + }) + .collect(); + + let join = Join::try_new( + prev, + Arc::new(branch_aliased), + join_keys, + None, + JoinType::Inner, + JoinConstraint::On, + NullEquality::NullEqualsNothing, + false, + )?; + Some(Arc::new(LogicalPlan::Join(join))) + } + }; + } + + let current = + current.expect("distinct_list non-empty implies at least one branch"); + let join_schema = current.schema(); + + let base_field_count = group_size + other_list.len(); + + let mut proj_exprs: Vec = vec![]; + for i in 0..group_size { + let (q, f) = schema.qualified_field(i); + let orig_name = f.name(); + let (join_q, join_f) = join_schema.qualified_field(i); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } + // Preserve original aggregate column order (distinct and non-distinct may be interleaved). + for aggr_i in 0..aggr_expr.len() { + let schema_idx = group_size + aggr_i; + let (q, f) = schema.qualified_field(schema_idx); + let orig_name = f.name(); + + if let Some((dist_idx, (_, _, _))) = distinct_list + .iter() + .enumerate() + .find(|(_, (_, idx, _))| *idx == schema_idx) + { + let branch_start_idx = base_field_count + dist_idx * (group_size + 1); + let branch_aggr_idx = branch_start_idx + group_size; + let (join_q, join_f) = join_schema.qualified_field(branch_aggr_idx); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } else if let Some((other_idx, _)) = other_list + .iter() + .enumerate() + .find(|(_, (_, idx))| *idx == schema_idx) + { + let join_idx = group_size + other_idx; + let (join_q, join_f) = join_schema.qualified_field(join_idx); + let c = Expr::Column(Column::new(join_q.cloned(), join_f.name())); + proj_exprs.push(c.alias_qualified(q.cloned(), orig_name)); + } else { + return internal_err!( + "aggregate index {aggr_i} (schema index {schema_idx}) is neither distinct nor other" + ); + } + } + + let out = project((*current).clone(), proj_exprs)?; + Ok(Transformed::yes(out)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Optimizer; + use crate::OptimizerContext; + use crate::OptimizerRule; + use crate::test::*; + use arrow::datatypes::DataType; + use datafusion_expr::GroupingSet; + use datafusion_expr::LogicalPlan; + use datafusion_expr::expr_fn::cast; + use datafusion_expr::logical_plan::Aggregate; + use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; + use datafusion_expr::{Expr, col}; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct}; + + fn optimize_with_rule_config( + plan: LogicalPlan, + rule: Arc, + enable_multi_distinct_count_rewrite: bool, + ) -> Result { + Optimizer::with_rules(vec![rule]).optimize( + plan, + &OptimizerContext::new().with_enable_multi_distinct_count_rewrite( + enable_multi_distinct_count_rewrite, + ), + |_, _| {}, + ) + } + + fn optimize_with_rule( + plan: LogicalPlan, + rule: Arc, + ) -> Result { + optimize_with_rule_config(plan, rule, true) + } + + #[test] + fn rewrites_two_count_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count_distinct(col("b")), count_distinct(col("c"))], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("Filter: test.b IS NOT NULL"), + "expected null filter on b, got:\n{s}" + ); + assert!( + s.contains("Filter: test.c IS NOT NULL"), + "expected null filter on c, got:\n{s}" + ); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base alias, got:\n{s}" + ); + assert!( + s.matches("SubqueryAlias: mdc_d").count() >= 2, + "expected distinct branches, got:\n{s}" + ); + Ok(()) + } + + #[test] + fn rewrites_global_three_count_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + Vec::::new(), + vec![ + count_distinct(col("a")), + count_distinct(col("b")), + count_distinct(col("c")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!( + s.contains("Cross Join") || s.contains("Inner Join"), + "expected join rewrite for global multi-distinct, got:\n{s}" + ); + assert!( + !s.contains("mdc_base"), + "global-only rewrite should not use mdc_base, got:\n{s}" + ); + Ok(()) + } + + /// Grouped query with multiple `COUNT(DISTINCT …)` **and** non-distinct aggregates (typical BI). + /// Non-distinct aggs live in `mdc_base`; each distinct column gets a branch + join on keys. + #[test] + fn rewrites_two_count_distinct_with_non_distinct_count() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")), + count_distinct(col("c")), + count(col("a")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base aggregate for non-distinct aggs, got:\n{s}" + ); + Ok(()) + } + + #[test] + fn does_not_rewrite_two_count_distinct_same_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")).alias("cd1"), + count_distinct(col("b")).alias("cd2"), + ], + )? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } + + #[test] + fn does_not_rewrite_single_count_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a")], vec![count_distinct(col("b"))])? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } + + #[test] + fn rewrites_three_count_distinct_grouped() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")), + count_distinct(col("c")), + count_distinct(col("a")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!( + s.matches("Inner Join").count() >= 2, + "expected two joins for three branches, got:\n{s}" + ); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base aggregate, got:\n{s}" + ); + Ok(()) + } + + #[test] + fn rewrites_interleaved_non_distinct_between_distincts() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(col("b")), + count(col("a")), + count_distinct(col("c")), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("SubqueryAlias: mdc_base"), + "expected base for middle count(a), got:\n{s}" + ); + Ok(()) + } + + #[test] + fn rewrites_count_distinct_on_cast_exprs() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![ + count_distinct(cast(col("b"), DataType::Int64)), + count_distinct(cast(col("c"), DataType::Int64)), + ], + )? + .build()?; + + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let s = optimized.display_indent_schema().to_string(); + assert!(s.contains("Inner Join"), "expected join rewrite, got:\n{s}"); + assert!( + s.contains("Filter: CAST(test.b AS Int64) IS NOT NULL"), + "expected null filter on cast(b), got:\n{s}" + ); + assert!( + s.contains("Filter: CAST(test.c AS Int64) IS NOT NULL"), + "expected null filter on cast(c), got:\n{s}" + ); + Ok(()) + } + + #[test] + fn does_not_rewrite_grouping_sets_multi_distinct() -> Result<()> { + let table_scan = test_table_scan()?; + let group_expr = vec![Expr::GroupingSet(GroupingSet::GroupingSets(vec![vec![ + col("a"), + ]]))]; + let aggr_expr = vec![count_distinct(col("b")), count_distinct(col("c"))]; + let plan = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(table_scan), + group_expr, + aggr_expr, + )?); + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } + + #[test] + fn skips_rewrite_when_config_disabled() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count_distinct(col("b")), count_distinct(col("c"))], + )? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = optimize_with_rule_config( + plan, + Arc::new(MultiDistinctCountRewrite::new()), + false, + )?; + assert_eq!(before, optimized.display_indent_schema().to_string()); + Ok(()) + } + + #[test] + fn does_not_rewrite_mixed_agg() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count_distinct(col("b")), count(col("c"))], + )? + .build()?; + let before = plan.display_indent_schema().to_string(); + let optimized = + optimize_with_rule(plan, Arc::new(MultiDistinctCountRewrite::new()))?; + let after = optimized.display_indent_schema().to_string(); + assert_eq!(before, after); + Ok(()) + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index bdea6a83072cd..9d626d789dafd 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -45,6 +45,7 @@ use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::extract_leaf_expressions::{ExtractLeafExpressions, PushDownLeafProjections}; use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::multi_distinct_count_rewrite::MultiDistinctCountRewrite; use crate::optimize_projections::OptimizeProjections; use crate::optimize_unions::OptimizeUnions; use crate::plan_signature::LogicalPlanSignature; @@ -217,6 +218,14 @@ impl OptimizerContext { Arc::make_mut(&mut self.options).optimizer.max_passes = v as usize; self } + + /// Enable [`crate::multi_distinct_count_rewrite::MultiDistinctCountRewrite`] (default off). + pub fn with_enable_multi_distinct_count_rewrite(mut self, enable: bool) -> Self { + Arc::make_mut(&mut self.options) + .optimizer + .enable_multi_distinct_count_rewrite = enable; + self + } } impl Default for OptimizerContext { @@ -298,6 +307,7 @@ impl Optimizer { Arc::new(PushDownLimit::new()), Arc::new(PushDownFilter::new()), Arc::new(SingleDistinctToGroupBy::new()), + Arc::new(MultiDistinctCountRewrite::new()), // The previous optimizations added expressions and projections, // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 56ab4d1539f92..519978074fcec 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -144,6 +144,7 @@ The following configuration settings are available: | datafusion.optimizer.enable_aggregate_dynamic_filter_pushdown | true | When set to true, the optimizer will attempt to push down Aggregate dynamic filters into the file scan phase. | | datafusion.optimizer.enable_dynamic_filter_pushdown | true | When set to true attempts to push down dynamic filters generated by operators (TopK, Join & Aggregate) into the file scan phase. For example, for a query such as `SELECT * FROM t ORDER BY timestamp DESC LIMIT 10`, the optimizer will attempt to push down the current top 10 timestamps that the TopK operator references into the file scans. This means that if we already have 10 timestamps in the year 2025 any files that only have timestamps in the year 2024 can be skipped / pruned at various stages in the scan. The config will suppress `enable_join_dynamic_filter_pushdown`, `enable_topk_dynamic_filter_pushdown` & `enable_aggregate_dynamic_filter_pushdown` So if you disable `enable_topk_dynamic_filter_pushdown`, then enable `enable_dynamic_filter_pushdown`, the `enable_topk_dynamic_filter_pushdown` will be overridden. | | datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.enable_multi_distinct_count_rewrite | false | When set to true, the optimizer may rewrite a single aggregate with multiple `COUNT(DISTINCT …)` (with `GROUP BY`) into joins of per-distinct sub-aggregates. This can reduce aggregate-state memory in some cases by splitting distincts into separate sub-aggregates, at the cost of additional joins; default off until benchmarks support enabling broadly. | | datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | | datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | | datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level |