diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2292f5855bfde..da0868f9a09a5 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -51,12 +51,14 @@ use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow_schema::FieldRef; use datafusion_common::config::{CsvOptions, JsonOptions}; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ Column, DFSchema, DataFusionError, ParamValues, ScalarValue, SchemaError, TableReference, UnnestOptions, exec_err, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; +use datafusion_expr::utils::find_aggregate_exprs; use datafusion_expr::{ ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, @@ -410,21 +412,95 @@ impl DataFrame { expr_list: impl IntoIterator>, ) -> Result { let expr_list: Vec = - expr_list.into_iter().map(|e| e.into()).collect::>(); + expr_list.into_iter().map(|e| e.into()).collect(); + // Extract expressions let expressions = expr_list.iter().filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr), _ => None, }); - let window_func_exprs = find_window_exprs(expressions); - let plan = if window_func_exprs.is_empty() { + // Apply window functions first + let window_func_exprs = find_window_exprs(expressions.clone()); + + let mut plan = if window_func_exprs.is_empty() { self.plan } else { LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + // Collect aggregate expressions + let aggr_exprs = find_aggregate_exprs(expressions.clone()); + + // Check for non-aggregate expressions + let has_non_aggregate_expr = expressions + .clone() + .any(|expr| find_aggregate_exprs(std::iter::once(expr)).is_empty()); + + // Fallback to projection: + // - already aggregated + // - contains non-aggregate expressions + // - no aggregates + if matches!(plan, LogicalPlan::Aggregate(_)) + || has_non_aggregate_expr + || aggr_exprs.is_empty() + { + let project_plan = + LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + + return Ok(DataFrame { + session_state: self.session_state, + plan: project_plan, + projection_requires_validation: false, + }); + } + + // Assign aliases to aggregate expressions + let mut aggr_map: HashMap = HashMap::new(); + let aggr_exprs_with_alias: Vec = aggr_exprs + .into_iter() + .enumerate() + .map(|(i, expr)| { + let alias = format!("__df_agg_{i}"); + let aliased = expr.clone().alias(alias.clone()); + let col = Expr::Column(Column::from_name(alias)); + aggr_map.insert(expr, col); + aliased + }) + .collect(); + + // Build aggregate plan + plan = LogicalPlanBuilder::from(plan) + .aggregate(Vec::::new(), aggr_exprs_with_alias)? + .build()?; + + // Rewrite expressions to use aggregate outputs + let rewrite_expr = |expr: Expr, aggr_map: &HashMap| -> Result { + expr.transform(|e| { + Ok(match aggr_map.get(&e) { + Some(replacement) => Transformed::yes(replacement.clone()), + None => Transformed::no(e), + }) + }) + .map(|t| t.data) + }; + + let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); + for select_expr in expr_list.into_iter() { + match select_expr { + SelectExpr::Expression(expr) => { + let rewritten = rewrite_expr(expr.clone(), &aggr_map)?; + let alias = expr.name_for_alias()?; + rewritten_exprs.push(SelectExpr::Expression(rewritten.alias(alias))); + } + other => rewritten_exprs.push(other), + } + } + + // Final projection + let project_plan = LogicalPlanBuilder::from(plan) + .project(rewritten_exprs)? + .build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 80bbde1f6ba14..9a0f96cfa2e5f 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -34,6 +34,7 @@ use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; use datafusion_common::metadata::FieldMetadata; +use datafusion_expr::select_expr::SelectExpr; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -72,7 +73,9 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_datasource::file_format::format_as_file_type; use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnv; -use datafusion_expr::expr::{GroupingSet, NullTreatment, Sort, WindowFunction}; +use datafusion_expr::expr::{ + GroupingSet, NullTreatment, Sort, WildcardOptions, WindowFunction, +}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, LogicalPlanBuilder, @@ -6854,3 +6857,74 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { + let df = test_table().await?; + + // Multiple aggregates + let res = df.clone().select(vec![ + count(col("c9")).alias("count_c9"), + count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + sum(col("c9")).alias("sum_c9"), + count(col("c8")).alias("count_c8"), + (sum(col("c9")) + count(col("c8"))).alias("total1"), + ((count(col("c9")) + lit(1)) * lit(2)).alias("total2"), + (count(col("c9")) + lit(1)).alias("count_c9_add_1"), + ])?; + + assert_batches_eq!( + &[ + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + "| count_c9 | count_c9_str | sum_c9 | count_c8 | total1 | total2 | count_c9_add_1 |", + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + "| 100 | 100 | 222089770060 | 100 | 222089770160 | 202 | 101 |", + "+----------+--------------+--------------+----------+--------------+--------+----------------+", + ], + &res.collect().await? + ); + + // Test duplicate aggregate aliases + let res = df.clone().select(vec![ + count(col("c9")).alias("count_c9"), + count(col("c9")).alias("count_c9_2"), + ])?; + + assert_batches_eq!( + &[ + "+----------+------------+", + "| count_c9 | count_c9_2 |", + "+----------+------------+", + "| 100 | 100 |", + "+----------+------------+", + ], + &res.collect().await? + ); + + // Wildcard + let res = df + .clone() + .select(vec![ + SelectExpr::Wildcard(WildcardOptions::default()), + lit(42).into(), + ])? + .limit(0, None)?; + + let batches = res.collect().await?; + assert_eq!(batches[0].num_rows(), 100); + assert_eq!(batches[0].num_columns(), 14); + + let res = df.clone().select(vec![ + SelectExpr::QualifiedWildcard( + "aggregate_test_100".into(), + WildcardOptions::default(), + ), + lit(42).into(), + ])?; + + let batches = res.collect().await?; + assert_eq!(batches[0].num_rows(), 100); + assert_eq!(batches[0].num_columns(), 14); + + Ok(()) +}