From ff0ada71eccc9b827fc6f3b2f88b4d82ed483fea Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Wed, 18 Mar 2026 10:06:23 +0400 Subject: [PATCH 1/4] DataFrame API: allow aggregate functions in select() (#17874) --- ...es@explain_plan_environment_overrides.snap | 12 ++-- datafusion/core/src/dataframe/mod.rs | 64 +++++++++++++++++-- datafusion/core/tests/dataframe/mod.rs | 24 +++++++ 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 1359cefbe71c7..5f43ca88dc9d7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -18,19 +18,19 @@ exit_code: 0 | logical_plan | [ | | | { | | | "Plan": { | -| | "Expressions": [ | -| | "Int64(123)" | -| | ], | | | "Node Type": "Projection", | -| | "Output": [ | +| | "Expressions": [ | | | "Int64(123)" | | | ], | | | "Plans": [ | | | { | | | "Node Type": "EmptyRelation", | -| | "Output": [], | -| | "Plans": [] | +| | "Plans": [], | +| | "Output": [] | | | } | +| | ], | +| | "Output": [ | +| | "Int64(123)" | | | ] | | | } | | | } | diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2292f5855bfde..1947e25adb467 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -57,6 +57,7 @@ use datafusion_common::{ plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; +use datafusion_expr::utils::find_aggregate_exprs; use datafusion_expr::{ ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, @@ -410,21 +411,76 @@ impl DataFrame { expr_list: impl IntoIterator>, ) -> Result { let expr_list: Vec = - expr_list.into_iter().map(|e| e.into()).collect::>(); + expr_list.into_iter().map(|e| e.into()).collect(); + // Extract plain expressions let expressions = expr_list.iter().filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr), _ => None, }); - let window_func_exprs = find_window_exprs(expressions); - let plan = if window_func_exprs.is_empty() { + // Apply window functions first + let window_func_exprs = find_window_exprs(expressions.clone()); + + let mut plan = if window_func_exprs.is_empty() { self.plan } else { LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + // Collect aggregate expressions + let aggr_exprs = find_aggregate_exprs(expressions.clone()); + + // Check if any expression is non-aggregate + let has_non_aggregate_expr = expressions + .clone() + .any(|expr| find_aggregate_exprs(std::iter::once(expr)).is_empty()); + + // Fallback to projection: + // - already aggregated + // - contains non-aggregate expressions + // - no aggregates at all + if matches!(plan, LogicalPlan::Aggregate(_)) + || has_non_aggregate_expr + || aggr_exprs.is_empty() + { + let project_plan = + LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + + return Ok(DataFrame { + session_state: self.session_state, + plan: project_plan, + projection_requires_validation: false, + }); + } + + // Build Aggregate node + let aggr_exprs: Vec = aggr_exprs + .into_iter() + .enumerate() + .map(|(i, expr)| expr.alias(format!("__agg_{i}"))) + .collect(); + + plan = LogicalPlanBuilder::from(plan) + .aggregate(Vec::::new(), aggr_exprs)? + .build()?; + + // Replace aggregates with their aliases + let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); + for (i, select_expr) in expr_list.into_iter().enumerate() { + match select_expr { + SelectExpr::Expression(expr) => { + let column = Expr::Column(Column::from_name(format!("__agg_{i}"))); + let alias = expr.name_for_alias()?; + rewritten_exprs.push(SelectExpr::Expression(column.alias(alias))); + } + other => rewritten_exprs.push(other), + } + } + + let project_plan = LogicalPlanBuilder::from(plan) + .project(rewritten_exprs)? + .build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 80bbde1f6ba14..5fc67b18b06ed 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -34,6 +34,7 @@ use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; use datafusion_common::metadata::FieldMetadata; +use datafusion_functions_aggregate::approx_distinct::approx_distinct; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -6854,3 +6855,26 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { + let df = test_table().await?; + + let res = df.select(vec![ + approx_distinct(col("c9")).alias("count_c9"), + approx_distinct(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + ])?; + + assert_batches_eq!( + &[ + "+----------+--------------+", + "| count_c9 | count_c9_str |", + "+----------+--------------+", + "| 100 | 100 |", + "+----------+--------------+", + ], + &res.collect().await? + ); + + Ok(()) +} From 1659fa7c8a69242a59695370d56b59f893e02a7d Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Wed, 18 Mar 2026 10:59:44 +0400 Subject: [PATCH 2/4] use count instead of approx_distinct in test --- datafusion/core/tests/dataframe/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 5fc67b18b06ed..9dcc147339166 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -34,7 +34,6 @@ use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; use datafusion_common::metadata::FieldMetadata; -use datafusion_functions_aggregate::approx_distinct::approx_distinct; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -6861,8 +6860,8 @@ async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { let df = test_table().await?; let res = df.select(vec![ - approx_distinct(col("c9")).alias("count_c9"), - approx_distinct(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + count(col("c9")).alias("count_c9"), + count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), ])?; assert_batches_eq!( From f9f351e0adf7cff8b1adce5f7005a1efa42d71e4 Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Wed, 18 Mar 2026 16:29:49 +0400 Subject: [PATCH 3/4] Update CLI snapshot --- ...overrides@explain_plan_environment_overrides.snap | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 5f43ca88dc9d7..1359cefbe71c7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -18,19 +18,19 @@ exit_code: 0 | logical_plan | [ | | | { | | | "Plan": { | -| | "Node Type": "Projection", | | | "Expressions": [ | | | "Int64(123)" | | | ], | +| | "Node Type": "Projection", | +| | "Output": [ | +| | "Int64(123)" | +| | ], | | | "Plans": [ | | | { | | | "Node Type": "EmptyRelation", | -| | "Plans": [], | -| | "Output": [] | +| | "Output": [], | +| | "Plans": [] | | | } | -| | ], | -| | "Output": [ | -| | "Int64(123)" | | | ] | | | } | | | } | From 0262b63d4251227b7b37aa04cc30c5afb1bc738b Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Thu, 19 Mar 2026 17:06:59 +0400 Subject: [PATCH 4/4] fix found bugs and add tests --- datafusion/core/src/dataframe/mod.rs | 42 ++++++++++++----- datafusion/core/tests/dataframe/mod.rs | 65 +++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 18 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 1947e25adb467..da0868f9a09a5 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -51,6 +51,7 @@ use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow_schema::FieldRef; use datafusion_common::config::{CsvOptions, JsonOptions}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, @@ -413,7 +414,7 @@ impl DataFrame { let expr_list: Vec = expr_list.into_iter().map(|e| e.into()).collect(); - // Extract plain expressions + // Extract expressions let expressions = expr_list.iter().filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr), _ => None, @@ -431,7 +432,7 @@ impl DataFrame { // Collect aggregate expressions let aggr_exprs = find_aggregate_exprs(expressions.clone()); - // Check if any expression is non-aggregate + // Check for non-aggregate expressions let has_non_aggregate_expr = expressions .clone() .any(|expr| find_aggregate_exprs(std::iter::once(expr)).is_empty()); @@ -439,7 +440,7 @@ impl DataFrame { // Fallback to projection: // - already aggregated // - contains non-aggregate expressions - // - no aggregates at all + // - no aggregates if matches!(plan, LogicalPlan::Aggregate(_)) || has_non_aggregate_expr || aggr_exprs.is_empty() @@ -454,30 +455,49 @@ impl DataFrame { }); } - // Build Aggregate node - let aggr_exprs: Vec = aggr_exprs + // Assign aliases to aggregate expressions + let mut aggr_map: HashMap = HashMap::new(); + let aggr_exprs_with_alias: Vec = aggr_exprs .into_iter() .enumerate() - .map(|(i, expr)| expr.alias(format!("__agg_{i}"))) + .map(|(i, expr)| { + let alias = format!("__df_agg_{i}"); + let aliased = expr.clone().alias(alias.clone()); + let col = Expr::Column(Column::from_name(alias)); + aggr_map.insert(expr, col); + aliased + }) .collect(); + // Build aggregate plan plan = LogicalPlanBuilder::from(plan) - .aggregate(Vec::::new(), aggr_exprs)? + .aggregate(Vec::::new(), aggr_exprs_with_alias)? .build()?; - // Replace aggregates with their aliases + // Rewrite expressions to use aggregate outputs + let rewrite_expr = |expr: Expr, aggr_map: &HashMap| -> Result { + expr.transform(|e| { + Ok(match aggr_map.get(&e) { + Some(replacement) => Transformed::yes(replacement.clone()), + None => Transformed::no(e), + }) + }) + .map(|t| t.data) + }; + let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); - for (i, select_expr) in expr_list.into_iter().enumerate() { + for select_expr in expr_list.into_iter() { match select_expr { SelectExpr::Expression(expr) => { - let column = Expr::Column(Column::from_name(format!("__agg_{i}"))); + let rewritten = rewrite_expr(expr.clone(), &aggr_map)?; let alias = expr.name_for_alias()?; - rewritten_exprs.push(SelectExpr::Expression(column.alias(alias))); + rewritten_exprs.push(SelectExpr::Expression(rewritten.alias(alias))); } other => rewritten_exprs.push(other), } } + // Final projection let project_plan = LogicalPlanBuilder::from(plan) .project(rewritten_exprs)? .build()?; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 9dcc147339166..9a0f96cfa2e5f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -34,6 +34,7 @@ use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; use datafusion_common::metadata::FieldMetadata; +use datafusion_expr::select_expr::SelectExpr; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -72,7 +73,9 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; +use datafusion_expr::expr::{ + GroupingSet, NullTreatment, Sort, WildcardOptions, WindowFunction, +}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, @@ -6859,21 +6862,69 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { let df = test_table().await?; - let res = df.select(vec![ + // Multiple aggregates + let res = df.clone().select(vec![ count(col("c9")).alias("count_c9"), count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + sum(col("c9")).alias("sum_c9"), + count(col("c8")).alias("count_c8"), + (sum(col("c9")) + count(col("c8"))).alias("total1"), + ((count(col("c9")) + lit(1)) * lit(2)).alias("total2"), + (count(col("c9")) + lit(1)).alias("count_c9_add_1"), + ])?; + + assert_batches_eq!( + &[ + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + "| count_c9 | count_c9_str | sum_c9 | count_c8 | total1 | total2 | count_c9_add_1 |", + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + "| 100 | 100 | 222089770060 | 100 | 222089770160 | 202 | 101 |", + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + ], + &res.collect().await? + ); + + // Test duplicate aggregate aliases + let res = df.clone().select(vec![ + count(col("c9")).alias("count_c9"), + count(col("c9")).alias("count_c9_2"), ])?; assert_batches_eq!( &[ - "+----------+--------------+", - "| count_c9 | count_c9_str |", - "+----------+--------------+", - "| 100 | 100 |", - "+----------+--------------+", + "+----------+------------+", + "| count_c9 | count_c9_2 |", + "+----------+------------+", + "| 100 | 100 |", + "+----------+------------+", ], &res.collect().await? ); + // Wildcard + let res = df + .clone() + .select(vec![ + SelectExpr::Wildcard(WildcardOptions::default()), + lit(42).into(), + ])? + .limit(0, None)?; + + let batches = res.collect().await?; + assert_eq!(batches[0].num_rows(), 100); + assert_eq!(batches[0].num_columns(), 14); + + let res = df.clone().select(vec![ + SelectExpr::QualifiedWildcard( + "aggregate_test_100".into(), + WildcardOptions::default(), + ), + lit(42).into(), + ])?; + + let batches = res.collect().await?; + assert_eq!(batches[0].num_rows(), 100); + assert_eq!(batches[0].num_columns(), 14); + Ok(()) }