diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index b0618b971..7d1b972fa 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -129,6 +129,7 @@ enum WindowFunction { ROW_NUMBER = 0; RANK = 1; DENSE_RANK = 2; + PERCENT_RANK = 3; } enum AggFunction { diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 84a625734..c8d3c67c4 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -636,6 +636,9 @@ impl PhysicalPlanner { protobuf::WindowFunction::DenseRank => { WindowFunction::RankLike(WindowRankType::DenseRank) } + protobuf::WindowFunction::PercentRank => { + WindowFunction::PercentRank + } }, protobuf::WindowFunctionType::Agg => match w.agg_func() { protobuf::AggFunction::Min => WindowFunction::Agg(AggFunction::Min), diff --git a/native-engine/datafusion-ext-plans/src/window/mod.rs b/native-engine/datafusion-ext-plans/src/window/mod.rs index a9e9da29d..9040bd60b 100644 --- a/native-engine/datafusion-ext-plans/src/window/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/mod.rs @@ -23,8 +23,8 @@ use crate::{ agg::{AggFunction, agg::create_agg}, window::{ processors::{ - agg_processor::AggProcessor, rank_processor::RankProcessor, - row_number_processor::RowNumberProcessor, + agg_processor::AggProcessor, percent_rank_processor::PercentRankProcessor, + rank_processor::RankProcessor, row_number_processor::RowNumberProcessor, }, window_context::WindowContext, }, @@ -36,6 +36,7 @@ pub mod window_context; #[derive(Debug, Clone, Copy)] pub enum WindowFunction { RankLike(WindowRankType), + PercentRank, Agg(AggFunction), } @@ -87,6 +88,7 @@ impl WindowExpr { WindowFunction::RankLike(WindowRankType::DenseRank) => { Ok(Box::new(RankProcessor::new(true))) } + WindowFunction::PercentRank => Ok(Box::new(PercentRankProcessor::new())), WindowFunction::Agg(agg_func) => { let agg = create_agg( agg_func.clone(), @@ -98,4 +100,8 @@ impl WindowExpr { } } } + + pub fn requires_full_partition(&self) -> bool { + matches!(self.func, WindowFunction::PercentRank) + } } diff --git a/native-engine/datafusion-ext-plans/src/window/processors/mod.rs b/native-engine/datafusion-ext-plans/src/window/processors/mod.rs index 7d4a72b55..8010f8ab0 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/mod.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/mod.rs @@ -14,5 +14,6 @@ // limitations under the License. pub mod agg_processor; +pub mod percent_rank_processor; pub mod rank_processor; pub mod row_number_processor; diff --git a/native-engine/datafusion-ext-plans/src/window/processors/percent_rank_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/percent_rank_processor.rs new file mode 100644 index 000000000..7cff18771 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/window/processors/percent_rank_processor.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::{ArrayRef, Float64Builder}, + record_batch::RecordBatch, +}; +use datafusion::common::Result; + +use crate::window::{WindowFunctionProcessor, window_context::WindowContext}; + +pub struct PercentRankProcessor; + +impl PercentRankProcessor { + pub fn new() -> Self { + Self + } +} + +impl Default for PercentRankProcessor { + fn default() -> Self { + Self::new() + } +} + +impl WindowFunctionProcessor for PercentRankProcessor { + fn process_batch(&mut self, context: &WindowContext, batch: &RecordBatch) -> Result { + let partition_rows = context.get_partition_rows(batch)?; + let order_rows = context.get_order_rows(batch)?; + let mut builder = Float64Builder::with_capacity(batch.num_rows()); + + let mut row_idx = 0usize; + while row_idx < batch.num_rows() { + let partition_start = row_idx; + row_idx += 1; + while row_idx < batch.num_rows() + && (!context.has_partition() + || partition_rows.row(row_idx).as_ref() + == partition_rows.row(partition_start).as_ref()) + { + row_idx += 1; + } + + let partition_end = row_idx; + let partition_size = partition_end - partition_start; + let denominator = (partition_size.saturating_sub(1)) as f64; + + let mut rank = 1usize; + let mut peer_group_size = 1usize; + for current_idx in partition_start..partition_end { + if current_idx > partition_start { + let prev_idx = current_idx - 1; + if order_rows.row(current_idx).as_ref() == order_rows.row(prev_idx).as_ref() { + peer_group_size += 1; + } else { + rank += peer_group_size; + peer_group_size = 1; + } + } + + let percent_rank = if partition_size <= 1 { + 0.0 + } else { + (rank - 1) as f64 / denominator + }; + builder.append_value(percent_rank); + } + } + + Ok(Arc::new(builder.finish())) + } +} diff --git a/native-engine/datafusion-ext-plans/src/window/window_context.rs b/native-engine/datafusion-ext-plans/src/window/window_context.rs index a76eb1253..1c5f68f12 100644 --- a/native-engine/datafusion-ext-plans/src/window/window_context.rs +++ b/native-engine/datafusion-ext-plans/src/window/window_context.rs @@ -167,4 +167,10 @@ impl WindowContext { .collect::>>()?, )?) } + + pub fn requires_full_partition(&self) -> bool { + self.window_exprs + .iter() + .any(|expr| expr.requires_full_partition()) + } } diff --git a/native-engine/datafusion-ext-plans/src/window_exec.rs b/native-engine/datafusion-ext-plans/src/window_exec.rs index 5bb698eec..8676a02cf 100644 --- a/native-engine/datafusion-ext-plans/src/window_exec.rs +++ b/native-engine/datafusion-ext-plans/src/window_exec.rs @@ -17,6 +17,7 @@ use std::{any::Any, fmt::Formatter, sync::Arc}; use arrow::{ array::{Array, ArrayRef, Int32Array}, + compute::concat_batches, datatypes::SchemaRef, record_batch::{RecordBatch, RecordBatchOptions}, }; @@ -37,7 +38,7 @@ use once_cell::sync::OnceCell; use crate::{ common::execution_context::ExecutionContext, - window::{WindowExpr, window_context::WindowContext}, + window::{WindowExpr, WindowFunctionProcessor, window_context::WindowContext}, }; #[derive(Debug)] @@ -217,45 +218,29 @@ fn execute_window( .map(|expr: &WindowExpr| expr.create_processor(&window_ctx)) .collect::>>()?; - while let Some(mut batch) = input.next().await.transpose()? { - let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); - let mut window_cols: Vec = processors - .iter_mut() - .map(|processor| processor.process_batch(&window_ctx, &batch)) - .collect::>()?; - - if let Some(group_limit) = window_ctx.group_limit { - assert_eq!(window_cols.len(), 1); - let limited = arrow::compute::kernels::cmp::lt_eq( - &window_cols[0], - &Int32Array::new_scalar(group_limit as i32), - )?; - window_cols[0] = arrow::compute::filter(&window_cols[0], &limited)?; - batch = arrow::compute::filter_record_batch(&batch, &limited)?; + if window_ctx.requires_full_partition() { + let mut staging_batches = vec![]; + while let Some(batch) = input.next().await.transpose()? { + staging_batches.push(batch); } - let outputs: Vec = batch - .columns() - .iter() - .cloned() - .chain(if window_ctx.output_window_cols { - window_cols - } else { - vec![] - }) - .zip(window_ctx.output_schema.fields()) - .map(|(array, field)| { - if array.data_type() != field.data_type() { - return cast(&array, field.data_type()); - } - Ok(array.clone()) - }) - .collect::>()?; - let output_batch = RecordBatch::try_new_with_options( - window_ctx.output_schema.clone(), - outputs, - &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), - )?; + if !staging_batches.is_empty() { + let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); + let batch = concat_batches(&window_ctx.input_schema, &staging_batches)?; + let output_batch = + process_window_batch(batch, &window_ctx, processors.as_mut_slice())?; + exec_ctx + .baseline_metrics() + .record_output(output_batch.num_rows()); + sender.send(output_batch).await; + } + return Ok(()); + } + + while let Some(batch) = input.next().await.transpose()? { + let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); + let output_batch = + process_window_batch(batch, &window_ctx, processors.as_mut_slice())?; exec_ctx .baseline_metrics() .record_output(output_batch.num_rows()); @@ -265,6 +250,50 @@ fn execute_window( })) } +fn process_window_batch( + mut batch: RecordBatch, + window_ctx: &WindowContext, + processors: &mut [Box], +) -> Result { + let mut window_cols: Vec = processors + .iter_mut() + .map(|processor| processor.process_batch(window_ctx, &batch)) + .collect::>()?; + + if let Some(group_limit) = window_ctx.group_limit { + assert_eq!(window_cols.len(), 1); + let limited = arrow::compute::kernels::cmp::lt_eq( + &window_cols[0], + &Int32Array::new_scalar(group_limit as i32), + )?; + window_cols[0] = arrow::compute::filter(&window_cols[0], &limited)?; + batch = arrow::compute::filter_record_batch(&batch, &limited)?; + } + + let outputs: Vec = batch + .columns() + .iter() + .cloned() + .chain(if window_ctx.output_window_cols { + window_cols + } else { + vec![] + }) + .zip(window_ctx.output_schema.fields()) + .map(|(array, field)| { + if array.data_type() != field.data_type() { + return cast(&array, field.data_type()); + } + Ok(array.clone()) + }) + .collect::>()?; + Ok(RecordBatch::try_new_with_options( + window_ctx.output_schema.clone(), + outputs, + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?) +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -447,6 +476,49 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_percent_rank_window() -> Result<(), Box> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let input = build_table( + ("grp", &vec![1, 1, 1, 2]), + ("id", &vec![1, 1, 2, 5]), + ("v", &vec![10, 20, 30, 40]), + )?; + let window_exprs = vec![WindowExpr::new( + WindowFunction::PercentRank, + vec![], + Arc::new(Field::new("percent_rank", DataType::Float64, false)), + DataType::Float64, + )]; + let window = Arc::new(WindowExec::try_new( + input, + window_exprs, + vec![Arc::new(Column::new("grp", 0))], + vec![PhysicalSortExpr { + expr: Arc::new(Column::new("id", 1)), + options: Default::default(), + }], + None, + true, + )?); + let stream = window.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(stream).await?; + let expected = vec![ + "+-----+----+----+--------------+", + "| grp | id | v | percent_rank |", + "+-----+----+----+--------------+", + "| 1 | 1 | 10 | 0.0 |", + "| 1 | 1 | 20 | 0.0 |", + "| 1 | 2 | 30 | 1.0 |", + "| 2 | 5 | 40 | 0.0 |", + "+-----+----+----+--------------+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn test_window_group_limit() -> Result<(), Box> { let session_ctx = SessionContext::new(); diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala index 40eecda4e..c0bb81eef 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronQuerySuite.scala @@ -574,6 +574,32 @@ class AuronQuerySuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQ } } + test("percent_rank window") { + withTable("t_percent_rank") { + sql(""" + |create table t_percent_rank using parquet as + |select * from values + | (1, 1, 10), + | (1, 1, 20), + | (1, 2, 30), + | (2, 5, 40) + |as t(grp, id, v) + |""".stripMargin) + + checkSparkAnswerAndOperator(""" + |select + | grp, + | id, + | v, + | percent_rank() over ( + | partition by grp + | order by id + | ) as percent_rank_v + |from t_percent_rank + |order by grp, id, v + |""".stripMargin) + } + } test("standard LEFT ANTI JOIN includes NULL keys") { // This test verifies that standard LEFT ANTI JOIN correctly includes NULL keys // NULL keys should be in the result because NULL never matches anything diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index fad61ff09..072c901cd 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.DenseRank import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.expressions.NullsFirst +import org.apache.spark.sql.catalyst.expressions.PercentRank import org.apache.spark.sql.catalyst.expressions.Rank import org.apache.spark.sql.catalyst.expressions.RowNumber import org.apache.spark.sql.catalyst.expressions.SortOrder @@ -118,6 +119,12 @@ abstract class NativeWindowBase( windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) windowExprBuilder.setWindowFunc(pb.WindowFunction.DENSE_RANK) + case e: PercentRank => + assert( + spec.frameSpecification == e.frame, + s"window frame not supported: ${spec.frameSpecification}") + windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) + windowExprBuilder.setWindowFunc(pb.WindowFunction.PERCENT_RANK) case e: Sum => assert( spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounde, CurrentRow)