From a62b1f7a4b55d8220ddff4d67f4e70cf5049d5ee Mon Sep 17 00:00:00 2001 From: light-city <455954986@qq.com> Date: Sat, 4 Nov 2023 00:07:21 +0800 Subject: [PATCH 01/83] Fix: Add support for null sort option per sort key. 1.Reconstruct the SortKey structure and add NullPlacement. 2.Remove NullPlacement from SortOptions 3.Fix selectk not displaying non-empty results in null AtEnd scenario. When limit k is greater than the actual table data and the table contains Null/NaN, the data cannot be obtained and only non-empty results are available. Therefore, we support returning non-null and supporting the order of setting Null for each SortKey. 4.Add relevant unit tests and change the interface implemented by multiple versions --- c_glib/arrow-glib/compute.cpp | 65 +-- c_glib/arrow-glib/compute.h | 1 + c_glib/test/test-rank-options.rb | 18 +- c_glib/test/test-sort-indices.rb | 8 +- c_glib/test/test-sort-options.rb | 22 +- cpp/src/arrow/acero/order_by_node.cc | 2 +- cpp/src/arrow/acero/order_by_node_test.cc | 10 +- cpp/src/arrow/acero/plan_test.cc | 2 +- cpp/src/arrow/compute/api_vector.cc | 21 +- cpp/src/arrow/compute/api_vector.h | 14 +- cpp/src/arrow/compute/exec_test.cc | 2 +- .../arrow/compute/kernels/select_k_test.cc | 464 ++++++++++++++- cpp/src/arrow/compute/kernels/vector_rank.cc | 6 +- .../arrow/compute/kernels/vector_select_k.cc | 538 +++++++++++++----- cpp/src/arrow/compute/kernels/vector_sort.cc | 52 +- .../compute/kernels/vector_sort_internal.h | 53 +- .../arrow/compute/kernels/vector_sort_test.cc | 285 ++++++---- cpp/src/arrow/compute/ordering.cc | 29 +- cpp/src/arrow/compute/ordering.h | 19 +- .../engine/substrait/relation_internal.cc | 20 +- cpp/src/arrow/engine/substrait/serde_test.cc | 9 +- python/pyarrow/_acero.pyx | 20 +- python/pyarrow/_compute.pyx | 38 +- python/pyarrow/_dataset.pyx | 6 +- python/pyarrow/array.pxi | 14 +- python/pyarrow/compute.py | 21 +- python/pyarrow/includes/libarrow.pxd | 12 +- python/pyarrow/table.pxi | 15 +- python/pyarrow/tests/test_acero.py | 10 +- python/pyarrow/tests/test_compute.py | 85 ++- python/pyarrow/tests/test_dataset.py | 8 +- python/pyarrow/tests/test_table.py | 11 +- ruby/red-arrow/lib/arrow/sort-key.rb | 57 +- ruby/red-arrow/lib/arrow/sort-options.rb | 4 +- ruby/red-arrow/test/test-sort-key.rb | 14 +- ruby/red-arrow/test/test-sort-options.rb | 8 +- 36 files changed, 1354 insertions(+), 609 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 9692f277d18..0fcb2fefcac 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -2987,6 +2987,7 @@ typedef struct GArrowSortKeyPrivate_ { enum { PROP_SORT_KEY_TARGET = 1, PROP_SORT_KEY_ORDER, + PROP_SORT_KEY_NULL_PLACEMENT, }; G_DEFINE_TYPE_WITH_PRIVATE(GArrowSortKey, @@ -3019,6 +3020,10 @@ garrow_sort_key_set_property(GObject *object, priv->sort_key.order = static_cast(g_value_get_enum(value)); break; + case PROP_SORT_KEY_NULL_PLACEMENT: + priv->sort_key.null_placement = + static_cast(g_value_get_enum(value)); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -3047,6 +3052,10 @@ garrow_sort_key_get_property(GObject *object, case PROP_SORT_KEY_ORDER: g_value_set_enum(value, static_cast(priv->sort_key.order)); break; + case PROP_SORT_KEY_NULL_PLACEMENT: + g_value_set_enum(value, + static_cast(priv->sort_key.null_placement)); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -3103,12 +3112,29 @@ garrow_sort_key_class_init(GArrowSortKeyClass *klass) static_cast(G_PARAM_READWRITE | G_PARAM_CONSTRUCT_ONLY)); g_object_class_install_property(gobject_class, PROP_SORT_KEY_ORDER, spec); + + /** + * GArrowSortKey::null-placement: + * + * Whether nulls and NaNs are placed at the start or at the end. + * + * Since: 15.0.0 + */ + spec = g_param_spec_enum("null-placement", + "Null Placement", + "Whether nulls and NaNs are placed at the start or at the end", + GARROW_TYPE_NULL_PLACEMENT, + 0, + static_cast(G_PARAM_READWRITE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_SORT_KEY_NULL_PLACEMENT, spec); } /** * garrow_sort_key_new: * @target: A name or dot path for sort target. * @order: How to order by this sort key. + * @null_placement: Whether nulls and NaNs are placed at the start or at the end. * * Returns: A newly created #GArrowSortKey. * @@ -3117,6 +3143,7 @@ garrow_sort_key_class_init(GArrowSortKeyClass *klass) GArrowSortKey * garrow_sort_key_new(const gchar *target, GArrowSortOrder order, + GArrowNullPlacement null_placement, GError **error) { auto arrow_reference_result = garrow_field_reference_resolve_raw(target); @@ -3127,6 +3154,7 @@ garrow_sort_key_new(const gchar *target, } auto sort_key = g_object_new(GARROW_TYPE_SORT_KEY, "order", order, + "null-placement", null_placement, NULL); auto priv = GARROW_SORT_KEY_GET_PRIVATE(sort_key); priv->sort_key.target = *arrow_reference_result; @@ -4461,8 +4489,7 @@ garrow_index_options_new(void) enum { - PROP_RANK_OPTIONS_NULL_PLACEMENT = 1, - PROP_RANK_OPTIONS_TIEBREAKER, + PROP_RANK_OPTIONS_TIEBREAKER = 1, }; G_DEFINE_TYPE(GArrowRankOptions, @@ -4483,10 +4510,6 @@ garrow_rank_options_set_property(GObject *object, auto options = garrow_rank_options_get_raw(GARROW_RANK_OPTIONS(object)); switch (prop_id) { - case PROP_RANK_OPTIONS_NULL_PLACEMENT: - options->null_placement = - static_cast(g_value_get_enum(value)); - break; case PROP_RANK_OPTIONS_TIEBREAKER: options->tiebreaker = static_cast( @@ -4507,11 +4530,6 @@ garrow_rank_options_get_property(GObject *object, auto options = garrow_rank_options_get_raw(GARROW_RANK_OPTIONS(object)); switch (prop_id) { - case PROP_RANK_OPTIONS_NULL_PLACEMENT: - g_value_set_enum( - value, - static_cast(options->null_placement)); - break; case PROP_RANK_OPTIONS_TIEBREAKER: g_value_set_enum( value, @@ -4543,25 +4561,6 @@ garrow_rank_options_class_init(GArrowRankOptionsClass *klass) auto options = arrow::compute::RankOptions::Defaults(); GParamSpec *spec; - /** - * GArrowRankOptions:null-placement: - * - * Whether nulls and NaNs are placed at the start or at the end. - * - * Since: 12.0.0 - */ - spec = g_param_spec_enum("null-placement", - "Null placement", - "Whether nulls and NaNs are placed " - "at the start or at the end.", - GARROW_TYPE_NULL_PLACEMENT, - static_cast( - options.null_placement), - static_cast(G_PARAM_READWRITE)); - g_object_class_install_property(gobject_class, - PROP_RANK_OPTIONS_NULL_PLACEMENT, - spec); - /** * GArrowRankOptions:tiebreaker: * @@ -4614,9 +4613,6 @@ garrow_rank_options_equal(GArrowRankOptions *options, arrow_other_options->sort_keys)) { return FALSE; } - if (arrow_options->null_placement != arrow_other_options->null_placement) { - return FALSE; - } if (arrow_options->tiebreaker != arrow_other_options->tiebreaker) { return FALSE; } @@ -6434,8 +6430,6 @@ garrow_sort_options_new_raw( NULL)); auto arrow_new_options = garrow_sort_options_get_raw(options); arrow_new_options->sort_keys = arrow_options->sort_keys; - /* TODO: Use property when we add support for null_placement. */ - arrow_new_options->null_placement = arrow_options->null_placement; return options; } @@ -6633,7 +6627,6 @@ garrow_rank_options_new_raw(const arrow::compute::RankOptions *arrow_options) { auto options = GARROW_RANK_OPTIONS( g_object_new(GARROW_TYPE_RANK_OPTIONS, - "null-placement", arrow_options->null_placement, "tiebreaker", arrow_options->tiebreaker, nullptr)); auto arrow_new_options = garrow_rank_options_get_raw(options); diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index 008ae2a7838..0a1ec0add79 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -571,6 +571,7 @@ GARROW_AVAILABLE_IN_3_0 GArrowSortKey * garrow_sort_key_new(const gchar *target, GArrowSortOrder order, + GArrowNullPlacement null_placement, GError **error); GARROW_AVAILABLE_IN_3_0 diff --git a/c_glib/test/test-rank-options.rb b/c_glib/test/test-rank-options.rb index 06806035cda..ba61d51607c 100644 --- a/c_glib/test/test-rank-options.rb +++ b/c_glib/test/test-rank-options.rb @@ -29,29 +29,23 @@ def test_equal def test_sort_keys sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] @options.sort_keys = sort_keys assert_equal(sort_keys, @options.sort_keys) end def test_add_sort_key - @options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - @options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + @options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + @options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_start)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_start), ], @options.sort_keys) end - def test_null_placement - assert_equal(Arrow::NullPlacement::AT_END, @options.null_placement) - @options.null_placement = :at_start - assert_equal(Arrow::NullPlacement::AT_START, @options.null_placement) - end - def test_tiebreaker assert_equal(Arrow::RankTiebreaker::FIRST, @options.tiebreaker) @options.tiebreaker = :max diff --git a/c_glib/test/test-sort-indices.rb b/c_glib/test/test-sort-indices.rb index a8c4f40c50f..a94da3a46f0 100644 --- a/c_glib/test/test-sort-indices.rb +++ b/c_glib/test/test-sort-indices.rb @@ -41,8 +41,8 @@ def test_record_batch } record_batch = build_record_batch(columns) sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] options = Arrow::SortOptions.new(sort_keys) assert_equal(build_uint64_array([4, 1, 0, 5, 3, 2]), @@ -61,8 +61,8 @@ def test_table } table = build_table(columns) options = Arrow::SortOptions.new - options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal(build_uint64_array([4, 1, 0, 5, 3, 2]), table.sort_indices(options)) end diff --git a/c_glib/test/test-sort-options.rb b/c_glib/test/test-sort-options.rb index e57645b1cfb..78c3ef16a60 100644 --- a/c_glib/test/test-sort-options.rb +++ b/c_glib/test/test-sort-options.rb @@ -20,8 +20,8 @@ class TestSortOptions < Test::Unit::TestCase def test_new sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] options = Arrow::SortOptions.new(sort_keys) assert_equal(sort_keys, options.sort_keys) @@ -29,20 +29,20 @@ def test_new def test_add_sort_key options = Arrow::SortOptions.new - options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ], options.sort_keys) end def test_set_sort_keys - options = Arrow::SortOptions.new([Arrow::SortKey.new("column3", :ascending)]) + options = Arrow::SortOptions.new([Arrow::SortKey.new("column3", :ascending, :at_end)]) sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] options.sort_keys = sort_keys assert_equal(sort_keys, options.sort_keys) @@ -50,8 +50,8 @@ def test_set_sort_keys def test_equal sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_start), + Arrow::SortKey.new("column2", :descending, :at_end), ] assert_equal(Arrow::SortOptions.new(sort_keys), Arrow::SortOptions.new(sort_keys)) diff --git a/cpp/src/arrow/acero/order_by_node.cc b/cpp/src/arrow/acero/order_by_node.cc index 1811fa9f4c7..3c8a978b4ed 100644 --- a/cpp/src/arrow/acero/order_by_node.cc +++ b/cpp/src/arrow/acero/order_by_node.cc @@ -115,7 +115,7 @@ class OrderByNode : public ExecNode, public TracedNode { ARROW_ASSIGN_OR_RAISE( auto table, Table::FromRecordBatches(output_schema_, std::move(accumulation_queue_))); - SortOptions sort_options(ordering_.sort_keys(), ordering_.null_placement()); + SortOptions sort_options(ordering_.sort_keys()); ExecContext* ctx = plan_->query_context()->exec_context(); ARROW_ASSIGN_OR_RAISE(auto indices, SortIndices(table, sort_options, ctx)); ARROW_ASSIGN_OR_RAISE(Datum sorted, diff --git a/cpp/src/arrow/acero/order_by_node_test.cc b/cpp/src/arrow/acero/order_by_node_test.cc index d77b0f3184f..56d0773a2ac 100644 --- a/cpp/src/arrow/acero/order_by_node_test.cc +++ b/cpp/src/arrow/acero/order_by_node_test.cc @@ -77,10 +77,10 @@ void CheckOrderByInvalid(OrderByNodeOptions options, const std::string& message) } TEST(OrderByNode, Basic) { - CheckOrderBy(OrderByNodeOptions({{SortKey("up")}})); - CheckOrderBy(OrderByNodeOptions({{SortKey("down", SortOrder::Descending)}})); - CheckOrderBy( - OrderByNodeOptions({{SortKey("up"), SortKey("down", SortOrder::Descending)}})); + CheckOrderBy(OrderByNodeOptions(Ordering{{SortKey("up")}})); + CheckOrderBy(OrderByNodeOptions(Ordering({SortKey("down", SortOrder::Descending)}))); + CheckOrderBy(OrderByNodeOptions( + Ordering({SortKey("up"), SortKey("down", SortOrder::Descending)}))); } TEST(OrderByNode, Large) { @@ -95,7 +95,7 @@ TEST(OrderByNode, Large) { ->Table(ExecPlan::kMaxBatchSize, kSmallNumBatches); Declaration plan = Declaration::Sequence({ {"table_source", TableSourceNodeOptions(input)}, - {"order_by", OrderByNodeOptions({{SortKey("up", SortOrder::Descending)}})}, + {"order_by", OrderByNodeOptions(Ordering({SortKey("up", SortOrder::Descending)}))}, {"jitter", JitterNodeOptions(kSeed, kJitterMod)}, }); ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema batches_and_schema, diff --git a/cpp/src/arrow/acero/plan_test.cc b/cpp/src/arrow/acero/plan_test.cc index e74ad6a6665..30da74aac23 100644 --- a/cpp/src/arrow/acero/plan_test.cc +++ b/cpp/src/arrow/acero/plan_test.cc @@ -515,7 +515,7 @@ TEST(ExecPlan, ToString) { }); ASSERT_OK_AND_ASSIGN(std::string plan_str, DeclarationToString(declaration)); EXPECT_EQ(plan_str, R"a(ExecPlan with 6 nodes: -custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC], null_placement=AtEnd}} +custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC AtEnd]}} :FilterNode{filter=(sum(multiply(i32, 2)) > 10)} :GroupByNode{keys=["bool"], aggregates=[ hash_sum(multiply(i32, 2)), diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index d47ee42ebf2..67b33561fb1 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -135,9 +135,8 @@ static auto kRunEndEncodeOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); -static auto kSortOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &SortOptions::sort_keys), - DataMember("null_placement", &SortOptions::null_placement)); +static auto kSortOptionsType = + GetFunctionOptionsType(DataMember("sort_keys", &SortOptions::sort_keys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot), DataMember("null_placement", &PartitionNthOptions::null_placement)); @@ -149,7 +148,6 @@ static auto kCumulativeOptionsType = GetFunctionOptionsType( DataMember("skip_nulls", &CumulativeOptions::skip_nulls)); static auto kRankOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &RankOptions::sort_keys), - DataMember("null_placement", &RankOptions::null_placement), DataMember("tiebreaker", &RankOptions::tiebreaker)); static auto kPairwiseOptionsType = GetFunctionOptionsType( DataMember("periods", &PairwiseOptions::periods)); @@ -180,14 +178,10 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; -SortOptions::SortOptions(std::vector sort_keys, NullPlacement null_placement) - : FunctionOptions(internal::kSortOptionsType), - sort_keys(std::move(sort_keys)), - null_placement(null_placement) {} +SortOptions::SortOptions(std::vector sort_keys) + : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)) {} SortOptions::SortOptions(const Ordering& ordering) - : FunctionOptions(internal::kSortOptionsType), - sort_keys(ordering.sort_keys()), - null_placement(ordering.null_placement()) {} + : FunctionOptions(internal::kSortOptionsType), sort_keys(ordering.sort_keys()) {} constexpr char SortOptions::kTypeName[]; PartitionNthOptions::PartitionNthOptions(int64_t pivot, NullPlacement null_placement) @@ -212,11 +206,10 @@ CumulativeOptions::CumulativeOptions(std::shared_ptr start, bool skip_nu skip_nulls(skip_nulls) {} constexpr char CumulativeOptions::kTypeName[]; -RankOptions::RankOptions(std::vector sort_keys, NullPlacement null_placement, +RankOptions::RankOptions(std::vector sort_keys, RankOptions::Tiebreaker tiebreaker) : FunctionOptions(internal::kRankOptionsType), sort_keys(std::move(sort_keys)), - null_placement(null_placement), tiebreaker(tiebreaker) {} constexpr char RankOptions::kTypeName[]; @@ -299,7 +292,7 @@ Result> SortIndices(const Array& values, SortOrder order, Result> SortIndices(const ChunkedArray& chunked_array, const ArraySortOptions& array_options, ExecContext* ctx) { - SortOptions options({SortKey("", array_options.order)}, array_options.null_placement); + SortOptions options({SortKey("", array_options.order, array_options.null_placement)}); ARROW_ASSIGN_OR_RAISE( Datum result, CallFunction("sort_indices", {Datum(chunked_array)}, &options, ctx)); return result.make_array(); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 0233090ef6f..40aacea2606 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -105,8 +105,7 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { class ARROW_EXPORT SortOptions : public FunctionOptions { public: - explicit SortOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd); + explicit SortOptions(std::vector sort_keys = {}); explicit SortOptions(const Ordering& ordering); static constexpr char const kTypeName[] = "SortOptions"; static SortOptions Defaults() { return SortOptions(); } @@ -115,13 +114,11 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// Note: Both classes contain the exact same information. However, /// sort_options should only be used in a "function options" context while Ordering /// is used more generally. - Ordering AsOrdering() && { return Ordering(std::move(sort_keys), null_placement); } - Ordering AsOrdering() const& { return Ordering(sort_keys, null_placement); } + Ordering AsOrdering() && { return Ordering(std::move(sort_keys)); } + Ordering AsOrdering() const& { return Ordering(sort_keys); } /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; - /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; }; /// \brief SelectK options @@ -177,21 +174,18 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { }; explicit RankOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd, Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs explicit RankOptions(SortOrder order, NullPlacement null_placement = NullPlacement::AtEnd, Tiebreaker tiebreaker = RankOptions::First) - : RankOptions({SortKey("", order)}, null_placement, tiebreaker) {} + : RankOptions({SortKey("", order, null_placement)}, tiebreaker) {} static constexpr char const kTypeName[] = "RankOptions"; static RankOptions Defaults() { return RankOptions(); } /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; - /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; /// Tiebreaker for dealing with equal values in ranks Tiebreaker tiebreaker; }; diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index d661e5735fe..04b864504c3 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -1400,7 +1400,7 @@ TEST(Ordering, IsSuborderOf) { Ordering a{{SortKey{3}, SortKey{1}, SortKey{7}}}; Ordering b{{SortKey{3}, SortKey{1}}}; Ordering c{{SortKey{1}, SortKey{7}}}; - Ordering d{{SortKey{1}, SortKey{7}}, NullPlacement::AtEnd}; + Ordering d{{SortKey{1}, SortKey{7, NullPlacement::AtStart}}}; Ordering imp = Ordering::Implicit(); Ordering unordered = Ordering::Unordered(); diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index c9dbe0bd4c0..9553add5043 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -88,24 +88,27 @@ class TestSelectKBase : public ::testing::Test { protected: template - void AssertSelectKArray(const std::shared_ptr values, int k) { + void AssertSelectKArray(const std::shared_ptr values, int k, + bool check_indices = false) { std::shared_ptr select_k; ASSERT_OK_AND_ASSIGN(select_k, SelectK(Datum(*values), k)); ASSERT_EQ(select_k->data()->null_count, 0); ValidateOutput(*select_k); - ValidateSelectK(Datum(*values), *select_k, order); + ValidateSelectK(Datum(*values), *select_k, order, check_indices); } - void AssertTopKArray(const std::shared_ptr values, int n) { - AssertSelectKArray(values, n); + void AssertTopKArray(const std::shared_ptr values, int n, + bool check_indices = false) { + AssertSelectKArray(values, n, check_indices); } - void AssertBottomKArray(const std::shared_ptr values, int n) { - AssertSelectKArray(values, n); + void AssertBottomKArray(const std::shared_ptr values, int n, + bool check_indices = false) { + AssertSelectKArray(values, n, check_indices); } - void AssertSelectKJson(const std::string& values, int n) { - AssertTopKArray(ArrayFromJSON(type_singleton(), values), n); - AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n); + void AssertSelectKJson(const std::string& values, int n, bool check_indices = false) { + AssertTopKArray(ArrayFromJSON(type_singleton(), values), n, check_indices); + AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n, check_indices); } virtual std::shared_ptr type_singleton() = 0; @@ -162,9 +165,11 @@ TYPED_TEST(TestSelectKForReal, Real) { this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); - this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); + // The result will contain nan. By default, the comparison of NaN is not equal, so + // indices are used for comparison. + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4, true); this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); - this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4, true); this->AssertSelectKJson("[100, 4, 2, 7, 8, 3, NaN, 3, 1]", 4); } @@ -234,6 +239,78 @@ TYPED_TEST(TestSelectKRandom, RandomValues) { } } +class TestSelectKWithArray : public ::testing::Test { + public: + void Check(const std::shared_ptr& type, const std::string& array_json, + const SelectKOptions& options, const std::string& expected_array) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(type, array_json, options, &actual)); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected_array), *actual); + } + + void CheckIndices(const std::shared_ptr& type, const std::string& array_json, + const SelectKOptions& options, const std::string& expected_json) { + auto array = ArrayFromJSON(type, array_json); + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*array), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + Status DoSelectK(const std::shared_ptr& type, const std::string& array_json, + const SelectKOptions& options, std::shared_ptr* out) { + auto array = ArrayFromJSON(type, array_json); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*array), options)); + + ValidateOutput(*indices); + ARROW_ASSIGN_OR_RAISE( + auto select_k, Take(Datum(array), Datum(indices), TakeOptions::NoBoundsCheck())); + *out = select_k.make_array(); + return Status::OK(); + } +}; + +TEST_F(TestSelectKWithArray, PartialSelectKNull) { + auto array_input = R"([null, 30, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(3, sort_keys); + auto expected = R"([10, 20, 30])"; + Check(uint8(), array_input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = R"([null, null, 10])"; + Check(uint8(), array_input, options, expected); +} + +TEST_F(TestSelectKWithArray, FullSelectKNull) { + auto array_input = R"([null, 30, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(10, sort_keys); + auto expected = R"([10, 20, 30, null, null])"; + Check(uint8(), array_input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = R"([null, null, 10, 20, 30])"; + Check(uint8(), array_input, options, expected); +} + +TEST_F(TestSelectKWithArray, PartialSelectKNullNaN) { + auto array_input = R"([null, 30, NaN, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(4, sort_keys); + CheckIndices(float64(), array_input, options, "[1, 3, 4, 2]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(float64(), array_input, options, "[0, 5, 2, 1]"); +} + +TEST_F(TestSelectKWithArray, FullSelectKNullNaN) { + auto array_input = R"([null, 30, NaN, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + CheckIndices(float64(), array_input, options, "[1, 3, 4, 2, 0, 5]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(float64(), array_input, options, "[0, 5, 2, 1, 3, 4]"); +} // Test basic cases for chunked array template @@ -263,6 +340,35 @@ struct TestSelectKWithChunkedArray : public ::testing::Test { void AssertBottomK(const std::shared_ptr& chunked_array, int64_t k) { AssertSelectK(chunked_array, k); } + + void Check(const std::shared_ptr& chunked_array, + const SelectKOptions& options, + const std::shared_ptr& expected_array) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(chunked_array, options, &actual)); + AssertChunkedEqual(*expected_array, *actual); + } + + void CheckIndices(const std::shared_ptr& chunked_array, + const SelectKOptions& options, const std::string& expected_json) { + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*chunked_array), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + Status DoSelectK(const std::shared_ptr& chunked_array, + const SelectKOptions& options, std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*chunked_array), options)); + + ValidateOutput(*indices); + ARROW_ASSIGN_OR_RAISE(auto select_k, Take(Datum(chunked_array), Datum(indices), + TakeOptions::NoBoundsCheck())); + *out = select_k.chunked_array(); + return Status::OK(); + } }; TYPED_TEST_SUITE(TestSelectKWithChunkedArray, SelectKableTypes); @@ -283,6 +389,59 @@ TYPED_TEST(TestSelectKWithChunkedArray, RandomValuesWithSlices) { } } +TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNull) { + auto chunked_array = ChunkedArrayFromJSON(uint8(), { + "[null, 1]", + "[3, null, 2]", + "[1]", + }); + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(3, sort_keys); + auto expected = ChunkedArrayFromJSON(uint8(), {"[1, 1, 2]"}); + this->Check(chunked_array, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = ChunkedArrayFromJSON(uint8(), {"[null, null, 1]"}); + this->Check(chunked_array, options, expected); +} + +TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNull) { + auto chunked_array = ChunkedArrayFromJSON(uint8(), { + "[null, 1]", + "[3, null, 2]", + "[1]", + }); + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(10, sort_keys); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + auto expected = ChunkedArrayFromJSON(uint8(), {"[null, null, 1, 1, 2, 3]"}); + this->Check(chunked_array, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + expected = ChunkedArrayFromJSON(uint8(), {"[1, 1, 2, 3, null, null]"}); + this->Check(chunked_array, options, expected); +} + +TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNullNaN) { + auto chunked_array = ChunkedArrayFromJSON( + float64(), {"[null, 1]", "[3, null, NaN]", "[10, NaN, 2]", "[1]"}); + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + this->CheckIndices(chunked_array, options, "[3, 0, 4]"); + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + this->CheckIndices(chunked_array, options, "[5, 2, 7]"); +} + +TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNullNaN) { + auto chunked_array = ChunkedArrayFromJSON( + float64(), {"[null, 1]", "[3, null, NaN]", "[10, NaN, 2]", "[1]"}); + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + this->CheckIndices(chunked_array, options, "[3, 0, 6, 4, 5, 2, 7, 8, 1]"); + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + this->CheckIndices(chunked_array, options, "[5, 2, 7, 8, 1, 6, 4, 3, 0]"); +} + template void ValidateSelectKIndices(const ArrayType& array) { ValidateOutput(array); @@ -363,6 +522,17 @@ class TestSelectKWithRecordBatch : public ::testing::Test { ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); } + void CheckIndices(const std::shared_ptr& schm, const std::string& batch_json, + const SelectKOptions& options, const std::string& expected_json) { + auto batch = RecordBatchFromJSON(schm, batch_json); + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*batch), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + Status DoSelectK(const std::shared_ptr& schm, const std::string& batch_json, const SelectKOptions& options, std::shared_ptr* out) { auto batch = RecordBatchFromJSON(schm, batch_json); @@ -539,6 +709,128 @@ TEST_F(TestSelectKWithRecordBatch, BottomKNull) { Check(schema, batch_input, options, expected_batch); } +TEST_F(TestSelectKWithRecordBatch, PartialSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 30, "b": 3}, + {"a": null, "b": 4}, + {"a": null, "b": 6}, + {"a": 20, "b": 5}, + {"a": null, "b": 5}, + {"a": 10, "b": 3}, + {"a": null, "b": null} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + auto expected_batch = R"([{"a": null, "b": 6}, + {"a": null, "b": 5}, + {"a": null, "b": 5} + ])"; + Check(schema, batch_input, options, expected_batch); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected_batch = R"([{"a": null, "b": null}, + {"a": null, "b": 6}, + {"a": null, "b": 5} + ])"; + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestSelectKWithRecordBatch, FullSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 30, "b": 3}, + {"a": null, "b": 4}, + {"a": null, "b": 6}, + {"a": 20, "b": 5}, + {"a": null, "b": 5}, + {"a": 10, "b": 3}, + {"a": null, "b": null} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + auto expected_batch = R"([{"a": null, "b": 6}, + {"a": null, "b": 5}, + {"a": null, "b": 5}, + {"a": null, "b": 4}, + {"a": null, "b": null}, + {"a": 10, "b": 3}, + {"a": 20, "b": 5}, + {"a": 30, "b": 3} + ])"; + Check(schema, batch_input, options, expected_batch); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected_batch = R"([{"a": null, "b": null}, + {"a": null, "b": 6}, + {"a": null, "b": 5}, + {"a": null, "b": 5}, + {"a": null, "b": 4}, + {"a": 10, "b": 3}, + {"a": 20, "b": 5}, + {"a": 30, "b": 3} + ])"; + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestSelectKWithRecordBatch, PartialSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + CheckIndices(schema, batch_input, options, "[0, 3, 6]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, batch_input, options, "[3, 0, 6]"); +} + +TEST_F(TestSelectKWithRecordBatch, FullSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + CheckIndices(schema, batch_input, options, "[0, 3, 6, 7, 1, 2, 5, 4]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, batch_input, options, "[3, 0, 6, 7, 1, 2, 4, 5]"); +} + TEST_F(TestSelectKWithRecordBatch, BottomKOneColumnKey) { auto schema = ::arrow::schema({ {field("country", utf8())}, @@ -605,6 +897,18 @@ struct TestSelectKWithTable : public ::testing::Test { ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected), *actual); } + void CheckIndices(const std::shared_ptr& schm, + const std::vector& input_json, + const SelectKOptions& options, const std::string& expected_json) { + auto table = TableFromJSON(schm, input_json); + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*table), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + Status DoSelectK(const std::shared_ptr& schm, const std::vector& input_json, const SelectKOptions& options, std::shared_ptr* out) { @@ -711,5 +1015,143 @@ TEST_F(TestSelectKWithTable, BottomKMultipleColumnKeys) { Check(schema, input, options, expected); } +TEST_F(TestSelectKWithTable, PartialSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + std::vector expected = {R"([{"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": 5}, + {"a": null, "b": null}, + {"a": 1, "b": 5} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": null}, + {"a": null, "b": 5}, + {"a": 1, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +TEST_F(TestSelectKWithTable, FullSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + std::vector expected = {R"([{"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5}, + {"a": 3, "b": 5}, + {"a": 3, "b": null}, + {"a": null, "b": 5}, + {"a": null, "b": null} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": 5}, + {"a": null, "b": null}, + {"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5}, + {"a": 3, "b": 5}, + {"a": 3, "b": null} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": null}, + {"a": null, "b": 5}, + {"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5}, + {"a": 3, "b": null}, + {"a": 3, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +TEST_F(TestSelectKWithTable, PartialSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + CheckIndices(schema, input, options, "[7, 1, 2]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[0, 3, 6]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[3, 0, 6]"); +} + +TEST_F(TestSelectKWithTable, FullSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + CheckIndices(schema, input, options, "[7, 1, 2, 5, 4, 6, 0, 3]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[0, 3, 6, 7, 1, 2, 5, 4]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[3, 0, 6, 7, 1, 2, 4, 5]"); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 780ae25d963..1b0d76285c8 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -293,8 +293,10 @@ class RankMetaFunction : public MetaFunction { static Result Rank(const T& input, const RankOptions& options, ExecContext* ctx) { SortOrder order = SortOrder::Ascending; + NullPlacement null_placement = NullPlacement::AtEnd; if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; } int64_t length = input.length(); @@ -305,8 +307,8 @@ class RankMetaFunction : public MetaFunction { std::iota(indices_begin, indices_end, 0); Datum output; - Ranker ranker(ctx, indices_begin, indices_end, input, order, - options.null_placement, options.tiebreaker, &output); + Ranker ranker(ctx, indices_begin, indices_end, input, order, null_placement, + options.tiebreaker, &output); ARROW_RETURN_NOT_OK(ranker.Run()); return output; } diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 5000de89962..9a5a66d287e 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -71,6 +71,120 @@ class SelectKComparator { } }; +struct ExtractCounter { + int64_t extract_non_null_count; + int64_t extract_nan_count; + int64_t extract_null_count; +}; + +class HeapSorter { + public: + using HeapPusherFunction = + std::function; + + HeapSorter(int64_t k, NullPlacement null_placement, MemoryPool* pool) + : k_(k), null_placement_(null_placement), pool_(pool) {} + + Result> HeapSort(HeapPusherFunction heap_pusher, + NullPartitionResult p, + NullPartitionResult q) { + ExtractCounter counter = ComputeExtractCounter(p, q); + return HeapSortInternal(counter, heap_pusher, p, q); + } + + ExtractCounter ComputeExtractCounter(NullPartitionResult p, NullPartitionResult q) { + int64_t extract_non_null_count = 0; + int64_t extract_nan_count = 0; + int64_t extract_null_count = 0; + int64_t non_null_count = q.non_null_count(); + int64_t nan_count = q.null_count(); + int64_t null_count = p.null_count(); + // non-null nan null + if (null_placement_ == NullPlacement::AtEnd) { + extract_non_null_count = non_null_count <= k_ ? non_null_count : k_; + extract_nan_count = extract_non_null_count >= k_ + ? 0 + : std::min(nan_count, k_ - extract_non_null_count); + extract_null_count = extract_non_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_non_null_count + extract_nan_count)); + } else { // null nan non-null + extract_null_count = null_count <= k_ ? null_count : k_; + extract_nan_count = + extract_null_count >= k_ ? 0 : std::min(nan_count, k_ - extract_null_count); + extract_non_null_count = extract_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_null_count + extract_nan_count)); + } + return {extract_non_null_count, extract_nan_count, extract_null_count}; + } + + Result> HeapSortInternal(ExtractCounter counter, + HeapPusherFunction heap_pusher, + NullPartitionResult p, + NullPartitionResult q) { + int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + + counter.extract_null_count; + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); + // [extrat_count....extract_nan_count...extract_null_count] + if (null_placement_ == NullPlacement::AtEnd) { + if (counter.extract_non_null_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count - 1; + auto kth_begin = std::min(q.non_nulls_begin + k_, q.non_nulls_end); + heap_pusher(q.non_nulls_begin, kth_begin, q.non_nulls_end, out_cbegin); + } + + if (counter.extract_nan_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count + counter.extract_nan_count - 1; + auto kth_begin = + std::min(q.nulls_begin + k_ - counter.extract_non_null_count, q.nulls_end); + heap_pusher(q.nulls_begin, kth_begin, q.nulls_end, out_cbegin); + } + + if (counter.extract_null_count) { + auto* out_cbegin = + take_indices->template GetMutableValues(1) + out_size - 1; + auto kth_begin = std::min(p.nulls_begin + k_ - counter.extract_non_null_count - + counter.extract_nan_count, + p.nulls_end); + heap_pusher(p.nulls_begin, kth_begin, p.nulls_end, out_cbegin); + } + } else { // [extract_null_count....extract_nan_count...extrat_count] + if (counter.extract_null_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count - 1; + auto kth_begin = std::min(p.nulls_begin + k_, p.nulls_end); + heap_pusher(p.nulls_begin, kth_begin, p.nulls_end, out_cbegin); + } + + if (counter.extract_nan_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count + counter.extract_nan_count - 1; + auto kth_begin = + std::min(q.nulls_begin + k_ - counter.extract_null_count, q.nulls_end); + heap_pusher(q.nulls_begin, kth_begin, q.nulls_end, out_cbegin); + } + + if (counter.extract_non_null_count) { + auto* out_cbegin = + take_indices->template GetMutableValues(1) + out_size - 1; + auto kth_begin = std::min(q.non_nulls_begin + k_ - counter.extract_null_count - + counter.extract_nan_count, + q.non_nulls_end); + heap_pusher(q.non_nulls_begin, kth_begin, q.non_nulls_end, out_cbegin); + } + } + return take_indices; + } + + private: + int64_t k_; + NullPlacement null_placement_; + MemoryPool* pool_; +}; + class ArraySelecter : public TypeVisitor { public: ArraySelecter(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -80,6 +194,7 @@ class ArraySelecter : public TypeVisitor { array_(array), k_(options.k), order_(options.sort_keys[0].order), + null_placement_(options.sort_keys[0].null_placement), physical_type_(GetPhysicalType(array.type())), output_(output) {} @@ -112,11 +227,10 @@ class ArraySelecter : public TypeVisitor { k_ = arr.length(); } - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); + const auto p = PartitionNullsOnly(indices_begin, indices_end, + arr, 0, null_placement_); + const auto q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, arr, 0, null_placement_); SelectKComparator comparator; auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { @@ -126,24 +240,24 @@ class ArraySelecter : public TypeVisitor { }; using HeapContainer = std::priority_queue, decltype(cmp)>; - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - if (cmp(x_index, heap.top())) { + auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter, uint64_t* out_cbegin) { + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + if (cmp(x_index, heap.top())) { + heap.pop(); + heap.push(x_index); + } + } + while (heap.size() > 0) { + *out_cbegin = heap.top(); heap.pop(); - heap.push(x_index); + --out_cbegin; } - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + }; + HeapSorter h(k_, null_placement_, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); *output_ = Datum(take_indices); return Status::OK(); } @@ -152,6 +266,7 @@ class ArraySelecter : public TypeVisitor { const Array& array_; int64_t k_; SortOrder order_; + NullPlacement null_placement_; const std::shared_ptr physical_type_; Datum* output_; }; @@ -163,6 +278,188 @@ struct TypedHeapItem { ArrayType* array; }; +template +class ChunkedHeapSorter { + public: + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + using HeapItem = TypedHeapItem; + + ChunkedHeapSorter(int64_t k, NullPlacement null_placement, MemoryPool* pool) + : k_(k), null_placement_(null_placement), pool_(pool) {} + + Result> HeapSort(const ArrayVector physical_chunks) { + std::vector> chunks_null_partions; + std::vector> chunks_holder; + std::vector> chunks_indices_holder; + chunks_null_partions.reserve(physical_chunks.size()); + ExtractCounter counter = ComputeExtractCounter(physical_chunks, chunks_null_partions, + chunks_holder, chunks_indices_holder); + return HeapSortInternal(chunks_holder, counter, chunks_null_partions); + } + + // Extract the total count of non-nulls, nans, and nulls for all chunks + ExtractCounter ComputeExtractCounter( + const ArrayVector physical_chunks, + std::vector>& + chunks_null_partions, + std::vector>& chunks_holder, + std::vector>& chunks_indices_holder) { + int64_t all_non_null_count = 0; + int64_t all_nan_count = 0; + int64_t all_null_count = 0; + int64_t extract_non_null_count = 0; + int64_t extract_nan_count = 0; + int64_t extract_null_count = 0; + for (size_t i = 0; i < physical_chunks.size(); i++) { + const auto& chunk = physical_chunks[i]; + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + chunks_indices_holder.emplace_back(std::vector(arr.length())); + std::vector& indices = + chunks_indices_holder[chunks_indices_holder.size() - 1]; + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + NullPartitionResult p = PartitionNullsOnly( + indices_begin, indices_end, arr, 0, null_placement_); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, arr, 0, null_placement_); + int64_t non_null_count = q.non_null_count(); + int64_t nan_count = q.null_count(); + int64_t null_count = p.null_count(); + all_non_null_count += non_null_count; + all_nan_count += nan_count; + all_null_count += null_count; + chunks_null_partions.emplace_back(p, q); + } + // non-null nan null + if (null_placement_ == NullPlacement::AtEnd) { + extract_non_null_count = all_non_null_count <= k_ ? all_non_null_count : k_; + extract_nan_count = extract_non_null_count >= k_ + ? 0 + : std::min(all_nan_count, k_ - extract_non_null_count); + extract_null_count = extract_non_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_non_null_count + extract_nan_count)); + } else { // null nan non-null + extract_null_count = all_null_count <= k_ ? all_null_count : k_; + extract_nan_count = + extract_null_count >= k_ ? 0 : std::min(all_nan_count, k_ - extract_null_count); + extract_non_null_count = extract_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_null_count + extract_nan_count)); + } + return {extract_non_null_count, extract_nan_count, extract_null_count}; + } + + Result> HeapSortInternal( + const std::vector>& chunks_holder, + ExtractCounter counter, + const std::vector>& + chunks_null_partions) { + std::function cmp; + SelectKComparator comparator; + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + HeapContainer non_null_heap(cmp); + HeapContainer nan_heap(cmp); + HeapContainer null_heap(cmp); + + uint64_t offset = 0; + for (size_t i = 0; i < chunks_null_partions.size(); i++) { + const auto& null_part_pair = chunks_null_partions[i]; + const auto& p = null_part_pair.first; + const auto& q = null_part_pair.second; + ArrayType& arr = *chunks_holder[i]; + + auto HeapPusher = [&](HeapContainer& heap, int64_t extract_non_null_count, + uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter) { + uint64_t* iter = indices_begin; + for (; iter != kth_begin && + heap.size() < static_cast(extract_non_null_count); + ++iter) { + heap.push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.pop(); + heap.push(HeapItem{x_index, offset, &arr}); + } + } + }; + HeapPusher( + non_null_heap, counter.extract_non_null_count, q.non_nulls_begin, + std::min(q.non_nulls_begin + counter.extract_non_null_count, q.non_nulls_end), + q.non_nulls_end); + HeapPusher(nan_heap, counter.extract_nan_count, q.nulls_begin, + std::min(q.nulls_begin + counter.extract_nan_count, q.nulls_end), + q.nulls_end); + HeapPusher(null_heap, counter.extract_null_count, p.nulls_begin, + std::min(p.nulls_begin + counter.extract_null_count, p.nulls_end), + p.nulls_end); + offset += arr.length(); + } + + int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + + counter.extract_null_count; + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); + + auto PopHeaper = [&](HeapContainer& heap, uint64_t* out_cbegin) { + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.pop(); + --out_cbegin; + } + }; + + if (null_placement_ == NullPlacement::AtEnd) { + // non_null + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count - 1; + PopHeaper(non_null_heap, out_cbegin); + // nan + out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count + counter.extract_nan_count - 1; + PopHeaper(nan_heap, out_cbegin); + // null + out_cbegin = take_indices->template GetMutableValues(1) + out_size - 1; + PopHeaper(null_heap, out_cbegin); + } else { + // null + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count - 1; + PopHeaper(null_heap, out_cbegin); + // nan + out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count + counter.extract_nan_count - 1; + PopHeaper(nan_heap, out_cbegin); + // non_null + out_cbegin = take_indices->template GetMutableValues(1) + out_size - 1; + PopHeaper(non_null_heap, out_cbegin); + } + return take_indices; + } + + private: + int64_t k_; + NullPlacement null_placement_; + MemoryPool* pool_; +}; + class ChunkedArraySelecter : public TypeVisitor { public: ChunkedArraySelecter(ExecContext* ctx, const ChunkedArray& chunked_array, @@ -173,6 +470,7 @@ class ChunkedArraySelecter : public TypeVisitor { physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), k_(options.k), order_(options.sort_keys[0].order), + null_placement_(options.sort_keys[0].null_placement), ctx_(ctx), output_(output) {} @@ -191,10 +489,6 @@ class ChunkedArraySelecter : public TypeVisitor { template Status SelectKthInternal() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - using HeapItem = TypedHeapItem; - const auto num_chunks = chunked_array_.num_chunks(); if (num_chunks == 0) { return Status::OK(); @@ -202,63 +496,9 @@ class ChunkedArraySelecter : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } - std::function cmp; - SelectKComparator comparator; - - cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { - const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - return comparator(lval, rval); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - - HeapContainer heap(cmp); - std::vector> chunks_holder; - uint64_t offset = 0; - for (const auto& chunk : physical_chunks_) { - if (chunk->length() == 0) continue; - chunks_holder.emplace_back(std::make_shared(chunk->data())); - ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; - - std::vector indices(arr.length()); - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); - std::iota(indices_begin, indices_end, 0); - - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); - uint64_t* iter = indices_begin; - for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { - heap.push(HeapItem{*iter, offset, &arr}); - } - for (; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); - auto top_item = heap.top(); - const auto& top_value = - GetView::LogicalValue(top_item.array->GetView(top_item.index)); - if (comparator(xval, top_value)) { - heap.pop(); - heap.push(HeapItem{x_index, offset, &arr}); - } - } - offset += chunk->length(); - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - auto top_item = heap.top(); - *out_cbegin = top_item.index + top_item.offset; - heap.pop(); - --out_cbegin; - } + ChunkedHeapSorter h(k_, null_placement_, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(physical_chunks_)); *output_ = Datum(take_indices); return Status::OK(); } @@ -268,6 +508,7 @@ class ChunkedArraySelecter : public TypeVisitor { const ArrayVector physical_chunks_; int64_t k_; SortOrder order_; + NullPlacement null_placement_; ExecContext* ctx_; Datum* output_; }; @@ -286,7 +527,7 @@ class RecordBatchSelecter : public TypeVisitor { k_(options.k), output_(output), sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -312,7 +553,7 @@ class RecordBatchSelecter : public TypeVisitor { *status = maybe_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_array), key.order); + resolved.emplace_back(*std::move(maybe_array), key.order, key.null_placement); } return resolved; } @@ -337,7 +578,9 @@ class RecordBatchSelecter : public TypeVisitor { cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { const auto lval = GetView::LogicalValue(arr.GetView(left)); const auto rval = GetView::LogicalValue(arr.GetView(right)); - if (lval == rval) { + const bool is_null_left = arr.IsNull(left); + const bool is_null_right = arr.IsNull(right); + if ((lval == rval) || (is_null_left && is_null_right)) { // If the left value equals to the right value, // we need to compare the second and following // sort keys. @@ -353,30 +596,31 @@ class RecordBatchSelecter : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); + NullPartitionResult p = PartitionNullsOnly( + indices_begin, indices_end, arr, 0, first_sort_key.null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, arr, 0, first_sort_key.null_placement); - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - auto top_item = heap.top(); - if (cmp(x_index, top_item)) { + auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter, uint64_t* out_cbegin) { + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.pop(); + heap.push(x_index); + } + } + while (heap.size() > 0) { + *out_cbegin = heap.top(); heap.pop(); - heap.push(x_index); + --out_cbegin; } - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + }; + + HeapSorter h(k_, first_sort_key.null_placement, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); *output_ = Datum(take_indices); return Status::OK(); } @@ -394,12 +638,13 @@ class TableSelecter : public TypeVisitor { private: struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& chunked_array, - const SortOrder order) + const SortOrder order, NullPlacement null_placement) : order(order), type(GetPhysicalType(chunked_array->type())), chunks(GetPhysicalChunks(*chunked_array, type)), null_count(chunked_array->null_count()), - resolver(GetArrayPointers(chunks)) {} + resolver(GetArrayPointers(chunks)), + null_placement(null_placement) {} using LocationType = int64_t; @@ -415,6 +660,7 @@ class TableSelecter : public TypeVisitor { const ArrayVector chunks; const int64_t null_count; const ChunkedArrayResolver resolver; + NullPlacement null_placement; }; using Comparator = MultipleKeyComparator; @@ -427,7 +673,7 @@ class TableSelecter : public TypeVisitor { k_(options.k), output_(output), sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -454,36 +700,44 @@ class TableSelecter : public TypeVisitor { *status = maybe_chunked_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_chunked_array), key.order); + resolved.emplace_back(*std::move(maybe_chunked_array), key.order, + key.null_placement); } return resolved; } // Behaves like PartitionNulls() but this supports multiple sort keys. - template NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, uint64_t* indices_end, const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - const auto p = PartitionNullsOnly( indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, - NullPlacement::AtEnd); + first_sort_key.null_placement); DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + // Sort all nulls by the second and following sort keys. + std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + + return p; + } + + template + NullPartitionResult PartitionNaNsInternal(uint64_t* indices_begin, + uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, - NullPlacement::AtEnd); + indices_begin, indices_end, first_sort_key.resolver, + first_sort_key.null_placement); auto& comparator = comparator_; // Sort all NaNs by the second and following sort keys. std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); return q; } @@ -510,9 +764,11 @@ class TableSelecter : public TypeVisitor { cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { auto chunk_left = first_sort_key.template GetChunk(left); auto chunk_right = first_sort_key.template GetChunk(right); + const bool is_null_left = chunk_left.IsNull(); + const bool is_null_right = chunk_right.IsNull(); auto value_left = chunk_left.Value(); auto value_right = chunk_right.Value(); - if (value_left == value_right) { + if ((value_left == value_right) || (is_null_left && is_null_right)) { return comparator.Compare(left, right, 1); } return select_k_comparator(value_left, value_right); @@ -526,28 +782,30 @@ class TableSelecter : public TypeVisitor { std::iota(indices_begin, indices_end, 0); const auto p = - this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - const auto end_iter = p.non_nulls_end; - auto kth_begin = std::min(indices_begin + k_, end_iter); - - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - uint64_t top_item = heap.top(); - if (cmp(x_index, top_item)) { + this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); + const auto q = this->PartitionNaNsInternal(p.non_nulls_begin, p.non_nulls_end, + first_sort_key); + + auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter, uint64_t* out_cbegin) { + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + uint64_t top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.pop(); + heap.push(x_index); + } + } + while (heap.size() > 0) { + *out_cbegin = heap.top(); heap.pop(); - heap.push(x_index); + --out_cbegin; } - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + }; + + HeapSorter h(k_, first_sort_key.null_placement, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); *output_ = Datum(take_indices); return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 8ddcbb9905c..e897e393a8c 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -402,7 +402,7 @@ class RadixRecordBatchSorter { : physical_type(sort_key.type), array(sort_key.owned_array), order(sort_key.order), - null_placement(options.null_placement), + null_placement(sort_key.null_placement), next_column(next_column) {} Result> MakeColumnSort() { @@ -462,16 +462,14 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { : indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(std::move(sort_keys)), - null_placement_(options.null_placement), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, const RecordBatch& batch, const SortOptions& options) : indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)), - null_placement_(options.null_placement), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} // This is optimized for the first sort key. The first sort key sort // is processed in this class. The second and following sort keys @@ -552,10 +550,10 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { const ArrayType& array = ::arrow::internal::checked_cast(first_sort_key.array); - const auto p = PartitionNullsOnly(indices_begin_, indices_end_, - array, 0, null_placement_); + const auto p = PartitionNullsOnly( + indices_begin_, indices_end_, array, 0, first_sort_key.null_placement); const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, array, 0, null_placement_); + p.non_nulls_begin, p.non_nulls_end, array, 0, first_sort_key.null_placement); auto& comparator = comparator_; if (q.nulls_begin != q.nulls_end) { @@ -583,7 +581,6 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { uint64_t* indices_end_; Status status_; std::vector sort_keys_; - NullPlacement null_placement_; Comparator comparator_; }; @@ -607,13 +604,12 @@ class TableSorter { table_(table), batches_(MakeBatches(table, &status_)), options_(options), - null_placement_(options.null_placement), left_resolver_(batches_), right_resolver_(batches_), sort_keys_(ResolveSortKeys(table, batches_, options.sort_keys, &status_)), indices_begin_(indices_begin), indices_end_(indices_end), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} // This is optimized for null partitioning and merging along the first sort key. // Other sort keys are delegated to the Comparator class. @@ -711,7 +707,7 @@ class TableSorter { MergeNonNulls(range_begin, range_middle, range_end, temp_indices); }; - MergeImpl merge_impl(options_.null_placement, std::move(merge_nulls), + MergeImpl merge_impl(sort_keys_[0].null_placement, std::move(merge_nulls), std::move(merge_non_nulls)); RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); @@ -758,7 +754,7 @@ class TableSorter { const auto right_is_null = chunk_right.IsNull(); if (left_is_null == right_is_null) { return comparator.Compare(left_loc, right_loc, 1); - } else if (options_.null_placement == NullPlacement::AtEnd) { + } else if (first_sort_key.null_placement == NullPlacement::AtEnd) { return right_is_null; } else { return left_is_null; @@ -847,7 +843,6 @@ class TableSorter { const Table& table_; const RecordBatchVector batches_; const SortOptions& options_; - const NullPlacement null_placement_; const ::arrow::internal::ChunkResolver left_resolver_, right_resolver_; const std::vector sort_keys_; uint64_t* indices_begin_; @@ -936,18 +931,22 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; + NullPlacement null_placement = NullPlacement::AtEnd; if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; } - ArraySortOptions array_options(order, options.null_placement); + ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } Result SortIndices(const ChunkedArray& chunked_array, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; + NullPlacement null_placement = NullPlacement::AtEnd; if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; } auto out_type = uint64(); @@ -962,8 +961,8 @@ class SortIndicesMetaFunction : public MetaFunction { auto out_end = out_begin + length; std::iota(out_begin, out_end, 0); - RETURN_NOT_OK(SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, - options.null_placement)); + RETURN_NOT_OK( + SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, null_placement)); return Datum(out); } @@ -1056,7 +1055,7 @@ struct SortFieldPopulator { PrependInvalidColumn(sort_key.target.FindOne(schema))); if (seen_.insert(match).second) { ARROW_ASSIGN_OR_RAISE(auto schema_field, match.Get(schema)); - AddField(*schema_field->type(), match, sort_key.order); + AddField(*schema_field->type(), match, sort_key.order, sort_key.null_placement); } } @@ -1064,7 +1063,8 @@ struct SortFieldPopulator { } protected: - void AddLeafFields(const FieldVector& fields, SortOrder order) { + void AddLeafFields(const FieldVector& fields, SortOrder order, + NullPlacement null_placement) { if (fields.empty()) { return; } @@ -1073,21 +1073,22 @@ struct SortFieldPopulator { for (const auto& f : fields) { const auto& type = *f->type(); if (type.id() == Type::STRUCT) { - AddLeafFields(type.fields(), order); + AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type); + sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type, null_placement); } ++tmp_indices_.back(); } tmp_indices_.pop_back(); } - void AddField(const DataType& type, const FieldPath& path, SortOrder order) { + void AddField(const DataType& type, const FieldPath& path, SortOrder order, + NullPlacement null_placement) { if (type.id() == Type::STRUCT) { tmp_indices_ = path.indices(); - AddLeafFields(type.fields(), order); + AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(path, order, &type); + sort_fields_.emplace_back(path, order, &type, null_placement); } } @@ -1135,10 +1136,9 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ std::move(columns)); auto options = SortOptions::Defaults(); - options.null_placement = null_placement; options.sort_keys.reserve(array.num_fields()); for (int i = 0; i < array.num_fields(); ++i) { - options.sort_keys.push_back(SortKey(FieldRef(i), sort_order)); + options.sort_keys.push_back(SortKey(FieldRef(i), sort_order, null_placement)); } ARROW_ASSIGN_OR_RAISE(auto sort_keys, diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index d7e5575c807..b18f84b3c05 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -475,16 +475,19 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ struct SortField { SortField() = default; - SortField(FieldPath path, SortOrder order, const DataType* type) - : path(std::move(path)), order(order), type(type) {} - SortField(int index, SortOrder order, const DataType* type) - : SortField(FieldPath({index}), order, type) {} + SortField(FieldPath path, SortOrder order, const DataType* type, + NullPlacement null_placement) + : path(std::move(path)), order(order), type(type), null_placement(null_placement) {} + SortField(int index, SortOrder order, const DataType* type, + NullPlacement null_placement) + : SortField(FieldPath({index}), order, type, null_placement) {} bool is_nested() const { return path.indices().size() > 1; } FieldPath path; SortOrder order; const DataType* type; + NullPlacement null_placement; }; inline Status CheckNonNested(const FieldRef& ref) { @@ -530,9 +533,10 @@ Result> ResolveSortKeys( // paths [0,0,0,0] and [0,0,0,1], we shouldn't need to flatten the first three // components more than once. ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(table_or_batch)); - return ResolvedSortKey{std::move(child), f.order}; + return ResolvedSortKey{std::move(child), f.order, f.null_placement}; } - return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order}; + return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order, + f.null_placement}; }); } @@ -582,15 +586,13 @@ template struct ColumnComparator { using Location = typename ResolvedSortKey::LocationType; - ColumnComparator(const ResolvedSortKey& sort_key, NullPlacement null_placement) - : sort_key_(sort_key), null_placement_(null_placement) {} + explicit ColumnComparator(const ResolvedSortKey& sort_key) : sort_key_(sort_key) {} virtual ~ColumnComparator() = default; virtual int Compare(const Location& left, const Location& right) const = 0; ResolvedSortKey sort_key_; - NullPlacement null_placement_; }; template @@ -611,13 +613,13 @@ struct ConcreteColumnComparator : public ColumnComparator { if (is_null_left && is_null_right) { return 0; } else if (is_null_left) { - return this->null_placement_ == NullPlacement::AtStart ? -1 : 1; + return sort_key.null_placement == NullPlacement::AtStart ? -1 : 1; } else if (is_null_right) { - return this->null_placement_ == NullPlacement::AtStart ? 1 : -1; + return sort_key.null_placement == NullPlacement::AtStart ? 1 : -1; } } return CompareTypeValues(chunk_left.Value(), chunk_right.Value(), - sort_key.order, this->null_placement_); + sort_key.order, sort_key.null_placement); } }; @@ -638,9 +640,8 @@ class MultipleKeyComparator { public: using Location = typename ResolvedSortKey::LocationType; - MultipleKeyComparator(const std::vector& sort_keys, - NullPlacement null_placement) - : sort_keys_(sort_keys), null_placement_(null_placement) { + explicit MultipleKeyComparator(const std::vector& sort_keys) + : sort_keys_(sort_keys) { status_ &= MakeComparators(); } @@ -674,13 +675,11 @@ class MultipleKeyComparator { template Status VisitGeneric(const Type& type) { - res.reset( - new ConcreteColumnComparator{sort_key, null_placement}); + res.reset(new ConcreteColumnComparator{sort_key}); return Status::OK(); } const ResolvedSortKey& sort_key; - NullPlacement null_placement; std::unique_ptr> res; }; @@ -688,7 +687,7 @@ class MultipleKeyComparator { column_comparators_.reserve(sort_keys_.size()); for (const auto& sort_key : sort_keys_) { - ColumnComparatorFactory factory{sort_key, null_placement_, nullptr}; + ColumnComparatorFactory factory{sort_key, nullptr}; RETURN_NOT_OK(VisitTypeInline(*sort_key.type, &factory)); column_comparators_.push_back(std::move(factory.res)); } @@ -716,18 +715,19 @@ class MultipleKeyComparator { } const std::vector& sort_keys_; - const NullPlacement null_placement_; std::vector>> column_comparators_; Status status_; }; struct ResolvedRecordBatchSortKey { - ResolvedRecordBatchSortKey(const std::shared_ptr& array, SortOrder order) + ResolvedRecordBatchSortKey(const std::shared_ptr& array, SortOrder order, + NullPlacement null_placement) : type(GetPhysicalType(array->type())), owned_array(GetPhysicalArray(*array, type)), array(*owned_array), order(order), - null_count(array->null_count()) {} + null_count(array->null_count()), + null_placement(null_placement) {} using LocationType = int64_t; @@ -741,16 +741,18 @@ struct ResolvedRecordBatchSortKey { const Array& array; SortOrder order; int64_t null_count; + NullPlacement null_placement; }; struct ResolvedTableSortKey { ResolvedTableSortKey(const std::shared_ptr& type, ArrayVector chunks, - SortOrder order, int64_t null_count) + SortOrder order, int64_t null_count, NullPlacement null_placement) : type(GetPhysicalType(type)), owned_chunks(std::move(chunks)), chunks(GetArrayPointers(owned_chunks)), order(order), - null_count(null_count) {} + null_count(null_count), + null_placement(null_placement) {} using LocationType = ::arrow::internal::ChunkLocation; @@ -777,7 +779,7 @@ struct ResolvedTableSortKey { } return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order, - null_count); + null_count, f.null_placement); }; return ::arrow::compute::internal::ResolveSortKeys( @@ -789,6 +791,7 @@ struct ResolvedTableSortKey { std::vector chunks; SortOrder order; int64_t null_count; + NullPlacement null_placement; }; inline Result> MakeMutableUInt64Array( diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 1328dddc041..e324abd7791 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1205,9 +1205,8 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { ])"); for (auto null_placement : AllNullPlacements()) { - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}, - null_placement); + SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), + SortKey("b", SortOrder::Descending, null_placement)}); AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } @@ -1230,9 +1229,11 @@ TEST_F(TestRecordBatchSortIndices, Null) { const std::vector sort_keys{SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 5, 1, 4, 6, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1251,12 +1252,14 @@ TEST_F(TestRecordBatchSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } @@ -1275,12 +1278,14 @@ TEST_F(TestRecordBatchSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1299,12 +1304,14 @@ TEST_F(TestRecordBatchSortIndices, Boolean) { {"a": false, "b": null}, {"a": null, "b": true} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[7, 5, 3, 1, 6, 2, 4, 0]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } @@ -1322,12 +1329,15 @@ TEST_F(TestRecordBatchSortIndices, MoreTypes) { {"a": 2, "b": "05", "c": "aaa"}, {"a": 1, "b": "05", "c": "bbb"} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending), - SortKey("c", SortOrder::Ascending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending), + SortKey("c", SortOrder::Ascending)}; for (auto null_placement : AllNullPlacements()) { - SortOptions options(sort_keys, null_placement); + SortOptions options(sort_keys); + for (size_t i = 0; i < sort_keys.size(); i++) { + options.sort_keys[i].null_placement = null_placement; + } AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]"); } } @@ -1344,12 +1354,14 @@ TEST_F(TestRecordBatchSortIndices, Decimal) { {"a": "-12.3", "b": null}, {"a": "-12.3", "b": "-45.67"} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 4, 0, 2, 1]"); } @@ -1375,37 +1387,31 @@ TEST_F(TestRecordBatchSortIndices, NullType) { for (const auto order : AllOrders()) { // Uses radix sorter AssertSortIndices(batch, - SortOptions( - { - SortKey("a", order), - SortKey("i", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("i", order, null_placement), + }), "[0, 1, 2, 3]"); AssertSortIndices(batch, - SortOptions( - { - SortKey("a", order), - SortKey("b", SortOrder::Ascending), - SortKey("i", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("b", SortOrder::Ascending, null_placement), + SortKey("i", order, null_placement), + }), "[2, 3, 0, 1]"); // Uses multiple-key sorter AssertSortIndices(batch, - SortOptions( - { - SortKey("a", order), - SortKey("b", SortOrder::Ascending), - SortKey("c", SortOrder::Ascending), - SortKey("d", SortOrder::Ascending), - SortKey("e", SortOrder::Ascending), - SortKey("f", SortOrder::Ascending), - SortKey("g", SortOrder::Ascending), - SortKey("h", SortOrder::Ascending), - SortKey("i", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("b", SortOrder::Ascending, null_placement), + SortKey("c", SortOrder::Ascending, null_placement), + SortKey("d", SortOrder::Ascending, null_placement), + SortKey("e", SortOrder::Ascending, null_placement), + SortKey("f", SortOrder::Ascending, null_placement), + SortKey("g", SortOrder::Ascending, null_placement), + SortKey("h", SortOrder::Ascending, null_placement), + SortKey("i", order), + }), "[2, 3, 0, 1]"); } } @@ -1428,14 +1434,16 @@ TEST_F(TestRecordBatchSortIndices, DuplicateSortKeys) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); - const std::vector sort_keys{ + std::vector sort_keys{ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending), SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1447,16 +1455,19 @@ TEST_F(TestTableSortIndices, EmptyTable) { {field("a", uint8())}, {field("b", uint32())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {"[]"}); auto chunked_table = TableFromJSON(schema, {"[]", "[]"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[]"); + AssertSortIndices(chunked_table, options, "[]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); } @@ -1467,7 +1478,7 @@ TEST_F(TestTableSortIndices, EmptySortKeys) { {field("b", uint32())}, }); const std::vector sort_keys{}; - const SortOptions options(sort_keys, NullPlacement::AtEnd); + const SortOptions options(sort_keys); auto table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])"}); EXPECT_RAISES_WITH_MESSAGE_THAT( @@ -1486,8 +1497,8 @@ TEST_F(TestTableSortIndices, Null) { {field("a", uint8())}, {field("b", uint32())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -1498,9 +1509,11 @@ TEST_F(TestTableSortIndices, Null) { {"a": 1, "b": 5}, {"a": 3, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 5, 1, 4, 6, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); // Same data, several chunks @@ -1513,9 +1526,12 @@ TEST_F(TestTableSortIndices, Null) { {"a": 1, "b": 5}, {"a": 3, "b": 5} ])"}); - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 5, 1, 4, 6, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1524,8 +1540,8 @@ TEST_F(TestTableSortIndices, NaN) { {field("a", float32())}, {field("b", float64())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; table = TableFromJSON(schema, {R"([{"a": 3, "b": 5}, @@ -1537,9 +1553,11 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); // Same data, several chunks @@ -1553,9 +1571,12 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } @@ -1564,8 +1585,8 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {field("a", float32())}, {field("b", float64())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -1577,9 +1598,11 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); // Same data, several chunks @@ -1593,9 +1616,12 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1604,8 +1630,8 @@ TEST_F(TestTableSortIndices, Boolean) { {field("a", boolean())}, {field("b", boolean())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {R"([{"a": true, "b": null}, {"a": false, "b": null}, @@ -1617,9 +1643,11 @@ TEST_F(TestTableSortIndices, Boolean) { {"a": false, "b": null}, {"a": null, "b": true} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[7, 5, 3, 1, 6, 2, 4, 0]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } @@ -1628,8 +1656,8 @@ TEST_F(TestTableSortIndices, BinaryLike) { {field("a", large_utf8())}, {field("b", fixed_size_binary(3))}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Descending), - SortKey("b", SortOrder::Ascending)}; + std::vector sort_keys{SortKey("a", SortOrder::Descending), + SortKey("b", SortOrder::Ascending)}; auto table = TableFromJSON(schema, {R"([{"a": "one", "b": null}, {"a": "two", "b": "aaa"}, @@ -1641,9 +1669,10 @@ TEST_F(TestTableSortIndices, BinaryLike) { {"a": "three", "b": "bbb"}, {"a": "four", "b": "aaa"} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[1, 5, 2, 6, 0, 4, 7, 3]"); } @@ -1652,8 +1681,8 @@ TEST_F(TestTableSortIndices, Decimal) { {field("a", decimal128(3, 1))}, {field("b", decimal256(4, 2))}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {R"([{"a": "12.3", "b": "12.34"}, {"a": "45.6", "b": "12.34"}, @@ -1662,9 +1691,11 @@ TEST_F(TestTableSortIndices, Decimal) { R"([{"a": "-12.3", "b": null}, {"a": "-12.3", "b": "-45.67"} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); + AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 4, 0, 2, 1]"); } @@ -1687,21 +1718,17 @@ TEST_F(TestTableSortIndices, NullType) { for (const auto null_placement : AllNullPlacements()) { for (const auto order : AllOrders()) { AssertSortIndices(table, - SortOptions( - { - SortKey("a", order), - SortKey("d", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("d", order, null_placement), + }), "[0, 1, 2, 3]"); AssertSortIndices(table, - SortOptions( - { - SortKey("a", order), - SortKey("b", SortOrder::Ascending), - SortKey("d", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("b", SortOrder::Ascending, null_placement), + SortKey("d", order, null_placement), + }), "[2, 3, 0, 1]"); } } @@ -1714,7 +1741,7 @@ TEST_F(TestTableSortIndices, DuplicateSortKeys) { {field("a", float32())}, {field("b", float64())}, }); - const std::vector sort_keys{ + std::vector sort_keys{ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending), SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}; @@ -1730,9 +1757,11 @@ TEST_F(TestTableSortIndices, DuplicateSortKeys) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1752,13 +1781,17 @@ TEST_F(TestTableSortIndices, HeterogenousChunking) { SortOptions options( {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); options = SortOptions( {SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}); AssertSortIndices(table, options, "[1, 7, 6, 0, 5, 2, 4, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[2, 4, 3, 5, 1, 7, 6, 0]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 4, 2, 5, 1, 0, 6, 7]"); } @@ -1772,8 +1805,8 @@ TYPED_TEST_SUITE(TestTableSortIndicesForTemporal, TemporalArrowTypes); TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { auto type = this->GetType(); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema({ {field("a", type)}, {field("b", type)}, @@ -1788,7 +1821,10 @@ TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { {"a": 1, "b": 2} ])"}); for (auto null_placement : AllNullPlacements()) { - SortOptions options(sort_keys, null_placement); + SortOptions options(sort_keys); + for (size_t i = 0; i < sort_keys.size(); i++) { + options.sort_keys[i].null_placement = null_placement; + } AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); } } @@ -1861,12 +1897,12 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { DCHECK(!sort_key.target.IsNested()); if (auto name = sort_key.target.name()) { - sort_columns_.emplace_back(table.GetColumnByName(*name).get(), sort_key.order); + sort_columns_.emplace_back(table.GetColumnByName(*name).get(), sort_key); continue; } auto index = sort_key.target.field_path()->indices()[0]; - sort_columns_.emplace_back(table.column(index).get(), sort_key.order); + sort_columns_.emplace_back(table.column(index).get(), sort_key); } } @@ -1874,7 +1910,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { // false otherwise. bool operator()(uint64_t lhs, uint64_t rhs) { for (const auto& pair : sort_columns_) { - ColumnComparator comparator(pair.second, options_.null_placement); + ColumnComparator comparator(pair.second.order, pair.second.null_placement); const auto& chunked_array = *pair.first; int64_t lhs_index = 0, rhs_index = 0; const Array* lhs_array = FindTargetArray(chunked_array, lhs, &lhs_index); @@ -1903,7 +1939,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { } const SortOptions& options_; - std::vector> sort_columns_; + std::vector> sort_columns_; }; public: @@ -2064,7 +2100,9 @@ TEST_P(TestTableSortIndicesRandom, Sort) { auto table = Table::Make(schema, std::move(columns)); for (auto null_placement : AllNullPlacements()) { ARROW_SCOPED_TRACE("null_placement = ", null_placement); - options.null_placement = null_placement; + for (auto& sort_key : sort_keys) { + sort_key.null_placement = null_placement; + } ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } @@ -2083,7 +2121,9 @@ TEST_P(TestTableSortIndicesRandom, Sort) { for (auto null_placement : AllNullPlacements()) { ARROW_SCOPED_TRACE("null_placement = ", null_placement); - options.null_placement = null_placement; + for (auto& sort_key : sort_keys) { + sort_key.null_placement = null_placement; + } ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(batch), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } @@ -2173,18 +2213,19 @@ class TestNestedSortIndices : public ::testing::Test { std::vector sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Ascending), SortKey(FieldRef("a", "b"), SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(datum, options, "[7, 6, 3, 4, 0, 2, 1, 8, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(datum, options, "[5, 2, 1, 8, 3, 7, 6, 0, 4]"); // Implementations may have an optimized path for cases with one sort key. // Additionally, this key references a struct containing another struct, which should // work recursively options.sort_keys = {SortKey(FieldRef("a"), SortOrder::Ascending)}; - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; AssertSortIndices(datum, options, "[6, 7, 3, 4, 0, 8, 1, 2, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(datum, options, "[5, 8, 1, 2, 3, 6, 7, 0, 4]"); } @@ -2239,8 +2280,8 @@ class TestRank : public ::testing::Test { static void AssertRank(const DatumVector& datums, SortOrder order, NullPlacement null_placement, RankOptions::Tiebreaker tiebreaker, const std::shared_ptr& expected) { - const std::vector sort_keys{SortKey("foo", order)}; - RankOptions options(sort_keys, null_placement, tiebreaker); + const std::vector sort_keys{SortKey("foo", order, null_placement)}; + RankOptions options(sort_keys, tiebreaker); ARROW_SCOPED_TRACE("options = ", options.ToString()); for (const auto& datum : datums) { ASSERT_OK_AND_ASSIGN(auto actual, CallFunction("rank", {datum}, &options)); diff --git a/cpp/src/arrow/compute/ordering.cc b/cpp/src/arrow/compute/ordering.cc index 25ad6a5ca5f..fc1367284ec 100644 --- a/cpp/src/arrow/compute/ordering.cc +++ b/cpp/src/arrow/compute/ordering.cc @@ -24,7 +24,8 @@ namespace arrow { namespace compute { bool SortKey::Equals(const SortKey& other) const { - return target == other.target && order == other.order; + return target == other.target && order == other.order && + null_placement == other.null_placement; } std::string SortKey::ToString() const { @@ -38,6 +39,15 @@ std::string SortKey::ToString() const { ss << "DESC"; break; } + + switch (null_placement) { + case NullPlacement::AtStart: + ss << " AtStart"; + break; + case NullPlacement::AtEnd: + ss << " AtEnd"; + break; + } return ss.str(); } @@ -47,14 +57,11 @@ bool Ordering::IsSuborderOf(const Ordering& other) const { // is a subordering of everything return !is_implicit_; } - if (null_placement_ != other.null_placement_) { - return false; - } if (sort_keys_.size() > other.sort_keys_.size()) { return false; } for (std::size_t key_idx = 0; key_idx < sort_keys_.size(); key_idx++) { - if (sort_keys_[key_idx] != other.sort_keys_[key_idx]) { + if (!sort_keys_[key_idx].Equals(other.sort_keys_[key_idx])) { return false; } } @@ -62,7 +69,7 @@ bool Ordering::IsSuborderOf(const Ordering& other) const { } bool Ordering::Equals(const Ordering& other) const { - return null_placement_ == other.null_placement_ && sort_keys_ == other.sort_keys_; + return sort_keys_ == other.sort_keys_; } std::string Ordering::ToString() const { @@ -78,16 +85,6 @@ std::string Ordering::ToString() const { ss << key.ToString(); } ss << "]"; - switch (null_placement_) { - case NullPlacement::AtEnd: - ss << " nulls last"; - break; - case NullPlacement::AtStart: - ss << " nulls first"; - break; - default: - Unreachable(); - } return ss.str(); } diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index e581269cc20..146571568d5 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -46,8 +46,11 @@ enum class NullPlacement { /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices class ARROW_EXPORT SortKey : public util::EqualityComparable { public: - explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending) - : target(std::move(target)), order(order) {} + explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending, + NullPlacement null_placement = NullPlacement::AtEnd) + : target(std::move(target)), order(order), null_placement(null_placement) {} + explicit SortKey(FieldRef target, NullPlacement null_placement) + : SortKey(std::move(target), SortOrder::Ascending, null_placement) {} bool Equals(const SortKey& other) const; std::string ToString() const; @@ -56,13 +59,13 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { FieldRef target; /// How to order by this sort key. SortOrder order; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; }; class ARROW_EXPORT Ordering : public util::EqualityComparable { public: - Ordering(std::vector sort_keys, - NullPlacement null_placement = NullPlacement::AtStart) - : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} + explicit Ordering(std::vector sort_keys) : sort_keys_(std::move(sort_keys)) {} /// true if data ordered by other is also ordered by this /// /// For example, if data is ordered by [a, b, c] then it is also ordered @@ -91,7 +94,6 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { bool is_unordered() const { return !is_implicit_ && sort_keys_.empty(); } const std::vector& sort_keys() const { return sort_keys_; } - NullPlacement null_placement() const { return null_placement_; } static const Ordering& Implicit() { static const Ordering kImplicit(true); @@ -107,12 +109,9 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { } private: - explicit Ordering(bool is_implicit) - : null_placement_(NullPlacement::AtStart), is_implicit_(is_implicit) {} + explicit Ordering(bool is_implicit) : is_implicit_(is_implicit) {} /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys_; - /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement_; bool is_implicit_ = false; }; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 73f55c27ee8..3f5294444b5 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -797,21 +797,10 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::vector sort_keys; sort_keys.reserve(sort.sorts_size()); - // Substrait allows null placement to differ for each field. Acero expects it to - // be consistent across all fields. So we grab the null placement from the first - // key and verify all other keys have the same null placement - std::optional sample_sort_behavior; + // Substrait allows null placement to differ for each field. for (const auto& sort : sort.sorts()) { ARROW_ASSIGN_OR_RAISE(SortBehavior sort_behavior, SortBehavior::Make(sort.direction())); - if (sample_sort_behavior) { - if (sample_sort_behavior->null_placement != sort_behavior.null_placement) { - return Status::NotImplemented( - "substrait::SortRel with ordering with mixed null placement"); - } - } else { - sample_sort_behavior = sort_behavior; - } if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) { return Status::NotImplemented("substrait::SortRel with custom sort function"); } @@ -819,18 +808,17 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& FromProto(sort.expr(), ext_set, conversion_options)); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { - sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order)); + sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order, + sort_behavior.null_placement)); } else { return Status::Invalid("Sort key expressions must be a direct reference."); } } - DCHECK(sample_sort_behavior.has_value()); acero::Declaration sort_dec{ "order_by", {input.declaration}, - acero::OrderByNodeOptions(compute::Ordering( - std::move(sort_keys), sample_sort_behavior->null_placement))}; + acero::OrderByNodeOptions(compute::Ordering(std::move(sort_keys)))}; DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema}; return ProcessEmit(sort, std::move(sort_declaration), diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 2e72ae70edd..e1a2c24184d 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5378,8 +5378,6 @@ TEST(Substrait, SortAndFetch) { } TEST(Substrait, MixedSort) { - // Substrait allows two sort keys with differing direction but Acero - // does not. We should detect this and reject it. std::string substrait_json = R"({ "version": { "major_number": 9999, @@ -5474,10 +5472,9 @@ TEST(Substrait, MixedSort) { ConversionOptions conversion_options; conversion_options.named_table_provider = std::move(table_provider); - ASSERT_THAT( - DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, - conversion_options), - Raises(StatusCode::NotImplemented, testing::HasSubstr("mixed null placement"))); + ASSERT_OK_AND_ASSIGN( + auto plan_info, DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, + conversion_options)); } TEST(Substrait, PlanWithExtension) { diff --git a/python/pyarrow/_acero.pyx b/python/pyarrow/_acero.pyx index bb3196c86ef..6281aaf157c 100644 --- a/python/pyarrow/_acero.pyx +++ b/python/pyarrow/_acero.pyx @@ -233,18 +233,19 @@ class AggregateNodeOptions(_AggregateNodeOptions): cdef class _OrderByNodeOptions(ExecNodeOptions): - def _set_options(self, sort_keys, null_placement): + def _set_options(self, sort_keys): cdef: vector[CSortKey] c_sort_keys - for name, order in sort_keys: + for name, order, null_placement in sort_keys: c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order( + order), unwrap_null_placement(null_placement)) ) self.wrapped.reset( new COrderByNodeOptions( - COrdering(c_sort_keys, unwrap_null_placement(null_placement)) + COrdering(c_sort_keys) ) ) @@ -261,19 +262,16 @@ class OrderByNodeOptions(_OrderByNodeOptions): Parameters ---------- - sort_keys : sequence of (name, order) tuples + sort_keys : sequence of (name, order, null_placement) tuples Names of field/column keys to sort the input on, along with the order each field/column is sorted in. Accepted values for `order` are "ascending", "descending". + Accepted values for `null_placement` are "at_start", "at_end". Each field reference can be a string column name or expression. - null_placement : str, default "at_end" - Where nulls in input should be sorted, only applying to - columns/fields mentioned in `sort_keys`. - Accepted values are "at_start", "at_end". """ - def __init__(self, sort_keys=(), *, null_placement="at_end"): - self._set_options(sort_keys, null_placement) + def __init__(self, sort_keys=(), *): + self._set_options(sort_keys) cdef class _HashJoinNodeOptions(ExecNodeOptions): diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 25f77d8160e..29b74841538 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2057,14 +2057,14 @@ class ArraySortOptions(_ArraySortOptions): cdef class _SortOptions(FunctionOptions): - def _set_options(self, sort_keys, null_placement): + def _set_options(self, sort_keys): cdef vector[CSortKey] c_sort_keys - for name, order in sort_keys: + for name, order, null_placement in sort_keys: c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order( + order), unwrap_null_placement(null_placement)) ) - self.wrapped.reset(new CSortOptions( - c_sort_keys, unwrap_null_placement(null_placement))) + self.wrapped.reset(new CSortOptions(c_sort_keys)) class SortOptions(_SortOptions): @@ -2078,22 +2078,19 @@ class SortOptions(_SortOptions): along with the order each field/column is sorted in. Accepted values for `order` are "ascending", "descending". The field name can be a string column name or expression. - null_placement : str, default "at_end" - Where nulls in input should be sorted, only applying to - columns/fields mentioned in `sort_keys`. - Accepted values are "at_start", "at_end". """ - def __init__(self, sort_keys=(), *, null_placement="at_end"): - self._set_options(sort_keys, null_placement) + def __init__(self, sort_keys=(), *): + self._set_options(sort_keys) cdef class _SelectKOptions(FunctionOptions): def _set_options(self, k, sort_keys): cdef vector[CSortKey] c_sort_keys - for name, order in sort_keys: + for name, order, null_placement in sort_keys: c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order( + order), unwrap_null_placement(null_placement)) ) self.wrapped.reset(new CSelectKOptions(k, c_sort_keys)) @@ -2285,18 +2282,18 @@ cdef class _RankOptions(FunctionOptions): cdef vector[CSortKey] c_sort_keys if isinstance(sort_keys, str): c_sort_keys.push_back( - CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) + CSortKey(_ensure_field_ref(""), + unwrap_sort_order(sort_keys), unwrap_null_placement(null_placement)) ) else: - for name, order in sort_keys: + for name, order, placement in sort_keys: c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order( + order), unwrap_null_placement(placement)) ) try: self.wrapped.reset( - new CRankOptions(c_sort_keys, - unwrap_null_placement(null_placement), - self._tiebreaker_map[tiebreaker]) + new CRankOptions(c_sort_keys, self._tiebreaker_map[tiebreaker]) ) except KeyError: _raise_invalid_function_option(tiebreaker, "tiebreaker") @@ -2308,10 +2305,11 @@ class RankOptions(_RankOptions): Parameters ---------- - sort_keys : sequence of (name, order) tuples or str, default "ascending" + sort_keys : sequence of (name, order, null_placement) tuples or str, default "ascending" Names of field/column keys to sort the input on, along with the order each field/column is sorted in. Accepted values for `order` are "ascending", "descending". + Accepted values for `null_placement` are "at_start", "at_end". The field name can be a string column name or expression. Alternatively, one can simply pass "ascending" or "descending" as a string if the input is array-like. diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 48ee6769153..29ffbd1409f 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -802,11 +802,13 @@ cdef class Dataset(_Weakrefable): Parameters ---------- - sorting : str or list[tuple(name, order)] + sorting : str or list[tuple(name, order, null_placement)] Name of the column to use to sort (ascending), or a list of multiple sorting conditions where each entry is a tuple with column name and sorting order ("ascending" or "descending") + and nulls and NaNs are placed + at the start or at the end ("at_start" or "at_end") **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -817,7 +819,7 @@ cdef class Dataset(_Weakrefable): A new dataset sorted according to the sort keys. """ if isinstance(sorting, str): - sorting = [(sorting, "ascending")] + sorting = [(sorting, "ascending", "at_end")] res = _pac()._sort_source( self, output_type=InMemoryDataset, sort_keys=sorting, **kwargs diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 2e975038227..0a80dd42848 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1501,7 +1501,7 @@ cdef class Array(_PandasConvertible): """ return _pc().index(self, value, start, end, memory_pool=memory_pool) - def sort(self, order="ascending", **kwargs): + def sort(self, order="ascending", null_placement="at_end", **kwargs): """ Sort the Array @@ -1510,6 +1510,9 @@ cdef class Array(_PandasConvertible): order : str, default "ascending" Which order to sort values in. Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -1520,7 +1523,7 @@ cdef class Array(_PandasConvertible): """ indices = _pc().sort_indices( self, - options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + options=_pc().SortOptions(sort_keys=[("", order, null_placement)], **kwargs) ) return self.take(indices) @@ -3241,7 +3244,7 @@ cdef class StructArray(Array): result.validate() return result - def sort(self, order="ascending", by=None, **kwargs): + def sort(self, order="ascending", null_placement="at_end", by=None, **kwargs): """ Sort the StructArray @@ -3250,6 +3253,9 @@ cdef class StructArray(Array): order : str, default "ascending" Which order to sort values in. Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". by : str or None, default None If to sort the array by one of its fields or by the whole array. @@ -3267,7 +3273,7 @@ cdef class StructArray(Array): tosort = self indices = _pc().sort_indices( tosort, - options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + options=_pc().SortOptions(sort_keys=[("", order, null_placement)], **kwargs) ) return self.take(indices) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 205ab393b8b..8b30a2f1e7b 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -545,7 +545,7 @@ def fill_null(values, fill_value): return call_function("coalesce", [values, fill_value]) -def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): +def top_k_unstable(values, k, sort_keys=None, null_placements=None, *, memory_pool=None): """ Select the indices of the top-k ordered elements from array- or table-like data. @@ -561,6 +561,9 @@ def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): The number of `k` elements to keep. sort_keys : List-like Column key names to order by when input is table-like data. + null_placements : A list of "at_start" or "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". memory_pool : MemoryPool, optional If not passed, will allocate memory from the default memory pool. @@ -585,14 +588,15 @@ def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): if sort_keys is None: sort_keys = [] if isinstance(values, (pa.Array, pa.ChunkedArray)): - sort_keys.append(("dummy", "descending")) + sort_keys.append(("dummy", "descending", "at_end")) else: - sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys) + sort_keys = [(sort_key, "descending", null_placement) + for sort_key, null_placement in zip(sort_keys, null_placements)] options = SelectKOptions(k, sort_keys) return call_function("select_k_unstable", [values], options, memory_pool) -def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): +def bottom_k_unstable(values, k, sort_keys=None, null_placements=None, *, memory_pool=None): """ Select the indices of the bottom-k ordered elements from array- or table-like data. @@ -608,6 +612,9 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): The number of `k` elements to keep. sort_keys : List-like Column key names to order by when input is table-like data. + null_placements : A list of "at_start" or "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". memory_pool : MemoryPool, optional If not passed, will allocate memory from the default memory pool. @@ -632,9 +639,11 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): if sort_keys is None: sort_keys = [] if isinstance(values, (pa.Array, pa.ChunkedArray)): - sort_keys.append(("dummy", "ascending")) + sort_keys.append(("dummy", "ascending", "at_end")) else: - sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys) + sort_keys = [(sort_key, "ascending", null_placement) + for sort_key, null_placement in zip(sort_keys, null_placements)] + options = SelectKOptions(k, sort_keys) return call_function("select_k_unstable", [values], options, memory_pool) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index fda9d444976..b4a8aa50ae5 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2429,18 +2429,18 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CNullPlacement null_placement cdef cppclass CSortKey" arrow::compute::SortKey": - CSortKey(CFieldRef target, CSortOrder order) + CSortKey(CFieldRef target, CSortOrder order, CNullPlacement null_placement) CFieldRef target CSortOrder order + CNullPlacement null_placement cdef cppclass COrdering" arrow::compute::Ordering": - COrdering(vector[CSortKey] sort_keys, CNullPlacement null_placement) + COrdering(vector[CSortKey] sort_keys) cdef cppclass CSortOptions \ "arrow::compute::SortOptions"(CFunctionOptions): - CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) + CSortOptions(vector[CSortKey] sort_keys) vector[CSortKey] sort_keys - CNullPlacement null_placement cdef cppclass CSelectKOptions \ "arrow::compute::SelectKOptions"(CFunctionOptions): @@ -2513,10 +2513,8 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CRankOptions \ "arrow::compute::RankOptions"(CFunctionOptions): - CRankOptions(vector[CSortKey] sort_keys, CNullPlacement, - CRankOptionsTiebreaker tiebreaker) + CRankOptions(vector[CSortKey] sort_keys, CRankOptionsTiebreaker tiebreaker) vector[CSortKey] sort_keys - CNullPlacement null_placement CRankOptionsTiebreaker tiebreaker cdef enum DatumType" arrow::Datum::type": diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index bbf60416de9..53d842dee73 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1081,7 +1081,7 @@ cdef class ChunkedArray(_PandasConvertible): """ return _pc().drop_null(self) - def sort(self, order="ascending", **kwargs): + def sort(self, order="ascending", null_placement="at_end", **kwargs): """ Sort the ChunkedArray @@ -1090,6 +1090,9 @@ cdef class ChunkedArray(_PandasConvertible): order : str, default "ascending" Which order to sort values in. Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -1100,7 +1103,7 @@ cdef class ChunkedArray(_PandasConvertible): """ indices = _pc().sort_indices( self, - options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + options=_pc().SortOptions(sort_keys=[("", order, null_placement)], **kwargs) ) return self.take(indices) @@ -1922,11 +1925,13 @@ cdef class _Tabular(_PandasConvertible): Parameters ---------- - sorting : str or list[tuple(name, order)] + sorting : str or list[tuple(name, order, null_placement)] Name of the column to use to sort (ascending), or a list of multiple sorting conditions where each entry is a tuple with column name - and sorting order ("ascending" or "descending") + and sorting order ("ascending" or "descending") + and nulls and NaNs are placed + at the start or at the end ("at_start" or "at_end") **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -1958,7 +1963,7 @@ cdef class _Tabular(_PandasConvertible): animal: [["Brittle stars","Centipede","Dog","Flamingo","Horse","Parrot"]] """ if isinstance(sorting, str): - sorting = [(sorting, "ascending")] + sorting = [(sorting, "ascending", "at_end")] indices = _pc().sort_indices( self, diff --git a/python/pyarrow/tests/test_acero.py b/python/pyarrow/tests/test_acero.py index 988e9b6e314..2086c51809a 100644 --- a/python/pyarrow/tests/test_acero.py +++ b/python/pyarrow/tests/test_acero.py @@ -247,19 +247,19 @@ def test_order_by(): table = pa.table({'a': [1, 2, 3, 4], 'b': [1, 3, None, 2]}) table_source = Declaration("table_source", TableSourceNodeOptions(table)) - ord_opts = OrderByNodeOptions([("b", "ascending")]) + ord_opts = OrderByNodeOptions([("b", "ascending", "at_end")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [1, 4, 2, 3], "b": [1, 2, 3, None]}) assert result.equals(expected) - ord_opts = OrderByNodeOptions([(field("b"), "descending")]) + ord_opts = OrderByNodeOptions([(field("b"), "descending", "at_end")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [2, 4, 1, 3], "b": [3, 2, 1, None]}) assert result.equals(expected) - ord_opts = OrderByNodeOptions([(1, "descending")], null_placement="at_start") + ord_opts = OrderByNodeOptions([(1, "descending", "at_start")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [3, 2, 4, 1], "b": [None, 3, 2, 1]}) @@ -274,10 +274,10 @@ def test_order_by(): _ = decl.to_table() with pytest.raises(ValueError, match="\"decreasing\" is not a valid sort order"): - _ = OrderByNodeOptions([("b", "decreasing")]) + _ = OrderByNodeOptions([("b", "decreasing", "at_end")]) with pytest.raises(ValueError, match="\"start\" is not a valid null placement"): - _ = OrderByNodeOptions([("b", "ascending")], null_placement="start") + _ = OrderByNodeOptions([("b", "ascending", "start")]) def test_hash_join(): diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 4b2144d702c..44befb29605 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -167,7 +167,7 @@ def test_option_class_equality(): pc.QuantileOptions(), pc.RandomOptions(), pc.RankOptions(sort_keys="ascending", - null_placement="at_start", tiebreaker="max"), + null_placement="at_end", tiebreaker="max"), pc.ReplaceSliceOptions(0, 1, "a"), pc.ReplaceSubstringOptions("a", "b"), pc.RoundOptions(2, "towards_infinity"), @@ -175,10 +175,10 @@ def test_option_class_equality(): pc.RoundTemporalOptions(1, "second", week_starts_monday=True), pc.RoundToMultipleOptions(100, "towards_infinity"), pc.ScalarAggregateOptions(), - pc.SelectKOptions(0, sort_keys=[("b", "ascending")]), + pc.SelectKOptions(0, sort_keys=[("b", "ascending", "at_end")]), pc.SetLookupOptions(pa.array([1])), pc.SliceOptions(0, 1, 1), - pc.SortOptions([("dummy", "descending")], null_placement="at_start"), + pc.SortOptions([("dummy", "descending", "at_end")]), pc.SplitOptions(), pc.SplitPatternOptions("pattern"), pc.StrftimeOptions(), @@ -2558,8 +2558,9 @@ def test_partition_nth_null_placement(): def test_select_k_array(): - def validate_select_k(select_k_indices, arr, order, stable_sort=False): - sorted_indices = pc.sort_indices(arr, sort_keys=[("dummy", order)]) + def validate_select_k(select_k_indices, arr, order, null_placement="at_end", stable_sort=False): + sorted_indices = pc.sort_indices( + arr, sort_keys=[("dummy", order, null_placement)]) head_k_indices = sorted_indices.slice(0, len(select_k_indices)) if stable_sort: assert select_k_indices == head_k_indices @@ -2572,8 +2573,8 @@ def validate_select_k(select_k_indices, arr, order, stable_sort=False): for k in [0, 2, 4]: for order in ["descending", "ascending"]: result = pc.select_k_unstable( - arr, k=k, sort_keys=[("dummy", order)]) - validate_select_k(result, arr, order) + arr, k=k, sort_keys=[("dummy", order, "at_end")]) + validate_select_k(result, arr, order, "at_end") result = pc.top_k_unstable(arr, k=k) validate_select_k(result, arr, "descending") @@ -2583,19 +2584,20 @@ def validate_select_k(select_k_indices, arr, order, stable_sort=False): result = pc.select_k_unstable( arr, options=pc.SelectKOptions( - k=2, sort_keys=[("dummy", "descending")]) + k=2, sort_keys=[("dummy", "descending", "at_end")]) ) validate_select_k(result, arr, "descending") result = pc.select_k_unstable( - arr, options=pc.SelectKOptions(k=2, sort_keys=[("dummy", "ascending")]) + arr, options=pc.SelectKOptions( + k=2, sort_keys=[("dummy", "ascending", "at_end")]) ) validate_select_k(result, arr, "ascending") # Position options assert pc.select_k_unstable(arr, 2, - sort_keys=[("dummy", "ascending")]) == result - assert pc.select_k_unstable(arr, 2, [("dummy", "ascending")]) == result + sort_keys=[("dummy", "ascending", "at_end")]) == result + assert pc.select_k_unstable(arr, 2, [("dummy", "ascending", "at_end")]) == result def test_select_k_table(): @@ -2612,20 +2614,22 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]}) for k in [0, 2, 4]: result = pc.select_k_unstable( - table, k=k, sort_keys=[("a", "ascending")]) - validate_select_k(result, table, sort_keys=[("a", "ascending")]) + table, k=k, sort_keys=[("a", "ascending", "at_end")]) + validate_select_k(result, table, sort_keys=[("a", "ascending", "at_end")]) result = pc.select_k_unstable( - table, k=k, sort_keys=[(pc.field("a"), "ascending"), ("b", "ascending")]) + table, k=k, sort_keys=[(pc.field("a"), "ascending", "at_end"), ("b", "ascending", "at_end")]) validate_select_k( - result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]) + result, table, sort_keys=[("a", "ascending", "at_end"), ("b", "ascending", "at_end")]) - result = pc.top_k_unstable(table, k=k, sort_keys=["a"]) - validate_select_k(result, table, sort_keys=[("a", "descending")]) + result = pc.top_k_unstable(table, k=k, sort_keys=[ + "a"], null_placements=["at_end"]) + validate_select_k(result, table, sort_keys=[("a", "descending", "at_end")]) - result = pc.bottom_k_unstable(table, k=k, sort_keys=["a", "b"]) + result = pc.bottom_k_unstable( + table, k=k, sort_keys=["a", "b"], null_placements=["at_end", "at_start"]) validate_select_k( - result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]) + result, table, sort_keys=[("a", "ascending", "at_end"), ("b", "ascending", "at_start")]) with pytest.raises( ValueError, @@ -2634,7 +2638,7 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): with pytest.raises(ValueError, match="select_k_unstable requires a nonnegative `k`"): - pc.select_k_unstable(table, k=-1, sort_keys=[("a", "ascending")]) + pc.select_k_unstable(table, k=-1, sort_keys=[("a", "ascending", "at_end")]) with pytest.raises(ValueError, match="select_k_unstable requires a " @@ -2642,11 +2646,11 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): pc.select_k_unstable(table, k=2, sort_keys=[]) with pytest.raises(ValueError, match="not a valid sort order"): - pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")]) + pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending", "at_end")]) with pytest.raises(ValueError, match="Invalid sort key column: No match for.*unknown"): - pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending")]) + pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending", "at_end")]) def test_array_sort_indices(): @@ -2672,25 +2676,22 @@ def test_sort_indices_array(): arr = pa.array([1, 2, None, 0]) result = pc.sort_indices(arr) assert result.to_pylist() == [3, 0, 1, 2] - result = pc.sort_indices(arr, sort_keys=[("dummy", "ascending")]) + result = pc.sort_indices(arr, sort_keys=[("dummy", "ascending", "at_end")]) assert result.to_pylist() == [3, 0, 1, 2] - result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")]) + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending", "at_end")]) assert result.to_pylist() == [1, 0, 3, 2] - result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")], - null_placement="at_start") + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Positional `sort_keys` - result = pc.sort_indices(arr, [("dummy", "descending")], - null_placement="at_start") + result = pc.sort_indices(arr, [("dummy", "descending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Using SortOptions result = pc.sort_indices( - arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")]) + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending", "at_end")]) ) assert result.to_pylist() == [1, 0, 3, 2] result = pc.sort_indices( - arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")], - null_placement="at_start") + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending", "at_start")]) ) assert result.to_pylist() == [2, 1, 0, 3] @@ -2698,26 +2699,22 @@ def test_sort_indices_array(): def test_sort_indices_table(): table = pa.table({"a": [1, 1, None, 0], "b": [1, 0, 0, 1]}) - result = pc.sort_indices(table, sort_keys=[("a", "ascending")]) + result = pc.sort_indices(table, sort_keys=[("a", "ascending", "at_end")]) assert result.to_pylist() == [3, 0, 1, 2] - result = pc.sort_indices(table, sort_keys=[(pc.field("a"), "ascending")], - null_placement="at_start") + result = pc.sort_indices( + table, sort_keys=[(pc.field("a"), "ascending", "at_start")]) assert result.to_pylist() == [2, 3, 0, 1] result = pc.sort_indices( - table, sort_keys=[("a", "descending"), ("b", "ascending")] + table, sort_keys=[("a", "descending", "at_end"), ("b", "ascending", "at_end")] ) assert result.to_pylist() == [1, 0, 3, 2] result = pc.sort_indices( - table, sort_keys=[("a", "descending"), ("b", "ascending")], - null_placement="at_start" - ) + table, sort_keys=[("a", "descending", "at_start"), ("b", "ascending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Positional `sort_keys` result = pc.sort_indices( - table, [("a", "descending"), ("b", "ascending")], - null_placement="at_start" - ) + table, [("a", "descending", "at_start"), ("b", "ascending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] with pytest.raises(ValueError, match="Must specify one or more sort keys"): @@ -2725,10 +2722,10 @@ def test_sort_indices_table(): with pytest.raises(ValueError, match="Invalid sort key column: No match for.*unknown"): - pc.sort_indices(table, sort_keys=[("unknown", "ascending")]) + pc.sort_indices(table, sort_keys=[("unknown", "ascending", "at_end")]) with pytest.raises(ValueError, match="not a valid sort order"): - pc.sort_indices(table, sort_keys=[("a", "nonscending")]) + pc.sort_indices(table, sort_keys=[("a", "nonscending", "at_end")]) def test_is_in(): @@ -3277,7 +3274,7 @@ def test_rank_options(): # Ensure sort_keys tuple usage result = pc.rank(arr, options=pc.RankOptions( - sort_keys=[("b", "ascending")]) + sort_keys=[("b", "ascending", "at_end")]) ) assert result.equals(expected) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 6f3b54b0cd6..0b8336688ce 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -3893,7 +3893,7 @@ def test_legacy_write_to_dataset_drops_null(tempdir): def _sort_table(tab, sort_col): import pyarrow.compute as pc sorted_indices = pc.sort_indices( - tab, options=pc.SortOptions([(sort_col, 'ascending')])) + tab, options=pc.SortOptions([(sort_col, 'ascending', 'at_end')])) return pc.take(tab, sorted_indices) @@ -5349,7 +5349,7 @@ def test_dataset_sort_by(tempdir, dstype): "values": [1, 2, 3, 4, 5] } - assert dt.sort_by([("values", "descending")]).to_table().to_pydict() == { + assert dt.sort_by([("values", "descending", "at_end")]).to_table().to_pydict() == { "keys": ["c", "b", "b", "a", "a"], "values": [5, 4, 3, 2, 1] } @@ -5367,12 +5367,12 @@ def test_dataset_sort_by(tempdir, dstype): ], names=["a", "b"]) dt = ds.dataset(table) - sorted_tab = dt.sort_by([("a", "descending")]) + sorted_tab = dt.sort_by([("a", "descending", "at_end")]) sorted_tab_dict = sorted_tab.to_table().to_pydict() assert sorted_tab_dict["a"] == [35, 7, 7, 5] assert sorted_tab_dict["b"] == ["foobar", "car", "bar", "foo"] - sorted_tab = dt.sort_by([("a", "ascending")]) + sorted_tab = dt.sort_by([("a", "ascending", "at_end")]) sorted_tab_dict = sorted_tab.to_table().to_pydict() assert sorted_tab_dict["a"] == [5, 7, 7, 35] assert sorted_tab_dict["b"] == ["foo", "car", "bar", "foobar"] diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index a678f521e38..0f76c8f2ee8 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2501,7 +2501,7 @@ def test_table_sort_by(): "values": [1, 2, 3, 4, 5] } - assert table.sort_by([("values", "descending")]).to_pydict() == { + assert table.sort_by([("values", "descending", "at_end")]).to_pydict() == { "keys": ["c", "b", "b", "a", "a"], "values": [5, 4, 3, 2, 1] } @@ -2511,12 +2511,12 @@ def test_table_sort_by(): pa.array(["foo", "car", "bar", "foobar"]) ], names=["a", "b"]) - sorted_tab = tab.sort_by([("a", "descending")]) + sorted_tab = tab.sort_by([("a", "descending", "at_end")]) sorted_tab_dict = sorted_tab.to_pydict() assert sorted_tab_dict["a"] == [35, 7, 7, 5] assert sorted_tab_dict["b"] == ["foobar", "car", "bar", "foo"] - sorted_tab = tab.sort_by([("a", "ascending")]) + sorted_tab = tab.sort_by([("a", "ascending", "at_end")]) sorted_tab_dict = sorted_tab.to_pydict() assert sorted_tab_dict["a"] == [5, 7, 7, 35] assert sorted_tab_dict["b"] == ["foo", "car", "bar", "foobar"] @@ -2529,13 +2529,14 @@ def test_record_batch_sort(): pa.array(["foo", "car", "bar", "foobar"]) ], names=["a", "b", "c"]) - sorted_rb = rb.sort_by([("a", "descending"), ("b", "descending")]) + sorted_rb = rb.sort_by([("a", "descending", "at_end"), + ("b", "descending", "at_end")]) sorted_rb_dict = sorted_rb.to_pydict() assert sorted_rb_dict["a"] == [35, 7, 7, 5] assert sorted_rb_dict["b"] == [1, 4, 3, 2] assert sorted_rb_dict["c"] == ["car", "foo", "bar", "foobar"] - sorted_rb = rb.sort_by([("a", "ascending"), ("b", "ascending")]) + sorted_rb = rb.sort_by([("a", "ascending", "at_end"), ("b", "ascending", "at_end")]) sorted_rb_dict = sorted_rb.to_pydict() assert sorted_rb_dict["a"] == [5, 7, 7, 35] assert sorted_rb_dict["b"] == [2, 3, 4, 1] diff --git a/ruby/red-arrow/lib/arrow/sort-key.rb b/ruby/red-arrow/lib/arrow/sort-key.rb index 7ceab631ea2..31b0363cc35 100644 --- a/ruby/red-arrow/lib/arrow/sort-key.rb +++ b/ruby/red-arrow/lib/arrow/sort-key.rb @@ -46,16 +46,16 @@ class << self # @return [Arrow::SortKey] A new suitable sort key. # # @since 4.0.0 - def resolve(target, order=nil) + def resolve(target, order=nil, null_placement = nil) return target if target.is_a?(self) - new(target, order) + new(target, order, null_placement) end # @api private def try_convert(value) case value when Symbol, String - new(value.to_s, :ascending) + new(value.to_s, :ascending, :at_end) else nil end @@ -139,10 +139,11 @@ def try_convert(value) # key.order # => Arrow::SortOrder::DESCENDING # # @since 4.0.0 - def initialize(target, order=nil) - target, order = normalize_target(target, order) + def initialize(target, order=nil, null_placement=nil) + target, order, null_placement = normalize_target(target, order, null_placement) order = normalize_order(order) || :ascending - initialize_raw(target, order) + null_placement = normalize_null_placement(null_placement) || :at_end + initialize_raw(target, order, null_placement) end # @return [String] The string representation of this sort key. You @@ -156,32 +157,49 @@ def initialize(target, order=nil) # # @since 4.0.0 def to_s - if order == SortOrder::ASCENDING + result = if order == SortOrder::ASCENDING "+#{target}" else "-#{target}" end + if null_placement == NullPlacement::AT_START + result += "_at_start" + else + result += "_at_end" + end + return result end # For backward compatibility alias_method :name, :target private - def normalize_target(target, order) + def normalize_target(target, order, null_placement) + # for recreatable, we should remove suffix + if target.end_with?("_at_start") + suffix_length = "_at_start".length + target = target[0..-(suffix_length + 1)] + elsif target.end_with?("_at_end") + suffix_length = "_at_end".length + target = target[0..-(suffix_length + 1)] + end + case target when Symbol - return target.to_s, order + return target.to_s, order, null_placement when String - return target, order if order + if order + return target, order, null_placement + end if target.start_with?("-") - return target[1..-1], order || :descending + return target[1..-1], order || :descending, null_placement || :at_end elsif target.start_with?("+") - return target[1..-1], order || :ascending + return target[1..-1], order || :ascending, null_placement || :at_end else - return target, order + return target, order, null_placement end else - return target, order + return target, order, null_placement end end @@ -195,5 +213,16 @@ def normalize_order(order) order end end + + def normalize_null_placement(null_placement) + case null_placement + when :at_end, "at_end" + :at_end + when :at_start, "at_start" + :at_start + else + null_placement + end + end end end diff --git a/ruby/red-arrow/lib/arrow/sort-options.rb b/ruby/red-arrow/lib/arrow/sort-options.rb index 24a027406b6..6e4af22eb38 100644 --- a/ruby/red-arrow/lib/arrow/sort-options.rb +++ b/ruby/red-arrow/lib/arrow/sort-options.rb @@ -102,8 +102,8 @@ def initialize(*sort_keys) # options.sort_keys.collect(&:to_s) # => ["-price"] # # @since 4.0.0 - def add_sort_key(target, order=nil) - add_sort_key_raw(SortKey.resolve(target, order)) + def add_sort_key(target, order=nil, null_placement=nil) + add_sort_key_raw(SortKey.resolve(target, order, null_placement)) end end end diff --git a/ruby/red-arrow/test/test-sort-key.rb b/ruby/red-arrow/test/test-sort-key.rb index 0a31f84610d..499f90fcadd 100644 --- a/ruby/red-arrow/test/test-sort-key.rb +++ b/ruby/red-arrow/test/test-sort-key.rb @@ -35,37 +35,37 @@ class SortKeyTest < Test::Unit::TestCase sub_test_case("#initialize") do test("String") do - assert_equal("+count", + assert_equal("+count_at_end", Arrow::SortKey.new("count").to_s) end test("+String") do - assert_equal("+count", + assert_equal("+count_at_end", Arrow::SortKey.new("+count").to_s) end test("-String") do - assert_equal("-count", + assert_equal("-count_at_end", Arrow::SortKey.new("-count").to_s) end test("Symbol") do - assert_equal("+-count", + assert_equal("+-count_at_end", Arrow::SortKey.new(:"-count").to_s) end test("String, Symbol") do - assert_equal("--count", + assert_equal("--count_at_end", Arrow::SortKey.new("-count", :desc).to_s) end test("String, String") do - assert_equal("--count", + assert_equal("--count_at_end", Arrow::SortKey.new("-count", "desc").to_s) end test("String, SortOrder") do - assert_equal("--count", + assert_equal("--count_at_end", Arrow::SortKey.new("-count", Arrow::SortOrder::DESCENDING).to_s) end diff --git a/ruby/red-arrow/test/test-sort-options.rb b/ruby/red-arrow/test/test-sort-options.rb index 0afd65b0f46..260c62e50f7 100644 --- a/ruby/red-arrow/test/test-sort-options.rb +++ b/ruby/red-arrow/test/test-sort-options.rb @@ -25,7 +25,7 @@ class SortOptionsTest < Test::Unit::TestCase test("-String, Symbol") do options = Arrow::SortOptions.new("-count", :age) - assert_equal(["-count", "+age"], + assert_equal(["-count_at_end", "+age_at_end"], options.sort_keys.collect(&:to_s)) end end @@ -38,19 +38,19 @@ class SortOptionsTest < Test::Unit::TestCase sub_test_case("#add_sort_key") do test("-String") do @options.add_sort_key("-count") - assert_equal(["-count"], + assert_equal(["-count_at_end"], @options.sort_keys.collect(&:to_s)) end test("-String, Symbol") do @options.add_sort_key("-count", :desc) - assert_equal(["--count"], + assert_equal(["--count_at_end"], @options.sort_keys.collect(&:to_s)) end test("SortKey") do @options.add_sort_key(Arrow::SortKey.new("-count")) - assert_equal(["-count"], + assert_equal(["-count_at_end"], @options.sort_keys.collect(&:to_s)) end end From d8bca7c5f76853d13697b73a6bd52f442b28c933 Mon Sep 17 00:00:00 2001 From: light-city <455954986@qq.com> Date: Fri, 10 Nov 2023 11:12:34 +0800 Subject: [PATCH 02/83] fix test --- cpp/src/arrow/compute/api_vector.h | 2 ++ cpp/src/arrow/compute/ordering.h | 1 + cpp/src/arrow/engine/substrait/serde_test.cc | 16 ++++++++++++++-- python/pyarrow/compute.py | 6 ++++-- python/pyarrow/tests/test_compute.py | 15 ++++++++++----- 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 40aacea2606..087a45a2c57 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -105,6 +105,7 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { class ARROW_EXPORT SortOptions : public FunctionOptions { public: + /// DEPRECATED(null_placement has been removed, please use SortKey.null_placement) explicit SortOptions(std::vector sort_keys = {}); explicit SortOptions(const Ordering& ordering); static constexpr char const kTypeName[] = "SortOptions"; @@ -173,6 +174,7 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { Dense }; + /// DEPRECATED(null_placement has been removed, please use SortKey.null_placement) explicit RankOptions(std::vector sort_keys = {}, Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index 146571568d5..b41d91905e2 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -65,6 +65,7 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { class ARROW_EXPORT Ordering : public util::EqualityComparable { public: + /// DEPRECATED(null_placement has been removed, please use SortKey.null_placement) explicit Ordering(std::vector sort_keys) : sort_keys_(std::move(sort_keys)) {} /// true if data ordered by other is also ordered by this /// diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index e1a2c24184d..3163c06b149 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5453,10 +5453,10 @@ TEST(Substrait, MixedSort) { auto test_schema = schema({field("A", int32()), field("B", int32())}); auto input_table = TableFromJSON(test_schema, {R"([ - [null, null], + [null, 10], [5, 8], [null, null], - [null, null], + [null, 3], [3, 4], [9, 6], [4, 5] @@ -5475,6 +5475,18 @@ TEST(Substrait, MixedSort) { ASSERT_OK_AND_ASSIGN( auto plan_info, DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, conversion_options)); + ASSERT_OK_AND_ASSIGN(auto result_table, + DeclarationToTable(std::move(plan_info.root.declaration))); + auto expected_table = TableFromJSON(test_schema, {R"([ + [null, 3], + [null, 10], + [null, null], + [3, 4], + [4, 5], + [5, 8], + [9, 6] + ])"}); + AssertTablesEqual(*result_table, *expected_table); } TEST(Substrait, PlanWithExtension) { diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8b30a2f1e7b..a6b7694afd7 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -545,7 +545,8 @@ def fill_null(values, fill_value): return call_function("coalesce", [values, fill_value]) -def top_k_unstable(values, k, sort_keys=None, null_placements=None, *, memory_pool=None): +def top_k_unstable( + values, k, sort_keys=None, null_placements=None, *, memory_pool=None): """ Select the indices of the top-k ordered elements from array- or table-like data. @@ -596,7 +597,8 @@ def top_k_unstable(values, k, sort_keys=None, null_placements=None, *, memory_po return call_function("select_k_unstable", [values], options, memory_pool) -def bottom_k_unstable(values, k, sort_keys=None, null_placements=None, *, memory_pool=None): +def bottom_k_unstable( + values, k, sort_keys=None, null_placements=None, *, memory_pool=None): """ Select the indices of the bottom-k ordered elements from array- or table-like data. diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 44befb29605..e2e3baa0f6f 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2558,7 +2558,8 @@ def test_partition_nth_null_placement(): def test_select_k_array(): - def validate_select_k(select_k_indices, arr, order, null_placement="at_end", stable_sort=False): + def validate_select_k(select_k_indices, arr, order, null_placement="at_end", + stable_sort=False): sorted_indices = pc.sort_indices( arr, sort_keys=[("dummy", order, null_placement)]) head_k_indices = sorted_indices.slice(0, len(select_k_indices)) @@ -2618,9 +2619,11 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): validate_select_k(result, table, sort_keys=[("a", "ascending", "at_end")]) result = pc.select_k_unstable( - table, k=k, sort_keys=[(pc.field("a"), "ascending", "at_end"), ("b", "ascending", "at_end")]) + table, k=k, sort_keys=[(pc.field("a"), "ascending", "at_end"), + ("b", "ascending", "at_end")]) validate_select_k( - result, table, sort_keys=[("a", "ascending", "at_end"), ("b", "ascending", "at_end")]) + result, table, sort_keys=[("a", "ascending", "at_end"), + ("b", "ascending", "at_end")]) result = pc.top_k_unstable(table, k=k, sort_keys=[ "a"], null_placements=["at_end"]) @@ -2629,7 +2632,8 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): result = pc.bottom_k_unstable( table, k=k, sort_keys=["a", "b"], null_placements=["at_end", "at_start"]) validate_select_k( - result, table, sort_keys=[("a", "ascending", "at_end"), ("b", "ascending", "at_start")]) + result, table, sort_keys=[("a", "ascending", "at_end"), ("b", "ascending", + "at_start")]) with pytest.raises( ValueError, @@ -2710,7 +2714,8 @@ def test_sort_indices_table(): ) assert result.to_pylist() == [1, 0, 3, 2] result = pc.sort_indices( - table, sort_keys=[("a", "descending", "at_start"), ("b", "ascending", "at_start")]) + table, sort_keys=[("a", "descending", "at_start"), ("b", "ascending", + "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Positional `sort_keys` result = pc.sort_indices( From 970f5bf2f282d097f7d6b3f4338891ac16ea5365 Mon Sep 17 00:00:00 2001 From: light-city <455954986@qq.com> Date: Tue, 14 Nov 2023 10:35:18 +0800 Subject: [PATCH 03/83] fix sortkey assert --- cpp/src/arrow/engine/substrait/serde_test.cc | 24 ++++++++------------ 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 3163c06b149..8a8ce066ea9 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5453,10 +5453,10 @@ TEST(Substrait, MixedSort) { auto test_schema = schema({field("A", int32()), field("B", int32())}); auto input_table = TableFromJSON(test_schema, {R"([ - [null, 10], + [null, null], [5, 8], [null, null], - [null, 3], + [null, null], [3, 4], [9, 6], [4, 5] @@ -5475,18 +5475,14 @@ TEST(Substrait, MixedSort) { ASSERT_OK_AND_ASSIGN( auto plan_info, DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, conversion_options)); - ASSERT_OK_AND_ASSIGN(auto result_table, - DeclarationToTable(std::move(plan_info.root.declaration))); - auto expected_table = TableFromJSON(test_schema, {R"([ - [null, 3], - [null, 10], - [null, null], - [3, 4], - [4, 5], - [5, 8], - [9, 6] - ])"}); - AssertTablesEqual(*result_table, *expected_table); + auto& order_by_options = + checked_cast(*plan_info.root.declaration.options); + EXPECT_THAT( + order_by_options.ordering.sort_keys(), + ElementsAre(arrow::compute::SortKey{"A", arrow::compute::SortOrder::Ascending, + arrow::compute::NullPlacement::AtStart}, + arrow::compute::SortKey{"B", arrow::compute::SortOrder::Ascending, + arrow::compute::NullPlacement::AtEnd})); } TEST(Substrait, PlanWithExtension) { From 6c0bc766e9256757eff65fa013b1bbf119fea0d6 Mon Sep 17 00:00:00 2001 From: light-city <455954986@qq.com> Date: Tue, 14 Nov 2023 12:02:35 +0800 Subject: [PATCH 04/83] fix serde test --- cpp/src/arrow/engine/substrait/serde_test.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 8a8ce066ea9..beedcfe36ba 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5477,12 +5477,14 @@ TEST(Substrait, MixedSort) { conversion_options)); auto& order_by_options = checked_cast(*plan_info.root.declaration.options); + EXPECT_THAT( order_by_options.ordering.sort_keys(), - ElementsAre(arrow::compute::SortKey{"A", arrow::compute::SortOrder::Ascending, - arrow::compute::NullPlacement::AtStart}, - arrow::compute::SortKey{"B", arrow::compute::SortOrder::Ascending, - arrow::compute::NullPlacement::AtEnd})); + ElementsAre( + arrow::compute::SortKey{FieldPath({0}), arrow::compute::SortOrder::Ascending, + arrow::compute::NullPlacement::AtStart}, + arrow::compute::SortKey{FieldPath({1}), arrow::compute::SortOrder::Ascending, + arrow::compute::NullPlacement::AtEnd})); } TEST(Substrait, PlanWithExtension) { From 25984c5442548ac863c5dd9b4db73c46ebdedd9f Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 25 Jun 2025 14:05:27 +0200 Subject: [PATCH 05/83] better api --- cpp/src/arrow/compute/api_vector.cc | 27 ++--- cpp/src/arrow/compute/api_vector.h | 80 ++++++++++++--- cpp/src/arrow/compute/kernels/vector_rank.cc | 9 +- .../arrow/compute/kernels/vector_select_k.cc | 22 +++-- cpp/src/arrow/compute/kernels/vector_sort.cc | 98 +++++++++---------- .../compute/kernels/vector_sort_internal.h | 45 ++++----- .../arrow/compute/kernels/vector_sort_test.cc | 45 +++++++-- cpp/src/arrow/compute/ordering.cc | 28 ++++-- cpp/src/arrow/compute/ordering.h | 17 +++- 9 files changed, 235 insertions(+), 136 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 06e6cf6c1ad..ff4fe1be09e 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -123,6 +123,7 @@ namespace compute { namespace internal { namespace { +using ::arrow::internal::CoercedDataMember; using ::arrow::internal::DataMember; static auto kFilterOptionsType = GetFunctionOptionsType( DataMember("null_selection_behavior", &FilterOptions::null_selection_behavior)); @@ -137,8 +138,8 @@ static auto kArraySortOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); static auto kSortOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &SortOptions::sort_keys), - DataMember("null_placement", &SortOptions::null_placement)); + CoercedDataMember("sort_keys", &SortOptions::sort_keys_, + &SortOptions::GetSortKeys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot), DataMember("null_placement", &PartitionNthOptions::null_placement)); @@ -152,12 +153,12 @@ static auto kCumulativeOptionsType = GetFunctionOptionsType( DataMember("start", &CumulativeOptions::start), DataMember("skip_nulls", &CumulativeOptions::skip_nulls)); static auto kRankOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &RankOptions::sort_keys), - DataMember("null_placement", &RankOptions::null_placement), + CoercedDataMember("sort_keys", &RankOptions::sort_keys_, + &RankOptions::GetSortKeys), DataMember("tiebreaker", &RankOptions::tiebreaker)); static auto kRankQuantileOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &RankQuantileOptions::sort_keys), - DataMember("null_placement", &RankQuantileOptions::null_placement)); + CoercedDataMember("sort_keys", &RankQuantileOptions::sort_keys_, + &RankQuantileOptions::GetSortKeys)); static auto kPairwiseOptionsType = GetFunctionOptionsType( DataMember("periods", &PairwiseOptions::periods)); static auto kListFlattenOptionsType = GetFunctionOptionsType( @@ -195,13 +196,13 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; -SortOptions::SortOptions(std::vector sort_keys, NullPlacement null_placement) +SortOptions::SortOptions(std::vector sort_keys, std::optional null_placement) : FunctionOptions(internal::kSortOptionsType), - sort_keys(std::move(sort_keys)), + sort_keys_(std::move(sort_keys)), null_placement(null_placement) {} SortOptions::SortOptions(const Ordering& ordering) : FunctionOptions(internal::kSortOptionsType), - sort_keys(ordering.sort_keys()), + sort_keys_(ordering.sort_keys()), null_placement(ordering.null_placement()) {} constexpr char SortOptions::kTypeName[]; @@ -232,18 +233,18 @@ CumulativeOptions::CumulativeOptions(std::shared_ptr start, bool skip_nu skip_nulls(skip_nulls) {} constexpr char CumulativeOptions::kTypeName[]; -RankOptions::RankOptions(std::vector sort_keys, NullPlacement null_placement, +RankOptions::RankOptions(std::vector sort_keys, std::optional null_placement, RankOptions::Tiebreaker tiebreaker) : FunctionOptions(internal::kRankOptionsType), - sort_keys(std::move(sort_keys)), + sort_keys_(std::move(sort_keys)), null_placement(null_placement), tiebreaker(tiebreaker) {} constexpr char RankOptions::kTypeName[]; RankQuantileOptions::RankQuantileOptions(std::vector sort_keys, - NullPlacement null_placement) + std::optional null_placement) : FunctionOptions(internal::kRankQuantileOptionsType), - sort_keys(std::move(sort_keys)), + sort_keys_(std::move(sort_keys)), null_placement(null_placement) {} constexpr char RankQuantileOptions::kTypeName[]; diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 69e4b243c97..398870f2bce 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -105,7 +105,7 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { class ARROW_EXPORT SortOptions : public FunctionOptions { public: explicit SortOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd); + std::optional null_placement = std::nullopt); explicit SortOptions(const Ordering& ordering); static constexpr char const kTypeName[] = "SortOptions"; static SortOptions Defaults() { return SortOptions(); } @@ -114,13 +114,29 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// Note: Both classes contain the exact same information. However, /// sort_options should only be used in a "function options" context while Ordering /// is used more generally. - Ordering AsOrdering() && { return Ordering(std::move(sort_keys), null_placement); } - Ordering AsOrdering() const& { return Ordering(sort_keys, null_placement); } + Ordering AsOrdering() && { return Ordering(std::move(sort_keys_), null_placement); } + Ordering AsOrdering() const& { return Ordering(sort_keys_, null_placement); } /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys; + std::vector sort_keys_; + + // DEPRECATED(will be removed after null_placement has been removed) + /// Get sort_keys with overwritten null_placement + std::vector GetSortKeys() const { + if(!null_placement.has_value()){ + return sort_keys_; + } + auto overwritten_sort_keys = sort_keys_; + for(auto& sort_key : overwritten_sort_keys){ + sort_key.null_placement = null_placement.value(); + } + return overwritten_sort_keys; + } + + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; + /// Will overwrite null ordering of sort keys + std::optional null_placement; }; /// \brief SelectK options @@ -156,6 +172,12 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { int64_t k; /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + + // DEPRECATED(will be removed after null_placement has been removed from other SortOptions-like structs) + /// Get sort_keys + std::vector GetSortKeys() const{ + return sort_keys; + } }; /// \brief Rank options @@ -176,11 +198,11 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { }; explicit RankOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd, + std::optional null_placement = std::nullopt, Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs explicit RankOptions(SortOrder order, - NullPlacement null_placement = NullPlacement::AtEnd, + std::optional null_placement = std::nullopt, Tiebreaker tiebreaker = RankOptions::First) : RankOptions({SortKey("", order)}, null_placement, tiebreaker) {} @@ -188,9 +210,25 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { static RankOptions Defaults() { return RankOptions(); } /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys; + std::vector sort_keys_; + + // DEPRECATED(will be removed after null_placement has been removed) + /// Get sort_keys with overwritten null_placement + std::vector GetSortKeys() const { + if(!null_placement.has_value()){ + return sort_keys_; + } + auto overwritten_sort_keys = sort_keys_; + for(auto& sort_key : overwritten_sort_keys){ + sort_key.null_placement = null_placement.value(); + } + return overwritten_sort_keys; + } + + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; + /// Will overwrite null ordering of sort keys + std::optional null_placement; /// Tiebreaker for dealing with equal values in ranks Tiebreaker tiebreaker; }; @@ -199,19 +237,35 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { public: explicit RankQuantileOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd); + std::optional null_placement = std::nullopt); /// Convenience constructor for array inputs explicit RankQuantileOptions(SortOrder order, - NullPlacement null_placement = NullPlacement::AtEnd) + std::optional null_placement = std::nullopt) : RankQuantileOptions({SortKey("", order)}, null_placement) {} static constexpr char const kTypeName[] = "RankQuantileOptions"; static RankQuantileOptions Defaults() { return RankQuantileOptions(); } /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys; + std::vector sort_keys_; + + // DEPRECATED(will be removed after null_placement has been removed) + /// Get sort_keys with overwritten null_placement + std::vector GetSortKeys() const { + if(!null_placement.has_value()){ + return sort_keys_; + } + auto overwritten_sort_keys = sort_keys_; + for(auto& sort_key : overwritten_sort_keys){ + sort_key.null_placement = null_placement.value(); + } + return overwritten_sort_keys; + } + + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; + /// Will overwrite null ordering of sort keys + std::optional null_placement; }; /// \brief Partitioning options for NthToIndices diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 1338ebedbe9..5f41cdcb55f 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -346,8 +346,11 @@ class RankMetaFunctionBase : public MetaFunction { checked_cast(function_options); SortOrder order = SortOrder::Ascending; - if (!options.sort_keys.empty()) { - order = options.sort_keys[0].order; + NullPlacement null_placement = NullPlacement::AtStart; + auto sort_keys = options.GetSortKeys(); + if (!sort_keys.empty()) { + order = sort_keys[0].order; + null_placement = sort_keys[0].null_placement; } int64_t length = input.length(); @@ -359,7 +362,7 @@ class RankMetaFunctionBase : public MetaFunction { auto needs_duplicates = Derived::NeedsDuplicates(options); ARROW_ASSIGN_OR_RAISE( auto sorted, SortAndMarkDuplicate(ctx, indices_begin, indices_end, input, order, - options.null_placement, needs_duplicates) + null_placement, needs_duplicates) .Run()); auto ranker = Derived::GetRanker(options); diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index eba7873e510..b0ec043e62d 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -81,7 +81,7 @@ class ArraySelector : public TypeVisitor { ctx_(ctx), array_(array), k_(options.k), - order_(options.sort_keys[0].order), + order_(options.GetSortKeys()[0].order), physical_type_(GetPhysicalType(array.type())), output_(output) {} @@ -287,8 +287,8 @@ class RecordBatchSelector : public TypeVisitor { record_batch_(record_batch), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + sort_keys_(ResolveSortKeys(record_batch, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -314,7 +314,7 @@ class RecordBatchSelector : public TypeVisitor { *status = maybe_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_array), key.order); + resolved.emplace_back(*std::move(maybe_array), key.order, key.null_placement); } return resolved; } @@ -396,8 +396,9 @@ class TableSelector : public TypeVisitor { private: struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& chunked_array, - const SortOrder order) + const SortOrder order, NullPlacement null_placement) : order(order), + null_placement(null_placement), type(GetPhysicalType(chunked_array->type())), chunks(GetPhysicalChunks(*chunked_array, type)), null_count(chunked_array->null_count()), @@ -410,6 +411,7 @@ class TableSelector : public TypeVisitor { ResolvedChunk GetChunk(int64_t index) const { return resolver.Resolve(index); } const SortOrder order; + const NullPlacement null_placement; const std::shared_ptr type; const ArrayVector chunks; const int64_t null_count; @@ -425,8 +427,8 @@ class TableSelector : public TypeVisitor { table_(table), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + sort_keys_(ResolveSortKeys(table, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -453,7 +455,7 @@ class TableSelector : public TypeVisitor { *status = maybe_chunked_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_chunked_array), key.order); + resolved.emplace_back(*std::move(maybe_chunked_array), key.order, key.null_placement); } return resolved; } @@ -621,7 +623,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.sort_keys)); + ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.GetSortKeys())); Datum output; RecordBatchSelector selector(ctx, record_batch, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); @@ -629,7 +631,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const Table& table, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.sort_keys)); + ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.GetSortKeys())); Datum output; TableSelector selector(ctx, table, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 28868849fc5..545b2e536fc 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -392,17 +392,14 @@ class RadixRecordBatchSorter { using ResolvedSortKey = ResolvedRecordBatchSortKey; RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, - std::vector sort_keys, - const SortOptions& options) + std::vector sort_keys) : sort_keys_(std::move(sort_keys)), - options_(options), indices_begin_(indices_begin), indices_end_(indices_end) {} RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, const RecordBatch& batch, const SortOptions& options) - : sort_keys_(ResolveRecordBatchSortKeys(batch, options.sort_keys, &status_)), - options_(options), + : sort_keys_(ResolveRecordBatchSortKeys(batch, options.GetSortKeys(), &status_)), indices_begin_(indices_begin), indices_end_(indices_end) {} @@ -414,7 +411,7 @@ class RadixRecordBatchSorter { std::vector> column_sorts(sort_keys_.size()); RecordBatchColumnSorter* next_column = nullptr; for (int64_t i = static_cast(sort_keys_.size() - 1); i >= 0; --i) { - ColumnSortFactory factory(sort_keys_[i], options_, next_column); + ColumnSortFactory factory(sort_keys_[i], next_column); ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort()); next_column = column_sorts[i].get(); } @@ -425,12 +422,12 @@ class RadixRecordBatchSorter { protected: struct ColumnSortFactory { - ColumnSortFactory(const ResolvedSortKey& sort_key, const SortOptions& options, + ColumnSortFactory(const ResolvedSortKey& sort_key, RecordBatchColumnSorter* next_column) : physical_type(sort_key.type), array(sort_key.owned_array), order(sort_key.order), - null_placement(options.null_placement), + null_placement(sort_key.null_placement), next_column(next_column) {} Result> MakeColumnSort() { @@ -473,7 +470,6 @@ class RadixRecordBatchSorter { } const std::vector sort_keys_; - const SortOptions& options_; uint64_t* indices_begin_; uint64_t* indices_end_; Status status_; @@ -485,21 +481,18 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { using ResolvedSortKey = ResolvedRecordBatchSortKey; MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, - std::vector sort_keys, - const SortOptions& options) + std::vector sort_keys) : indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(std::move(sort_keys)), - null_placement_(options.null_placement), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, const RecordBatch& batch, const SortOptions& options) : indices_begin_(indices_begin), indices_end_(indices_end), - sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)), - null_placement_(options.null_placement), - comparator_(sort_keys_, null_placement_) {} + sort_keys_(ResolveSortKeys(batch, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} // This is optimized for the first sort key. The first sort key sort // is processed in this class. The second and following sort keys @@ -635,11 +628,10 @@ class TableSorter { table_(table), batches_(MakeBatches(table, &status_)), options_(options), - null_placement_(options.null_placement), - sort_keys_(ResolveSortKeys(table, batches_, options.sort_keys, &status_)), + sort_keys_(ResolveSortKeys(table, batches_, options.GetSortKeys(), &status_)), indices_begin_(indices_begin), indices_end_(indices_end), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} // This is optimized for null partitioning and merging along the first sort key. // Other sort keys are delegated to the Comparator class. @@ -712,10 +704,11 @@ class TableSorter { TableSorter* sorter; std::vector* chunk_sorted; int64_t null_count; + NullPlacement null_placement; #define VISIT(TYPE) \ Status Visit(const TYPE& type) { \ - return sorter->MergeInternal(chunk_sorted, null_count); \ + return sorter->MergeInternal(chunk_sorted, null_count, null_placement); \ } VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) @@ -727,7 +720,7 @@ class TableSorter { type.ToString()); } }; - Visitor visitor{this, &chunk_sorted, null_count}; + Visitor visitor{this, &chunk_sorted, null_count, sort_keys_[0].null_placement}; RETURN_NOT_OK(VisitTypeInline(*sort_keys_[0].type, &visitor)); DCHECK_EQ(chunk_sorted.size(), 1); @@ -742,7 +735,7 @@ class TableSorter { // Recursive merge routine, typed on the first sort key template Status MergeInternal(std::vector* sorted, - int64_t null_count) { + int64_t null_count, NullPlacement null_placement) { auto merge_nulls = [&](CompressedChunkLocation* nulls_begin, CompressedChunkLocation* nulls_middle, CompressedChunkLocation* nulls_end, @@ -756,7 +749,7 @@ class TableSorter { MergeNonNulls(range_begin, range_middle, range_end, temp_indices); }; - ChunkedMergeImpl merge_impl(options_.null_placement, std::move(merge_nulls), + ChunkedMergeImpl merge_impl(null_placement, std::move(merge_nulls), std::move(merge_non_nulls)); RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); @@ -798,7 +791,7 @@ class TableSorter { const auto right_is_null = chunk_right.IsNull(); if (left_is_null == right_is_null) { return comparator.Compare(left_loc, right_loc, 1); - } else if (options_.null_placement == NullPlacement::AtEnd) { + } else if (first_sort_key.null_placement == NullPlacement::AtEnd) { return right_is_null; } else { return left_is_null; @@ -881,7 +874,6 @@ class TableSorter { const Table& table_; const RecordBatchVector batches_; const SortOptions& options_; - const NullPlacement null_placement_; const std::vector sort_keys_; uint64_t* indices_begin_; uint64_t* indices_end_; @@ -969,18 +961,24 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; - if (!options.sort_keys.empty()) { - order = options.sort_keys[0].order; + NullPlacement null_placement = NullPlacement::AtStart; + auto sort_keys = options.GetSortKeys(); + if (!sort_keys.empty()) { + order = sort_keys[0].order; + null_placement = sort_keys[0].null_placement; } - ArraySortOptions array_options(order, options.null_placement); + ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } Result SortIndices(const ChunkedArray& chunked_array, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; - if (!options.sort_keys.empty()) { - order = options.sort_keys[0].order; + NullPlacement null_placement = NullPlacement::AtStart; + auto sort_keys = options.GetSortKeys(); + if (!sort_keys.empty()) { + order = sort_keys[0].order; + null_placement = sort_keys[0].null_placement; } auto out_type = uint64(); @@ -996,14 +994,14 @@ class SortIndicesMetaFunction : public MetaFunction { std::iota(out_begin, out_end, 0); RETURN_NOT_OK(SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, - options.null_placement)); + null_placement)); return Datum(out); } Result SortIndices(const RecordBatch& batch, const SortOptions& options, ExecContext* ctx) const { ARROW_ASSIGN_OR_RAISE(auto sort_keys, - ResolveRecordBatchSortKeys(batch, options.sort_keys)); + ResolveRecordBatchSortKeys(batch, options.GetSortKeys())); auto n_sort_keys = sort_keys.size(); if (n_sort_keys == 0) { @@ -1026,11 +1024,10 @@ class SortIndicesMetaFunction : public MetaFunction { std::iota(out_begin, out_end, 0); if (n_sort_keys <= kMaxRadixSortKeys) { - RadixRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys), options); + RadixRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys)); ARROW_RETURN_NOT_OK(sorter.Sort()); } else { - MultipleKeyRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys), - options); + MultipleKeyRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys)); ARROW_RETURN_NOT_OK(sorter.Sort()); } return Datum(out); @@ -1038,7 +1035,7 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const Table& table, const SortOptions& options, ExecContext* ctx) const { - auto n_sort_keys = options.sort_keys.size(); + auto n_sort_keys = options.sort_keys_.size(); if (n_sort_keys == 0) { return Status::Invalid("Must specify one or more sort keys"); } @@ -1048,7 +1045,7 @@ class SortIndicesMetaFunction : public MetaFunction { // need to do here. ARROW_ASSIGN_OR_RAISE( auto chunked_array, - PrependInvalidColumn(options.sort_keys[0].target.GetOneFlattened(table))); + PrependInvalidColumn(options.GetSortKeys()[0].target.GetOneFlattened(table))); if (chunked_array->type()->id() != Type::STRUCT) { return SortIndices(*chunked_array, options, ctx); } @@ -1089,7 +1086,7 @@ struct SortFieldPopulator { PrependInvalidColumn(sort_key.target.FindOne(schema))); if (seen_.insert(match).second) { ARROW_ASSIGN_OR_RAISE(auto schema_field, match.Get(schema)); - AddField(*schema_field->type(), match, sort_key.order); + AddField(*schema_field->type(), match, sort_key.order, sort_key.null_placement); } } @@ -1097,7 +1094,7 @@ struct SortFieldPopulator { } protected: - void AddLeafFields(const FieldVector& fields, SortOrder order) { + void AddLeafFields(const FieldVector& fields, SortOrder order, NullPlacement null_placement) { if (fields.empty()) { return; } @@ -1106,21 +1103,21 @@ struct SortFieldPopulator { for (const auto& f : fields) { const auto& type = *f->type(); if (type.id() == Type::STRUCT) { - AddLeafFields(type.fields(), order); + AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type); + sort_fields_.emplace_back(FieldPath(tmp_indices_), order, null_placement, &type); } ++tmp_indices_.back(); } tmp_indices_.pop_back(); } - void AddField(const DataType& type, const FieldPath& path, SortOrder order) { + void AddField(const DataType& type, const FieldPath& path, SortOrder order, NullPlacement null_placement) { if (type.id() == Type::STRUCT) { tmp_indices_ = path.indices(); - AddLeafFields(type.fields(), order); + AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(path, order, &type); + sort_fields_.emplace_back(path, order, null_placement, &type); } } @@ -1168,21 +1165,18 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ std::move(columns)); auto options = SortOptions::Defaults(); - options.null_placement = null_placement; - options.sort_keys.reserve(array.num_fields()); + options.sort_keys_.reserve(array.num_fields()); for (int i = 0; i < array.num_fields(); ++i) { - options.sort_keys.push_back(SortKey(FieldRef(i), sort_order)); + options.sort_keys_.push_back(SortKey(FieldRef(i), sort_order, null_placement)); } ARROW_ASSIGN_OR_RAISE(auto sort_keys, - ResolveRecordBatchSortKeys(*batch, options.sort_keys)); + ResolveRecordBatchSortKeys(*batch, options.GetSortKeys())); if (sort_keys.size() <= kMaxRadixSortKeys) { - RadixRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys), - options); + RadixRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys)); return sorter.Sort(); } else { - MultipleKeyRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys), - options); + MultipleKeyRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys)); return sorter.Sort(); } } diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index 49704ff8069..b8d2a527098 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -487,15 +487,16 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ struct SortField { SortField() = default; - SortField(FieldPath path, SortOrder order, const DataType* type) - : path(std::move(path)), order(order), type(type) {} - SortField(int index, SortOrder order, const DataType* type) - : SortField(FieldPath({index}), order, type) {} + SortField(FieldPath path, SortOrder order, NullPlacement null_placement, const DataType* type) + : path(std::move(path)), order(order), null_placement(null_placement), type(type) {} + SortField(int index, SortOrder order, NullPlacement null_placement, const DataType* type) + : SortField(FieldPath({index}), order, null_placement, type) {} bool is_nested() const { return path.indices().size() > 1; } FieldPath path; SortOrder order; + NullPlacement null_placement; const DataType* type; }; @@ -542,9 +543,9 @@ Result> ResolveSortKeys( // paths [0,0,0,0] and [0,0,0,1], we shouldn't need to flatten the first three // components more than once. ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(table_or_batch)); - return ResolvedSortKey{std::move(child), f.order}; + return ResolvedSortKey{std::move(child), f.order, f.null_placement}; } - return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order}; + return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order, f.null_placement}; }); } @@ -594,15 +595,14 @@ template struct ColumnComparator { using Location = typename ResolvedSortKey::LocationType; - ColumnComparator(const ResolvedSortKey& sort_key, NullPlacement null_placement) - : sort_key_(sort_key), null_placement_(null_placement) {} + ColumnComparator(const ResolvedSortKey& sort_key) + : sort_key_(sort_key) {} virtual ~ColumnComparator() = default; virtual int Compare(const Location& left, const Location& right) const = 0; ResolvedSortKey sort_key_; - NullPlacement null_placement_; }; template @@ -622,14 +622,14 @@ struct ConcreteColumnComparator : public ColumnComparator { if (is_null_left && is_null_right) { return 0; } else if (is_null_left) { - return this->null_placement_ == NullPlacement::AtStart ? -1 : 1; + return sort_key.null_placement == NullPlacement::AtStart ? -1 : 1; } else if (is_null_right) { - return this->null_placement_ == NullPlacement::AtStart ? 1 : -1; + return sort_key.null_placement == NullPlacement::AtStart ? 1 : -1; } } return CompareTypeValues(chunk_left.template Value(), chunk_right.template Value(), sort_key.order, - this->null_placement_); + sort_key.null_placement); } }; @@ -650,9 +650,8 @@ class MultipleKeyComparator { public: using Location = typename ResolvedSortKey::LocationType; - MultipleKeyComparator(const std::vector& sort_keys, - NullPlacement null_placement) - : sort_keys_(sort_keys), null_placement_(null_placement) { + explicit MultipleKeyComparator(const std::vector& sort_keys) + : sort_keys_(sort_keys) { status_ &= MakeComparators(); } @@ -687,12 +686,11 @@ class MultipleKeyComparator { template Status VisitGeneric(const Type& type) { res.reset( - new ConcreteColumnComparator{sort_key, null_placement}); + new ConcreteColumnComparator{sort_key}); return Status::OK(); } const ResolvedSortKey& sort_key; - NullPlacement null_placement; std::unique_ptr> res; }; @@ -700,7 +698,7 @@ class MultipleKeyComparator { column_comparators_.reserve(sort_keys_.size()); for (const auto& sort_key : sort_keys_) { - ColumnComparatorFactory factory{sort_key, null_placement_, nullptr}; + ColumnComparatorFactory factory{sort_key, nullptr}; RETURN_NOT_OK(VisitTypeInline(*sort_key.type, &factory)); column_comparators_.push_back(std::move(factory.res)); } @@ -728,17 +726,17 @@ class MultipleKeyComparator { } const std::vector& sort_keys_; - const NullPlacement null_placement_; std::vector>> column_comparators_; Status status_; }; struct ResolvedRecordBatchSortKey { - ResolvedRecordBatchSortKey(const std::shared_ptr& array, SortOrder order) + ResolvedRecordBatchSortKey(const std::shared_ptr& array, SortOrder order, NullPlacement null_placement) : type(GetPhysicalType(array->type())), owned_array(GetPhysicalArray(*array, type)), array(*owned_array), order(order), + null_placement(null_placement), null_count(array->null_count()) {} using LocationType = int64_t; @@ -749,16 +747,18 @@ struct ResolvedRecordBatchSortKey { std::shared_ptr owned_array; const Array& array; SortOrder order; + NullPlacement null_placement; int64_t null_count; }; struct ResolvedTableSortKey { ResolvedTableSortKey(const std::shared_ptr& type, ArrayVector chunks, - SortOrder order, int64_t null_count) + SortOrder order, NullPlacement null_placement, int64_t null_count) : type(GetPhysicalType(type)), owned_chunks(std::move(chunks)), chunks(GetArrayPointers(owned_chunks)), order(order), + null_placement(null_placement), null_count(null_count) {} using LocationType = ::arrow::ChunkLocation; @@ -784,7 +784,7 @@ struct ResolvedTableSortKey { chunks.push_back(std::move(child)); } - return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order, + return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order, f.null_placement, null_count); }; @@ -796,6 +796,7 @@ struct ResolvedTableSortKey { ArrayVector owned_chunks; std::vector chunks; SortOrder order; + NullPlacement null_placement; int64_t null_count; }; diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 0569f1f2abb..91360558375 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1207,7 +1207,7 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { for (auto null_placement : AllNullPlacements()) { SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}, + {SortKey("a", SortOrder::Ascending, null_placement), SortKey("b", SortOrder::Descending, null_placement)}, null_placement); AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); @@ -1237,6 +1237,33 @@ TEST_F(TestRecordBatchSortIndices, Null) { AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); } +TEST_F(TestRecordBatchSortIndices, MixedNullOrdering) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + auto batch = RecordBatchFromJSON(schema, + R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"); + const std::vector sort_keys{SortKey("a", SortOrder::Ascending, NullPlacement::AtEnd), + SortKey("b", SortOrder::Descending, NullPlacement::AtEnd)}; + + SortOptions options(sort_keys, std::nullopt); + AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); + + options.sort_keys_.at(0).null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 5, 1, 4, 6, 2]"); + + options.sort_keys_.at(1).null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); +} + TEST_F(TestRecordBatchSortIndices, NaN) { auto schema = ::arrow::schema({ {field("a", float32())}, @@ -1858,25 +1885,25 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { class Comparator { public: Comparator(const Table& table, const SortOptions& options) : options_(options) { - for (const auto& sort_key : options_.sort_keys) { + for (const auto& sort_key : options_.GetSortKeys()) { DCHECK(!sort_key.target.IsNested()); if (auto name = sort_key.target.name()) { - sort_columns_.emplace_back(table.GetColumnByName(*name).get(), sort_key.order); + sort_columns_.emplace_back(table.GetColumnByName(*name).get(), sort_key.order, sort_key.null_placement); continue; } auto index = sort_key.target.field_path()->indices()[0]; - sort_columns_.emplace_back(table.column(index).get(), sort_key.order); + sort_columns_.emplace_back(table.column(index).get(), sort_key.order, sort_key.null_placement); } } // Return true if the left record is less or equals to the right record, // false otherwise. bool operator()(uint64_t lhs, uint64_t rhs) { - for (const auto& pair : sort_columns_) { - ColumnComparator comparator(pair.second, options_.null_placement); - const auto& chunked_array = *pair.first; + for (const auto& tuple : sort_columns_) { + ColumnComparator comparator(std::get<1>(tuple), std::get<2>(tuple)); + const auto& chunked_array = *std::get<0>(tuple); int64_t lhs_index = 0, rhs_index = 0; const Array* lhs_array = FindTargetArray(chunked_array, lhs, &lhs_index); const Array* rhs_array = FindTargetArray(chunked_array, rhs, &rhs_index); @@ -1904,7 +1931,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { } const SortOptions& options_; - std::vector> sort_columns_; + std::vector> sort_columns_; }; public: @@ -2182,7 +2209,7 @@ class TestNestedSortIndices : public ::testing::Test { // Implementations may have an optimized path for cases with one sort key. // Additionally, this key references a struct containing another struct, which should // work recursively - options.sort_keys = {SortKey(FieldRef("a"), SortOrder::Ascending)}; + options.sort_keys_ = {SortKey(FieldRef("a"), SortOrder::Ascending)}; options.null_placement = NullPlacement::AtEnd; AssertSortIndices(datum, options, "[6, 7, 3, 4, 0, 8, 1, 2, 5]"); options.null_placement = NullPlacement::AtStart; diff --git a/cpp/src/arrow/compute/ordering.cc b/cpp/src/arrow/compute/ordering.cc index 25ad6a5ca5f..ce19052672e 100644 --- a/cpp/src/arrow/compute/ordering.cc +++ b/cpp/src/arrow/compute/ordering.cc @@ -38,6 +38,14 @@ std::string SortKey::ToString() const { ss << "DESC"; break; } + switch (null_placement) { + case NullPlacement::AtStart: + ss << " NULLS FIRST"; + break; + case NullPlacement::AtEnd: + ss << " NULLS LAST"; + break; + } return ss.str(); } @@ -78,15 +86,17 @@ std::string Ordering::ToString() const { ss << key.ToString(); } ss << "]"; - switch (null_placement_) { - case NullPlacement::AtEnd: - ss << " nulls last"; - break; - case NullPlacement::AtStart: - ss << " nulls first"; - break; - default: - Unreachable(); + if(null_placement_.has_value()){ + switch (null_placement_.value()) { + case NullPlacement::AtEnd: + ss << " nulls last"; + break; + case NullPlacement::AtStart: + ss << " nulls first"; + break; + default: + Unreachable(); + } } return ss.str(); } diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index 61caa2b570d..f85b07bddf7 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -46,8 +46,8 @@ enum class NullPlacement { /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices class ARROW_EXPORT SortKey : public util::EqualityComparable { public: - explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending) - : target(std::move(target)), order(order) {} + explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending, NullPlacement null_placement = NullPlacement::AtEnd) + : target(std::move(target)), order(order), null_placement(null_placement) {} bool Equals(const SortKey& other) const; std::string ToString() const; @@ -56,12 +56,14 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { FieldRef target; /// How to order by this sort key. SortOrder order; + /// Null placement for this sort key. + NullPlacement null_placement; }; class ARROW_EXPORT Ordering : public util::EqualityComparable { public: Ordering(std::vector sort_keys, - NullPlacement null_placement = NullPlacement::AtStart) + std::optional null_placement = NullPlacement::AtStart) : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} /// true if data ordered by other is also ordered by this /// @@ -91,7 +93,9 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { bool is_unordered() const { return !is_implicit_ && sort_keys_.empty(); } const std::vector& sort_keys() const { return sort_keys_; } - NullPlacement null_placement() const { return null_placement_; } + + // DEPRECATED(will be removed after member null_placement_ has been removed) + std::optional null_placement() const { return null_placement_; } static const Ordering& Implicit() { static const Ordering kImplicit(true); @@ -111,8 +115,11 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { : null_placement_(NullPlacement::AtStart), is_implicit_(is_implicit) {} /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys_; + + // DEPRECATED(set null_placement in instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement_; + /// Will overwrite null ordering of sort keys + std::optional null_placement_; bool is_implicit_ = false; }; From dafbe1421af294c8a786526a230240ddc1391d22 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 25 Jun 2025 22:35:34 +0200 Subject: [PATCH 06/83] merge follow-up --- cpp/src/arrow/compute/api_vector.h | 9 +++++++++ cpp/src/arrow/compute/kernels/vector_sort.cc | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 398870f2bce..0ba20bad74b 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -206,6 +206,15 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { Tiebreaker tiebreaker = RankOptions::First) : RankOptions({SortKey("", order)}, null_placement, tiebreaker) {} + explicit RankOptions(std::vector sort_keys, + Tiebreaker tiebreaker = RankOptions::First) + : RankOptions(std::move(sort_keys), std::nullopt, tiebreaker) {} + + /// Convenience constructor for array inputs + explicit RankOptions(SortOrder order, + Tiebreaker tiebreaker = RankOptions::First) + : RankOptions({SortKey("", order)}, std::nullopt, tiebreaker) {} + static constexpr char const kTypeName[] = "RankOptions"; static RankOptions Defaults() { return RankOptions(); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index b7a12f81a7d..6f8f225244a 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -1104,7 +1104,7 @@ struct SortFieldPopulator { if (type.id() == Type::STRUCT) { AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type, null_placement); + sort_fields_.emplace_back(FieldPath(tmp_indices_), order, null_placement, &type); } ++tmp_indices_.back(); } @@ -1117,7 +1117,7 @@ struct SortFieldPopulator { tmp_indices_ = path.indices(); AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(path, order, &type, null_placement); + sort_fields_.emplace_back(path, order, null_placement, &type); } } From 48a3e0cda8380af9e089f4021d0a88375032ffe2 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 11:04:25 +0200 Subject: [PATCH 07/83] merge follow-up --- cpp/src/arrow/acero/sink_node.cc | 2 +- cpp/src/arrow/compute/exec_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/acero/sink_node.cc b/cpp/src/arrow/acero/sink_node.cc index ab06dd8ffd8..dce3e6eddec 100644 --- a/cpp/src/arrow/acero/sink_node.cc +++ b/cpp/src/arrow/acero/sink_node.cc @@ -476,7 +476,7 @@ struct OrderBySinkNode final : public SinkNode { } static Status ValidateOrderByOptions(const OrderBySinkNodeOptions& options) { - if (options.sort_options.sort_keys.empty()) { + if (options.sort_options.sort_keys_.empty()) { return Status::Invalid("At least one sort key should be specified"); } return ValidateCommonOrderOptions(options); diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index 27aee9e9f8d..9c361366e31 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -1400,7 +1400,7 @@ TEST(Ordering, IsSuborderOf) { Ordering a{{SortKey{3}, SortKey{1}, SortKey{7}}}; Ordering b{{SortKey{3}, SortKey{1}}}; Ordering c{{SortKey{1}, SortKey{7}}}; - Ordering d{{SortKey{1}, SortKey{7, NullPlacement::AtStart}}}; + Ordering d{{SortKey{1}, SortKey{7, SortOrder::Ascending, NullPlacement::AtStart}}}; Ordering imp = Ordering::Implicit(); Ordering unordered = Ordering::Unordered(); From 5c9eb50fa05b38452fcbc51b345d7fd26030b104 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 11:46:36 +0200 Subject: [PATCH 08/83] formatting --- c_glib/arrow-glib/compute.cpp | 20 +++++++------ cpp/src/arrow/compute/api_vector.cc | 12 ++++---- cpp/src/arrow/compute/api_vector.h | 28 +++++++++---------- cpp/src/arrow/compute/kernels/vector_sort.cc | 6 ++-- .../compute/kernels/vector_sort_internal.h | 10 ++++--- .../arrow/compute/kernels/vector_sort_test.cc | 12 ++++---- cpp/src/arrow/compute/ordering.cc | 2 +- cpp/src/arrow/compute/ordering.h | 3 +- python/pyarrow/_compute.pyx | 3 +- python/pyarrow/array.pxi | 3 +- 10 files changed, 54 insertions(+), 45 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index c15a56251bc..c91d747e177 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -3049,13 +3049,13 @@ garrow_sort_key_class_init(GArrowSortKeyClass *klass) * * Since: 15.0.0 */ - spec = g_param_spec_enum("null-placement", - "Null Placement", - "Whether nulls and NaNs are placed at the start or at the end", - GARROW_TYPE_NULL_PLACEMENT, - 0, - static_cast(G_PARAM_READWRITE | - G_PARAM_CONSTRUCT_ONLY)); + spec = g_param_spec_enum( + "null-placement", + "Null Placement", + "Whether nulls and NaNs are placed at the start or at the end", + GARROW_TYPE_NULL_PLACEMENT, + 0, + static_cast(G_PARAM_READWRITE | G_PARAM_CONSTRUCT_ONLY)); g_object_class_install_property(gobject_class, PROP_SORT_KEY_NULL_PLACEMENT, spec); } @@ -3080,8 +3080,10 @@ garrow_sort_key_new(const gchar *target, return NULL; } auto sort_key = g_object_new(GARROW_TYPE_SORT_KEY, - "order", order, - "null-placement", null_placement, + "order", + order, + "null-placement", + null_placement, NULL); auto priv = GARROW_SORT_KEY_GET_PRIVATE(sort_key); priv->sort_key.target = *arrow_reference_result; diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index b10208854cb..1fbcd70af21 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -138,8 +138,7 @@ static auto kArraySortOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); static auto kSortOptionsType = GetFunctionOptionsType( - CoercedDataMember("sort_keys", &SortOptions::sort_keys_, - &SortOptions::GetSortKeys)); + CoercedDataMember("sort_keys", &SortOptions::sort_keys_, &SortOptions::GetSortKeys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot), DataMember("null_placement", &PartitionNthOptions::null_placement)); @@ -153,8 +152,7 @@ static auto kCumulativeOptionsType = GetFunctionOptionsType( DataMember("start", &CumulativeOptions::start), DataMember("skip_nulls", &CumulativeOptions::skip_nulls)); static auto kRankOptionsType = GetFunctionOptionsType( - CoercedDataMember("sort_keys", &RankOptions::sort_keys_, - &RankOptions::GetSortKeys), + CoercedDataMember("sort_keys", &RankOptions::sort_keys_, &RankOptions::GetSortKeys), DataMember("tiebreaker", &RankOptions::tiebreaker)); static auto kRankQuantileOptionsType = GetFunctionOptionsType( CoercedDataMember("sort_keys", &RankQuantileOptions::sort_keys_, @@ -196,7 +194,8 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; -SortOptions::SortOptions(std::vector sort_keys, std::optional null_placement) +SortOptions::SortOptions(std::vector sort_keys, + std::optional null_placement) : FunctionOptions(internal::kSortOptionsType), sort_keys_(std::move(sort_keys)), null_placement(null_placement) {} @@ -233,7 +232,8 @@ CumulativeOptions::CumulativeOptions(std::shared_ptr start, bool skip_nu skip_nulls(skip_nulls) {} constexpr char CumulativeOptions::kTypeName[]; -RankOptions::RankOptions(std::vector sort_keys, std::optional null_placement, +RankOptions::RankOptions(std::vector sort_keys, + std::optional null_placement, RankOptions::Tiebreaker tiebreaker) : FunctionOptions(internal::kRankOptionsType), sort_keys_(std::move(sort_keys)), diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 0ba20bad74b..b0dc50dc3e0 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -123,11 +123,11 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { - if(!null_placement.has_value()){ + if (!null_placement.has_value()){ return sort_keys_; } auto overwritten_sort_keys = sort_keys_; - for(auto& sort_key : overwritten_sort_keys){ + for (auto& sort_key : overwritten_sort_keys){ sort_key.null_placement = null_placement.value(); } return overwritten_sort_keys; @@ -173,11 +173,10 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; - // DEPRECATED(will be removed after null_placement has been removed from other SortOptions-like structs) + // DEPRECATED(will be removed after null_placement has been removed from other + // SortOptions-like structs) /// Get sort_keys - std::vector GetSortKeys() const{ - return sort_keys; - } + std::vector GetSortKeys() const { return sort_keys; } }; /// \brief Rank options @@ -211,8 +210,7 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { : RankOptions(std::move(sort_keys), std::nullopt, tiebreaker) {} /// Convenience constructor for array inputs - explicit RankOptions(SortOrder order, - Tiebreaker tiebreaker = RankOptions::First) + explicit RankOptions(SortOrder order, Tiebreaker tiebreaker = RankOptions::First) : RankOptions({SortKey("", order)}, std::nullopt, tiebreaker) {} static constexpr char const kTypeName[] = "RankOptions"; @@ -224,11 +222,11 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { - if(!null_placement.has_value()){ + if (!null_placement.has_value()){ return sort_keys_; } auto overwritten_sort_keys = sort_keys_; - for(auto& sort_key : overwritten_sort_keys){ + for (auto& sort_key : overwritten_sort_keys){ sort_key.null_placement = null_placement.value(); } return overwritten_sort_keys; @@ -245,8 +243,10 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { /// \brief Quantile rank options class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { public: - explicit RankQuantileOptions(std::vector sort_keys = {}, - std::optional null_placement = std::nullopt); + explicit RankQuantileOptions( + std::vector sort_keys = {}, + std::optional null_placement = std::nullopt); + /// Convenience constructor for array inputs explicit RankQuantileOptions(SortOrder order, std::optional null_placement = std::nullopt) @@ -261,11 +261,11 @@ class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { - if(!null_placement.has_value()){ + if (!null_placement.has_value()){ return sort_keys_; } auto overwritten_sort_keys = sort_keys_; - for(auto& sort_key : overwritten_sort_keys){ + for (auto& sort_key : overwritten_sort_keys){ sort_key.null_placement = null_placement.value(); } return overwritten_sort_keys; diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 6f8f225244a..58f82fb0743 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -748,7 +748,7 @@ class TableSorter { }; ChunkedMergeImpl merge_impl(sort_keys_[0].null_placement, std::move(merge_nulls), - std::move(merge_non_nulls)); + std::move(merge_non_nulls)); RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); while (sorted->size() > 1) { @@ -991,8 +991,8 @@ class SortIndicesMetaFunction : public MetaFunction { auto out_end = out_begin + length; std::iota(out_begin, out_end, 0); - RETURN_NOT_OK(SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, - null_placement)); + RETURN_NOT_OK( + SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, null_placement)); return Datum(out); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index 03c850554f7..5cdad2d4a65 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -487,9 +487,11 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ struct SortField { SortField() = default; - SortField(FieldPath path, SortOrder order, NullPlacement null_placement, const DataType* type) + SortField(FieldPath path, SortOrder order, NullPlacement null_placement, + const DataType* type) : path(std::move(path)), order(order), null_placement(null_placement), type(type) {} - SortField(int index, SortOrder order, NullPlacement null_placement, const DataType* type) + SortField(int index, SortOrder order, NullPlacement null_placement, + const DataType* type) : SortField(FieldPath({index}), order, null_placement, type) {} bool is_nested() const { return path.indices().size() > 1; } @@ -784,8 +786,8 @@ struct ResolvedTableSortKey { chunks.push_back(std::move(child)); } - return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order, f.null_placement, - null_count); + return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order, + f.null_placement, null_count); }; return ::arrow::compute::internal::ResolveSortKeys( diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index e010fb9ca07..a37f15970a1 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1206,9 +1206,10 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { ])"); for (auto null_placement : AllNullPlacements()) { - SortOptions options( - {SortKey("a", SortOrder::Ascending, null_placement), SortKey("b", SortOrder::Descending, null_placement)}, - null_placement); + SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), + SortKey("b", SortOrder::Descending, null_placement)}, + null_placement); + AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } @@ -1253,8 +1254,9 @@ TEST_F(TestRecordBatchSortIndices, MixedNullOrdering) { {"a": 1, "b": 5}, {"a": 3, "b": 5} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending, NullPlacement::AtEnd), - SortKey("b", SortOrder::Descending, NullPlacement::AtEnd)}; + const std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtEnd), + SortKey("b", SortOrder::Descending, NullPlacement::AtEnd)}; SortOptions options(sort_keys, std::nullopt); AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); diff --git a/cpp/src/arrow/compute/ordering.cc b/cpp/src/arrow/compute/ordering.cc index 0bb89d02511..f65cf4fc4fd 100644 --- a/cpp/src/arrow/compute/ordering.cc +++ b/cpp/src/arrow/compute/ordering.cc @@ -87,7 +87,7 @@ std::string Ordering::ToString() const { ss << key.ToString(); } ss << "]"; - if(null_placement_.has_value()){ + if (null_placement_.has_value()){ switch (null_placement_.value()) { case NullPlacement::AtEnd: ss << " nulls last"; diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index f85b07bddf7..91bab5fe5b2 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -46,7 +46,8 @@ enum class NullPlacement { /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices class ARROW_EXPORT SortKey : public util::EqualityComparable { public: - explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending, NullPlacement null_placement = NullPlacement::AtEnd) + explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending, + NullPlacement null_placement = NullPlacement::AtEnd) : target(std::move(target)), order(order), null_placement(null_placement) {} bool Equals(const SortKey& other) const; diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c75ee89c9df..946a3fc243e 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -83,7 +83,8 @@ cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True): else: for name, order, null_placement in sort_keys: c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order), unwrap_null_placement(null_placement)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order( + order), unwrap_null_placement(null_placement)) ) return c_sort_keys diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 57f4c1c9705..bef86e023a8 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -4324,7 +4324,8 @@ cdef class StructArray(Array): if by is not None: tosort, sort_keys = self._flattened_field(by), [("", order)] else: - tosort, sort_keys = self, [(field.name, order, null_placement) for field in self.type] + tosort, sort_keys = self, [ + (field.name, order, null_placement) for field in self.type] indices = _pc().sort_indices( tosort, options=_pc().SortOptions(sort_keys=sort_keys, **kwargs) ) From 12d1b0d47dbd15c9169111de7f6623a537410975 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 12:05:32 +0200 Subject: [PATCH 09/83] formatting --- cpp/src/arrow/compute/api_vector.h | 12 ++++++------ cpp/src/arrow/compute/ordering.cc | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index b0dc50dc3e0..83f3fd6438b 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -123,11 +123,11 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { - if (!null_placement.has_value()){ + if (!null_placement.has_value()) { return sort_keys_; } auto overwritten_sort_keys = sort_keys_; - for (auto& sort_key : overwritten_sort_keys){ + for (auto& sort_key : overwritten_sort_keys) { sort_key.null_placement = null_placement.value(); } return overwritten_sort_keys; @@ -222,11 +222,11 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { - if (!null_placement.has_value()){ + if (!null_placement.has_value()) { return sort_keys_; } auto overwritten_sort_keys = sort_keys_; - for (auto& sort_key : overwritten_sort_keys){ + for (auto& sort_key : overwritten_sort_keys) { sort_key.null_placement = null_placement.value(); } return overwritten_sort_keys; @@ -261,11 +261,11 @@ class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { - if (!null_placement.has_value()){ + if (!null_placement.has_value()) { return sort_keys_; } auto overwritten_sort_keys = sort_keys_; - for (auto& sort_key : overwritten_sort_keys){ + for (auto& sort_key : overwritten_sort_keys) { sort_key.null_placement = null_placement.value(); } return overwritten_sort_keys; diff --git a/cpp/src/arrow/compute/ordering.cc b/cpp/src/arrow/compute/ordering.cc index f65cf4fc4fd..5ee78026229 100644 --- a/cpp/src/arrow/compute/ordering.cc +++ b/cpp/src/arrow/compute/ordering.cc @@ -87,7 +87,7 @@ std::string Ordering::ToString() const { ss << key.ToString(); } ss << "]"; - if (null_placement_.has_value()){ + if (null_placement_.has_value()) { switch (null_placement_.value()) { case NullPlacement::AtEnd: ss << " nulls last"; From 0b442f7f49eeb5ae9798bde897b70da824fc1dcc Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 12:07:27 +0200 Subject: [PATCH 10/83] missing unwrap_null_placement in python --- python/pyarrow/_compute.pyx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 946a3fc243e..bfdac4edfb2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -78,7 +78,8 @@ cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True): cdef vector[CSortKey] c_sort_keys if allow_str and isinstance(sort_keys, str): c_sort_keys.push_back( - CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order), + unwrap_null_placement(null_placement)) ) else: for name, order, null_placement in sort_keys: From 4398d16833b25b172bc3c0da514a67a15ea555b7 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 12:13:39 +0200 Subject: [PATCH 11/83] do not remove demoted null_placement from python api --- python/pyarrow/includes/libarrow.pxd | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 35b85db7827..b03f83ad67d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2763,12 +2763,13 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CNullPlacement null_placement cdef cppclass COrdering" arrow::compute::Ordering": - COrdering(vector[CSortKey] sort_keys) + COrdering(vector[CSortKey] sort_keys, CNullPlacement null_placement) cdef cppclass CSortOptions \ "arrow::compute::SortOptions"(CFunctionOptions): - CSortOptions(vector[CSortKey] sort_keys) + CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys + CNullPlacement null_placement cdef cppclass CSelectKOptions \ "arrow::compute::SelectKOptions"(CFunctionOptions): @@ -2841,8 +2842,10 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CRankOptions \ "arrow::compute::RankOptions"(CFunctionOptions): - CRankOptions(vector[CSortKey] sort_keys, CRankOptionsTiebreaker tiebreaker) + CRankOptions(vector[CSortKey] sort_keys, CNullPlacement, + CRankOptionsTiebreaker tiebreaker) vector[CSortKey] sort_keys + CNullPlacement null_placement CRankOptionsTiebreaker tiebreaker cdef cppclass CRankQuantileOptions \ From bdc4069a52f7d82d31fe9cf2572637b60c90c332 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 13:37:36 +0200 Subject: [PATCH 12/83] fix member name in hash_aggregate_test --- cpp/src/arrow/acero/hash_aggregate_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc b/cpp/src/arrow/acero/hash_aggregate_test.cc index dce0e44eb13..b1a493ecb9a 100644 --- a/cpp/src/arrow/acero/hash_aggregate_test.cc +++ b/cpp/src/arrow/acero/hash_aggregate_test.cc @@ -688,7 +688,7 @@ namespace { void SortBy(std::vector names, Datum* aggregated_and_grouped) { SortOptions options; for (auto&& name : names) { - options.sort_keys.emplace_back(std::move(name), SortOrder::Ascending); + options.sort_keys_.emplace_back(std::move(name), SortOrder::Ascending); } ASSERT_OK_AND_ASSIGN( From 8fc630cf2ef507f9dd39dafa9b049717fa05e072 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 13:42:57 +0200 Subject: [PATCH 13/83] update ToString method output --- cpp/src/arrow/acero/plan_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/acero/plan_test.cc b/cpp/src/arrow/acero/plan_test.cc index 1a11138de99..2cb03114de6 100644 --- a/cpp/src/arrow/acero/plan_test.cc +++ b/cpp/src/arrow/acero/plan_test.cc @@ -518,7 +518,7 @@ TEST(ExecPlan, ToString) { }); ASSERT_OK_AND_ASSIGN(std::string plan_str, DeclarationToString(declaration)); EXPECT_EQ(plan_str, R"a(ExecPlan with 6 nodes: -custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC AtEnd]}} +custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC NULLS LAST]}} :FilterNode{filter=(sum(multiply(i32, 2)) > 10)} :GroupByNode{keys=["bool"], aggregates=[ hash_sum(multiply(i32, 2)), From dcb3b7ab2de6a0b0f9f29d417b9d2e0260dc3cd4 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 14:30:57 +0200 Subject: [PATCH 14/83] format remove extra empty line --- cpp/src/arrow/compute/kernels/vector_sort_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index a37f15970a1..1100af19b35 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1210,7 +1210,6 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { SortKey("b", SortOrder::Descending, null_placement)}, null_placement); - AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } } From 682a37adbf4250a1b64cb06153a51c5adcb89b87 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 14:31:03 +0200 Subject: [PATCH 15/83] fix python interface --- python/pyarrow/_compute.pyx | 51 ++++++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index bfdac4edfb2..08294f62bf5 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -78,14 +78,19 @@ cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True): cdef vector[CSortKey] c_sort_keys if allow_str and isinstance(sort_keys, str): c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order), - unwrap_null_placement(null_placement)) + CSortKey(_ensure_field_ref(""), unwrap_sort_order(order), + unwrap_null_placement("at_end")) ) else: - for name, order, null_placement in sort_keys: + for item in sort_keys: + if len(item) == 2: + name, order = item + null_placement = "at_end" + else: + name, order, null_placement = item c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order( - order), unwrap_null_placement(null_placement)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order), + unwrap_null_placement(null_placement)) ) return c_sort_keys @@ -2163,9 +2168,14 @@ class ArraySortOptions(_ArraySortOptions): cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): - self.wrapped.reset(new CSortOptions( - unwrap_sort_keys(sort_keys, allow_str=False), - unwrap_null_placement(null_placement))) + if null_placement is None: + self.wrapped.reset(new CSortOptions( + unwrap_sort_keys(sort_keys, allow_str=False), + None)) + else: + self.wrapped.reset(new CSortOptions( + unwrap_sort_keys(sort_keys, allow_str=False), + unwrap_null_placement(null_placement))) class SortOptions(_SortOptions): @@ -2180,7 +2190,7 @@ class SortOptions(_SortOptions): Accepted values for `order` are "ascending", "descending". Accepted values for `null_placement` are "at_start", "at_end". The field name can be a string column name or expression. - null_placement : str, default None + null_placement : str | None, default None Where nulls in input should be sorted, overwrites `null_placement` in `sort_keys`. Accepted values are "at_start", "at_end". @@ -2381,9 +2391,11 @@ cdef class _RankOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement, tiebreaker): try: self.wrapped.reset( - new CRankOptions(unwrap_sort_keys(sort_keys), - unwrap_null_placement(null_placement), - self._tiebreaker_map[tiebreaker]) + new CRankOptions( + unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement) if null_placement is not None else None, + self._tiebreaker_map[tiebreaker] + ) ) except KeyError: _raise_invalid_function_option(tiebreaker, "tiebreaker") @@ -2403,9 +2415,10 @@ class RankOptions(_RankOptions): The field name can be a string column name or expression. Alternatively, one can simply pass "ascending" or "descending" as a string if the input is array-like. - null_placement : str, default "at_end" + null_placement : str | None, default None Where nulls in input should be sorted. Accepted values are "at_start", "at_end". + Overwrites the null_placement inside sort_keys tiebreaker : str, default "first" Configure how ties between equal values are handled. Accepted values are: @@ -2419,7 +2432,7 @@ class RankOptions(_RankOptions): number of distinct values in the input. """ - def __init__(self, sort_keys="ascending", *, null_placement="at_end", tiebreaker="first"): + def __init__(self, sort_keys="ascending", *, null_placement=None, tiebreaker="first"): self._set_options(sort_keys, null_placement, tiebreaker) @@ -2427,8 +2440,10 @@ cdef class _RankQuantileOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): self.wrapped.reset( - new CRankQuantileOptions(unwrap_sort_keys(sort_keys), - unwrap_null_placement(null_placement)) + new CRankQuantileOptions( + unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement) if null_placement is not None else None + ) ) @@ -2446,12 +2461,12 @@ class RankQuantileOptions(_RankQuantileOptions): The field name can be a string column name or expression. Alternatively, one can simply pass "ascending" or "descending" as a string if the input is array-like. - null_placement : str, default "at_end" + null_placement : str | None, default None Where nulls in input should be sorted. Accepted values are "at_start", "at_end". """ - def __init__(self, sort_keys="ascending", *, null_placement="at_end"): + def __init__(self, sort_keys="ascending", *, null_placement=None): self._set_options(sort_keys, null_placement) From e6afb8a90192b7eb828e5d67776d60f8e4234363 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 14:55:00 +0200 Subject: [PATCH 16/83] python formatting --- python/pyarrow/_compute.pyx | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 08294f62bf5..3000a6c4fba 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2393,7 +2393,8 @@ cdef class _RankOptions(FunctionOptions): self.wrapped.reset( new CRankOptions( unwrap_sort_keys(sort_keys), - unwrap_null_placement(null_placement) if null_placement is not None else None, + unwrap_null_placement( + null_placement) if null_placement is not None else None, self._tiebreaker_map[tiebreaker] ) ) @@ -2442,7 +2443,8 @@ cdef class _RankQuantileOptions(FunctionOptions): self.wrapped.reset( new CRankQuantileOptions( unwrap_sort_keys(sort_keys), - unwrap_null_placement(null_placement) if null_placement is not None else None + unwrap_null_placement( + null_placement) if null_placement is not None else None ) ) From 82299bcf7a8f770761e48e54e0ebe50f5fdf14af Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 15:07:24 +0200 Subject: [PATCH 17/83] do not pass None to python C binding --- python/pyarrow/_compute.pyx | 42 ++++++++++++++++++---------- python/pyarrow/includes/libarrow.pxd | 8 ++++-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 3000a6c4fba..0ad33434bb3 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2170,8 +2170,7 @@ cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): if null_placement is None: self.wrapped.reset(new CSortOptions( - unwrap_sort_keys(sort_keys, allow_str=False), - None)) + unwrap_sort_keys(sort_keys, allow_str=False))) else: self.wrapped.reset(new CSortOptions( unwrap_sort_keys(sort_keys, allow_str=False), @@ -2390,14 +2389,21 @@ cdef class _RankOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement, tiebreaker): try: - self.wrapped.reset( - new CRankOptions( - unwrap_sort_keys(sort_keys), - unwrap_null_placement( - null_placement) if null_placement is not None else None, - self._tiebreaker_map[tiebreaker] + if null_placement is None: + self.wrapped.reset( + new CRankOptions( + unwrap_sort_keys(sort_keys), + self._tiebreaker_map[tiebreaker] + ) + ) + else: + self.wrapped.reset( + new CRankOptions( + unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement), + self._tiebreaker_map[tiebreaker] + ) ) - ) except KeyError: _raise_invalid_function_option(tiebreaker, "tiebreaker") @@ -2440,13 +2446,19 @@ class RankOptions(_RankOptions): cdef class _RankQuantileOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): - self.wrapped.reset( - new CRankQuantileOptions( - unwrap_sort_keys(sort_keys), - unwrap_null_placement( - null_placement) if null_placement is not None else None + if null_placement is None: + self.wrapped.reset( + new CRankQuantileOptions( + unwrap_sort_keys(sort_keys) + ) + ) + else: + self.wrapped.reset( + new CRankQuantileOptions( + unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement) + ) ) - ) class RankQuantileOptions(_RankQuantileOptions): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index b03f83ad67d..98f66ba34a1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2769,7 +2769,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: "arrow::compute::SortOptions"(CFunctionOptions): CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys - CNullPlacement null_placement + optional[CNullPlacement] null_placement cdef cppclass CSelectKOptions \ "arrow::compute::SelectKOptions"(CFunctionOptions): @@ -2842,17 +2842,19 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CRankOptions \ "arrow::compute::RankOptions"(CFunctionOptions): + CRankOptions(vector[CSortKey] sort_keys, CRankOptionsTiebreaker tiebreaker) CRankOptions(vector[CSortKey] sort_keys, CNullPlacement, CRankOptionsTiebreaker tiebreaker) vector[CSortKey] sort_keys - CNullPlacement null_placement + optional[CNullPlacement] null_placement CRankOptionsTiebreaker tiebreaker cdef cppclass CRankQuantileOptions \ "arrow::compute::RankQuantileOptions"(CFunctionOptions): + CRankQuantileOptions(vector[CSortKey] sort_keys) CRankQuantileOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys - CNullPlacement null_placement + optional[CNullPlacement] null_placement cdef enum PivotWiderUnexpectedKeyBehavior \ "arrow::compute::PivotWiderOptions::UnexpectedKeyBehavior": From 370cdacee4ae1aa15df2d2ef3a254ff588545ddd Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 15:22:13 +0200 Subject: [PATCH 18/83] format python --- python/pyarrow/_compute.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 0ad33434bb3..1c031f4a5c2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2170,7 +2170,7 @@ cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): if null_placement is None: self.wrapped.reset(new CSortOptions( - unwrap_sort_keys(sort_keys, allow_str=False))) + unwrap_sort_keys(sort_keys, allow_str=False))) else: self.wrapped.reset(new CSortOptions( unwrap_sort_keys(sort_keys, allow_str=False), From 503a7d1ebc562e58d972156d0a1a4b60769c4c99 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 15:24:32 +0200 Subject: [PATCH 19/83] fix minor python api mistakes --- python/pyarrow/_compute.pyx | 3 +-- python/pyarrow/includes/libarrow.pxd | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 1c031f4a5c2..1039d7b09c2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -78,8 +78,7 @@ cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True): cdef vector[CSortKey] c_sort_keys if allow_str and isinstance(sort_keys, str): c_sort_keys.push_back( - CSortKey(_ensure_field_ref(""), unwrap_sort_order(order), - unwrap_null_placement("at_end")) + CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) ) else: for item in sort_keys: diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 98f66ba34a1..0ce01f2bc95 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2757,6 +2757,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CNullPlacement null_placement cdef cppclass CSortKey" arrow::compute::SortKey": + CSortKey(CFieldRef target, CSortOrder order) CSortKey(CFieldRef target, CSortOrder order, CNullPlacement null_placement) CFieldRef target CSortOrder order @@ -2767,6 +2768,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CSortOptions \ "arrow::compute::SortOptions"(CFunctionOptions): + CSortOptions(vector[CSortKey] sort_keys) CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys optional[CNullPlacement] null_placement From 1d0883e8ef714f34dc6f2cf96d6b0e5db8257033 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 15:35:16 +0200 Subject: [PATCH 20/83] python formatting --- python/pyarrow/tests/test_table.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index 43eda7ec598..3b546c833b3 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -3428,6 +3428,7 @@ def test_record_batch_sort(): assert sorted_rb_dict["b"] == [2, 3, 4, 1] assert sorted_rb_dict["c"] == ["foobar", "bar", "foo", "car"] + @pytest.mark.numpy @pytest.mark.parametrize("constructor", [pa.table, pa.record_batch]) def test_numpy_asarray(constructor): From 7833addbe6827f8cf7ae923f85db9287eb6ff299 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 16:12:30 +0200 Subject: [PATCH 21/83] do not rename member of exposed API struct --- cpp/src/arrow/acero/hash_aggregate_test.cc | 2 +- cpp/src/arrow/acero/sink_node.cc | 2 +- cpp/src/arrow/compute/api_vector.cc | 14 ++++++------- cpp/src/arrow/compute/api_vector.h | 22 ++++++++++---------- cpp/src/arrow/compute/kernels/vector_sort.cc | 6 +++--- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc b/cpp/src/arrow/acero/hash_aggregate_test.cc index b1a493ecb9a..dce0e44eb13 100644 --- a/cpp/src/arrow/acero/hash_aggregate_test.cc +++ b/cpp/src/arrow/acero/hash_aggregate_test.cc @@ -688,7 +688,7 @@ namespace { void SortBy(std::vector names, Datum* aggregated_and_grouped) { SortOptions options; for (auto&& name : names) { - options.sort_keys_.emplace_back(std::move(name), SortOrder::Ascending); + options.sort_keys.emplace_back(std::move(name), SortOrder::Ascending); } ASSERT_OK_AND_ASSIGN( diff --git a/cpp/src/arrow/acero/sink_node.cc b/cpp/src/arrow/acero/sink_node.cc index dce3e6eddec..ab06dd8ffd8 100644 --- a/cpp/src/arrow/acero/sink_node.cc +++ b/cpp/src/arrow/acero/sink_node.cc @@ -476,7 +476,7 @@ struct OrderBySinkNode final : public SinkNode { } static Status ValidateOrderByOptions(const OrderBySinkNodeOptions& options) { - if (options.sort_options.sort_keys_.empty()) { + if (options.sort_options.sort_keys.empty()) { return Status::Invalid("At least one sort key should be specified"); } return ValidateCommonOrderOptions(options); diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 1fbcd70af21..6a6dee47886 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -138,7 +138,7 @@ static auto kArraySortOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); static auto kSortOptionsType = GetFunctionOptionsType( - CoercedDataMember("sort_keys", &SortOptions::sort_keys_, &SortOptions::GetSortKeys)); + CoercedDataMember("sort_keys", &SortOptions::sort_keys, &SortOptions::GetSortKeys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot), DataMember("null_placement", &PartitionNthOptions::null_placement)); @@ -152,10 +152,10 @@ static auto kCumulativeOptionsType = GetFunctionOptionsType( DataMember("start", &CumulativeOptions::start), DataMember("skip_nulls", &CumulativeOptions::skip_nulls)); static auto kRankOptionsType = GetFunctionOptionsType( - CoercedDataMember("sort_keys", &RankOptions::sort_keys_, &RankOptions::GetSortKeys), + CoercedDataMember("sort_keys", &RankOptions::sort_keys, &RankOptions::GetSortKeys), DataMember("tiebreaker", &RankOptions::tiebreaker)); static auto kRankQuantileOptionsType = GetFunctionOptionsType( - CoercedDataMember("sort_keys", &RankQuantileOptions::sort_keys_, + CoercedDataMember("sort_keys", &RankQuantileOptions::sort_keys, &RankQuantileOptions::GetSortKeys)); static auto kPairwiseOptionsType = GetFunctionOptionsType( DataMember("periods", &PairwiseOptions::periods)); @@ -197,11 +197,11 @@ constexpr char ArraySortOptions::kTypeName[]; SortOptions::SortOptions(std::vector sort_keys, std::optional null_placement) : FunctionOptions(internal::kSortOptionsType), - sort_keys_(std::move(sort_keys)), + sort_keys(std::move(sort_keys)), null_placement(null_placement) {} SortOptions::SortOptions(const Ordering& ordering) : FunctionOptions(internal::kSortOptionsType), - sort_keys_(ordering.sort_keys()), + sort_keys(ordering.sort_keys()), null_placement(ordering.null_placement()) {} constexpr char SortOptions::kTypeName[]; @@ -236,7 +236,7 @@ RankOptions::RankOptions(std::vector sort_keys, std::optional null_placement, RankOptions::Tiebreaker tiebreaker) : FunctionOptions(internal::kRankOptionsType), - sort_keys_(std::move(sort_keys)), + sort_keys(std::move(sort_keys)), null_placement(null_placement), tiebreaker(tiebreaker) {} constexpr char RankOptions::kTypeName[]; @@ -244,7 +244,7 @@ constexpr char RankOptions::kTypeName[]; RankQuantileOptions::RankQuantileOptions(std::vector sort_keys, std::optional null_placement) : FunctionOptions(internal::kRankQuantileOptionsType), - sort_keys_(std::move(sort_keys)), + sort_keys(std::move(sort_keys)), null_placement(null_placement) {} constexpr char RankQuantileOptions::kTypeName[]; diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 83f3fd6438b..e56242a3af6 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -114,19 +114,19 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// Note: Both classes contain the exact same information. However, /// sort_options should only be used in a "function options" context while Ordering /// is used more generally. - Ordering AsOrdering() && { return Ordering(std::move(sort_keys_), null_placement); } - Ordering AsOrdering() const& { return Ordering(sort_keys_, null_placement); } + Ordering AsOrdering() && { return Ordering(std::move(sort_keys), null_placement); } + Ordering AsOrdering() const& { return Ordering(sort_keys, null_placement); } /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys_; + std::vector sort_keys; // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { if (!null_placement.has_value()) { - return sort_keys_; + return sort_keys; } - auto overwritten_sort_keys = sort_keys_; + auto overwritten_sort_keys = sort_keys; for (auto& sort_key : overwritten_sort_keys) { sort_key.null_placement = null_placement.value(); } @@ -217,15 +217,15 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { static RankOptions Defaults() { return RankOptions(); } /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys_; + std::vector sort_keys; // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { if (!null_placement.has_value()) { - return sort_keys_; + return sort_keys; } - auto overwritten_sort_keys = sort_keys_; + auto overwritten_sort_keys = sort_keys; for (auto& sort_key : overwritten_sort_keys) { sort_key.null_placement = null_placement.value(); } @@ -256,15 +256,15 @@ class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { static RankQuantileOptions Defaults() { return RankQuantileOptions(); } /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys_; + std::vector sort_keys; // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { if (!null_placement.has_value()) { - return sort_keys_; + return sort_keys; } - auto overwritten_sort_keys = sort_keys_; + auto overwritten_sort_keys = sort_keys; for (auto& sort_key : overwritten_sort_keys) { sort_key.null_placement = null_placement.value(); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 58f82fb0743..03e6090a207 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -1033,7 +1033,7 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const Table& table, const SortOptions& options, ExecContext* ctx) const { - auto n_sort_keys = options.sort_keys_.size(); + auto n_sort_keys = options.sort_keys.size(); if (n_sort_keys == 0) { return Status::Invalid("Must specify one or more sort keys"); } @@ -1165,9 +1165,9 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ std::move(columns)); auto options = SortOptions::Defaults(); - options.sort_keys_.reserve(array.num_fields()); + options.sort_keys.reserve(array.num_fields()); for (int i = 0; i < array.num_fields(); ++i) { - options.sort_keys_.push_back(SortKey(FieldRef(i), sort_order, null_placement)); + options.sort_keys.push_back(SortKey(FieldRef(i), sort_order, null_placement)); } ARROW_ASSIGN_OR_RAISE(auto sort_keys, From 6e2759babfafe5aa839b5d745fcc27e2a1764d01 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 16:37:49 +0200 Subject: [PATCH 22/83] do not rename member of exposed API struct, missed test file --- .../arrow/compute/kernels/vector_sort_test.cc | 106 +++++++++--------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 1100af19b35..48f69323a22 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1233,9 +1233,9 @@ TEST_F(TestRecordBatchSortIndices, Null) { SortOptions options(sort_keys); AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[0, 3, 5, 1, 4, 6, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1260,10 +1260,10 @@ TEST_F(TestRecordBatchSortIndices, MixedNullOrdering) { SortOptions options(sort_keys, std::nullopt); AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.sort_keys_.at(0).null_placement = NullPlacement::AtStart; + options.sort_keys.at(0).null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[0, 3, 5, 1, 4, 6, 2]"); - options.sort_keys_.at(1).null_placement = NullPlacement::AtStart; + options.sort_keys.at(1).null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1287,9 +1287,9 @@ TEST_F(TestRecordBatchSortIndices, NaN) { SortOptions options(sort_keys); AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } @@ -1313,9 +1313,9 @@ TEST_F(TestRecordBatchSortIndices, NaNAndNull) { SortOptions options(sort_keys); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1339,9 +1339,9 @@ TEST_F(TestRecordBatchSortIndices, Boolean) { SortOptions options(sort_keys); AssertSortIndices(batch, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[7, 5, 3, 1, 6, 2, 4, 0]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } @@ -1366,7 +1366,7 @@ TEST_F(TestRecordBatchSortIndices, MoreTypes) { for (auto null_placement : AllNullPlacements()) { SortOptions options(sort_keys); for (size_t i = 0; i < sort_keys.size(); i++) { - options.sort_keys_[i].null_placement = null_placement; + options.sort_keys[i].null_placement = null_placement; } AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]"); } @@ -1389,9 +1389,9 @@ TEST_F(TestRecordBatchSortIndices, Decimal) { SortOptions options(sort_keys); AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 4, 0, 2, 1]"); } @@ -1471,9 +1471,9 @@ TEST_F(TestRecordBatchSortIndices, DuplicateSortKeys) { SortOptions options(sort_keys); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1494,10 +1494,10 @@ TEST_F(TestTableSortIndices, EmptyTable) { SortOptions options(sort_keys); AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); } @@ -1541,9 +1541,9 @@ TEST_F(TestTableSortIndices, Null) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[0, 3, 5, 1, 4, 6, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); // Same data, several chunks @@ -1556,12 +1556,12 @@ TEST_F(TestTableSortIndices, Null) { {"a": 1, "b": 5}, {"a": 3, "b": 5} ])"}); - options.sort_keys_[0].null_placement = NullPlacement::AtEnd; - options.sort_keys_[1].null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[0, 3, 5, 1, 4, 6, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1585,9 +1585,9 @@ TEST_F(TestTableSortIndices, NaN) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); // Same data, several chunks @@ -1601,12 +1601,12 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - options.sort_keys_[0].null_placement = NullPlacement::AtEnd; - options.sort_keys_[1].null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } @@ -1630,9 +1630,9 @@ TEST_F(TestTableSortIndices, NaNAndNull) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); // Same data, several chunks @@ -1646,12 +1646,12 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - options.sort_keys_[0].null_placement = NullPlacement::AtEnd; - options.sort_keys_[1].null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1675,9 +1675,9 @@ TEST_F(TestTableSortIndices, Boolean) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[7, 5, 3, 1, 6, 2, 4, 0]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } @@ -1701,8 +1701,8 @@ TEST_F(TestTableSortIndices, BinaryLike) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[1, 5, 2, 6, 0, 4, 7, 3]"); } @@ -1723,9 +1723,9 @@ TEST_F(TestTableSortIndices, Decimal) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 4, 0, 2, 1]"); } @@ -1789,9 +1789,9 @@ TEST_F(TestTableSortIndices, DuplicateSortKeys) { ])"}); SortOptions options(sort_keys); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1811,17 +1811,17 @@ TEST_F(TestTableSortIndices, HeterogenousChunking) { SortOptions options( {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); options = SortOptions( {SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}); AssertSortIndices(table, options, "[1, 7, 6, 0, 5, 2, 4, 3]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[2, 4, 3, 5, 1, 7, 6, 0]"); - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 4, 2, 5, 1, 0, 6, 7]"); } @@ -1853,7 +1853,7 @@ TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { for (auto null_placement : AllNullPlacements()) { SortOptions options(sort_keys); for (size_t i = 0; i < sort_keys.size(); i++) { - options.sort_keys_[i].null_placement = null_placement; + options.sort_keys[i].null_placement = null_placement; } AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); } @@ -2245,17 +2245,17 @@ class TestNestedSortIndices : public ::testing::Test { SortOptions options(sort_keys); AssertSortIndices(datum, options, "[7, 6, 3, 4, 0, 2, 1, 8, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; - options.sort_keys_[1].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(datum, options, "[5, 2, 1, 8, 3, 7, 6, 0, 4]"); // Implementations may have an optimized path for cases with one sort key. // Additionally, this key references a struct containing another struct, which should // work recursively - options.sort_keys_ = {SortKey(FieldRef("a"), SortOrder::Ascending)}; - options.sort_keys_[0].null_placement = NullPlacement::AtEnd; + options.sort_keys = {SortKey(FieldRef("a"), SortOrder::Ascending)}; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; AssertSortIndices(datum, options, "[6, 7, 3, 4, 0, 8, 1, 2, 5]"); - options.sort_keys_[0].null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(datum, options, "[5, 8, 1, 2, 3, 6, 7, 0, 4]"); } From 0a029cac00e68644a4f668d21207b568e3e60c88 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 16:39:12 +0200 Subject: [PATCH 23/83] format cc file --- cpp/src/arrow/compute/api_vector.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 6a6dee47886..1f776f9d9dd 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -154,9 +154,9 @@ static auto kCumulativeOptionsType = GetFunctionOptionsType( static auto kRankOptionsType = GetFunctionOptionsType( CoercedDataMember("sort_keys", &RankOptions::sort_keys, &RankOptions::GetSortKeys), DataMember("tiebreaker", &RankOptions::tiebreaker)); -static auto kRankQuantileOptionsType = GetFunctionOptionsType( - CoercedDataMember("sort_keys", &RankQuantileOptions::sort_keys, - &RankQuantileOptions::GetSortKeys)); +static auto kRankQuantileOptionsType = + GetFunctionOptionsType(CoercedDataMember( + "sort_keys", &RankQuantileOptions::sort_keys, &RankQuantileOptions::GetSortKeys)); static auto kPairwiseOptionsType = GetFunctionOptionsType( DataMember("periods", &PairwiseOptions::periods)); static auto kListFlattenOptionsType = GetFunctionOptionsType( From 3463a59198ef4552062720be0a4734043b31893e Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 17:36:55 +0200 Subject: [PATCH 24/83] updating api that was not using additional argument for some reason (most likely human-error while merging) --- cpp/src/arrow/acero/order_by_node.cc | 2 +- cpp/src/arrow/acero/order_by_node_test.cc | 10 +++++----- python/pyarrow/array.pxi | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/acero/order_by_node.cc b/cpp/src/arrow/acero/order_by_node.cc index 51a07a6c5f7..65aa83247f8 100644 --- a/cpp/src/arrow/acero/order_by_node.cc +++ b/cpp/src/arrow/acero/order_by_node.cc @@ -116,7 +116,7 @@ class OrderByNode : public ExecNode, public TracedNode { ARROW_ASSIGN_OR_RAISE( auto table, Table::FromRecordBatches(output_schema_, std::move(accumulation_queue_))); - SortOptions sort_options(ordering_.sort_keys()); + SortOptions sort_options(ordering_.sort_keys(), ordering_.null_placement()); ExecContext* ctx = plan_->query_context()->exec_context(); ARROW_ASSIGN_OR_RAISE(auto indices, SortIndices(table, sort_options, ctx)); ARROW_ASSIGN_OR_RAISE(Datum sorted, diff --git a/cpp/src/arrow/acero/order_by_node_test.cc b/cpp/src/arrow/acero/order_by_node_test.cc index a76ac16da3e..37e6862ed0f 100644 --- a/cpp/src/arrow/acero/order_by_node_test.cc +++ b/cpp/src/arrow/acero/order_by_node_test.cc @@ -76,10 +76,10 @@ void CheckOrderByInvalid(OrderByNodeOptions options, const std::string& message) } TEST(OrderByNode, Basic) { - CheckOrderBy(OrderByNodeOptions(Ordering{{SortKey("up")}})); - CheckOrderBy(OrderByNodeOptions(Ordering({SortKey("down", SortOrder::Descending)}))); - CheckOrderBy(OrderByNodeOptions( - Ordering({SortKey("up"), SortKey("down", SortOrder::Descending)}))); + CheckOrderBy(OrderByNodeOptions({{SortKey("up")}})); + CheckOrderBy(OrderByNodeOptions({{SortKey("down", SortOrder::Descending)}})); + CheckOrderBy( + OrderByNodeOptions({{SortKey("up"), SortKey("down", SortOrder::Descending)}})); } TEST(OrderByNode, Large) { @@ -94,7 +94,7 @@ TEST(OrderByNode, Large) { ->Table(ExecPlan::kMaxBatchSize, kSmallNumBatches); Declaration plan = Declaration::Sequence({ {"table_source", TableSourceNodeOptions(input)}, - {"order_by", OrderByNodeOptions(Ordering({SortKey("up", SortOrder::Descending)}))}, + {"order_by", OrderByNodeOptions({{SortKey("up", SortOrder::Descending)}})}, {"jitter", JitterNodeOptions(kSeed, kJitterMod)}, }); ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema batches_and_schema, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index bef86e023a8..4bc4851d747 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1694,7 +1694,7 @@ cdef class Array(_PandasConvertible): self._assert_cpu() indices = _pc().sort_indices( self, - options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + options=_pc().SortOptions(sort_keys=[("", order, null_placement)], **kwargs) ) return self.take(indices) @@ -4322,7 +4322,7 @@ cdef class StructArray(Array): result : StructArray """ if by is not None: - tosort, sort_keys = self._flattened_field(by), [("", order)] + tosort, sort_keys = self._flattened_field(by), [("", order, null_placement)] else: tosort, sort_keys = self, [ (field.name, order, null_placement) for field in self.type] From e01a009056060edbd13f7159e672f51fe4a185d6 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 17:57:16 +0200 Subject: [PATCH 25/83] make null_placement optional in the python acero api --- python/pyarrow/_acero.pyx | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/_acero.pyx b/python/pyarrow/_acero.pyx index 04f72b4b9fd..b7b56cc8b69 100644 --- a/python/pyarrow/_acero.pyx +++ b/python/pyarrow/_acero.pyx @@ -234,12 +234,19 @@ class AggregateNodeOptions(_AggregateNodeOptions): cdef class _OrderByNodeOptions(ExecNodeOptions): def _set_options(self, sort_keys, null_placement): - self.wrapped.reset( - new COrderByNodeOptions( - COrdering(unwrap_sort_keys(sort_keys, allow_str=False), - unwrap_null_placement(null_placement)) + if null_placement is None: + self.wrapped.reset( + new COrderByNodeOptions( + COrdering(unwrap_sort_keys(sort_keys, allow_str=False)) + ) + ) + else: + self.wrapped.reset( + new COrderByNodeOptions( + COrdering(unwrap_sort_keys(sort_keys, allow_str=False), + unwrap_null_placement(null_placement)) + ) ) - ) class OrderByNodeOptions(_OrderByNodeOptions): From 794a2cd42929a1bea1ad055f90d12f9f5199f008 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Thu, 26 Jun 2025 23:55:15 +0200 Subject: [PATCH 26/83] minor additional fixes --- cpp/src/arrow/compute/ordering.h | 2 +- python/pyarrow/_acero.pyx | 2 +- python/pyarrow/includes/libarrow.pxd | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index 91bab5fe5b2..3764dfcdcb1 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -64,7 +64,7 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { class ARROW_EXPORT Ordering : public util::EqualityComparable { public: Ordering(std::vector sort_keys, - std::optional null_placement = NullPlacement::AtStart) + std::optional null_placement = std::nullopt) : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} /// true if data ordered by other is also ordered by this /// diff --git a/python/pyarrow/_acero.pyx b/python/pyarrow/_acero.pyx index b7b56cc8b69..2b6b2af1e41 100644 --- a/python/pyarrow/_acero.pyx +++ b/python/pyarrow/_acero.pyx @@ -270,7 +270,7 @@ class OrderByNodeOptions(_OrderByNodeOptions): null_placement : str, optional Where nulls in input should be sorted, only applying to columns/fields mentioned in `sort_keys`. - Accepted values are "at_start", "at_end", with "at_end" being the default. + Accepted values are "at_start", "at_end", """ def __init__(self, sort_keys=(), *, null_placement=None): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0ce01f2bc95..74770ce3111 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2764,6 +2764,7 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: CNullPlacement null_placement cdef cppclass COrdering" arrow::compute::Ordering": + COrdering(vector[CSortKey] sort_keys) COrdering(vector[CSortKey] sort_keys, CNullPlacement null_placement) cdef cppclass CSortOptions \ From adaae376d8a78df1b80e0808b140dfd5997633d9 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 27 Jun 2025 08:45:03 +0200 Subject: [PATCH 27/83] amend c_glib api to use std::optional for RankOptions --- c_glib/arrow-glib/compute.cpp | 6 +++--- c_glib/arrow-glib/compute.h | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index c91d747e177..03a42bc0be4 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4391,7 +4391,7 @@ garrow_rank_options_get_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - g_value_set_enum(value, static_cast(options->null_placement)); + g_value_set_enum(value, static_cast(options->null_placement)); break; case PROP_RANK_OPTIONS_TIEBREAKER: g_value_set_enum(value, static_cast(options->tiebreaker)); @@ -4432,8 +4432,8 @@ garrow_rank_options_class_init(GArrowRankOptionsClass *klass) "Null placement", "Whether nulls and NaNs are placed " "at the start or at the end.", - GARROW_TYPE_NULL_PLACEMENT, - static_cast(options.null_placement), + GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, + static_cast(options.null_placement), static_cast(G_PARAM_READWRITE)); g_object_class_install_property(gobject_class, PROP_RANK_OPTIONS_NULL_PLACEMENT, spec); diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index afdb7579a02..1f4f3c2ee1e 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -508,6 +508,28 @@ typedef enum /**/ { GARROW_NULL_PLACEMENT_AT_END, } GArrowNullPlacement; +/** + * GArrowOptionalNullPlacement: + * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: + * Place nulls and NaNs before any non-null values. + * NaNs will come after nulls. + * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: + * Place nulls and NaNs after any non-null values. + * NaNs will come before nulls. + * @GARROW_OPTIONAL_NULL_PLACEMENT_UNSET: + * Do not specify null placement. + * Null placement should instead + * + * They are corresponding to `std::optional` values. + * + * Since: 12.0.0 + */ +typedef enum /**/ { + GARROW_OPTIONAL_NULL_PLACEMENT_AT_START, + GARROW_OPTIONAL_NULL_PLACEMENT_AT_END, + GARROW_OPTIONAL_NULL_PLACEMENT_UNSET, +} GArrowOptionalNullPlacement; + #define GARROW_TYPE_ARRAY_SORT_OPTIONS (garrow_array_sort_options_get_type()) GARROW_AVAILABLE_IN_3_0 G_DECLARE_DERIVABLE_TYPE(GArrowArraySortOptions, From 14c3914cfad32e2eacb19936dd61107bb97181d3 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 27 Jun 2025 08:53:46 +0200 Subject: [PATCH 28/83] amend c_glib api to use std::optional for RankOptions --- c_glib/arrow-glib/compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 03a42bc0be4..8e5070afdf3 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4369,7 +4369,7 @@ garrow_rank_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: options->null_placement = - static_cast(g_value_get_enum(value)); + static_cast>(g_value_get_enum(value)); break; case PROP_RANK_OPTIONS_TIEBREAKER: options->tiebreaker = From 791d8e9909b8a2097dfeb8d791894d26f61693be Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 27 Jun 2025 09:09:48 +0200 Subject: [PATCH 29/83] amend c_glib api to use std::optional for RankOptions --- c_glib/arrow-glib/compute.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 8e5070afdf3..e147c557308 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4368,8 +4368,16 @@ garrow_rank_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - options->null_placement = - static_cast>(g_value_get_enum(value)); + auto val = g_value_get_enum(value); + if (val == GARROW_OPTIONAL_NULL_PLACEMENT_AT_START) { + options->null_placement = arrow::compute::NullPlacement::AtStart; + } + else if (val == GARROW_OPTIONAL_NULL_PLACEMENT_AT_END) { + options->null_placement = arrow::compute::NullPlacement::AtEnd; + } + else if (val == GARROW_OPTIONAL_NULL_PLACEMENT_UNSET) { + options->null_placement = std::nullopt; + } break; case PROP_RANK_OPTIONS_TIEBREAKER: options->tiebreaker = From bee7e81662c270d2ddc2203c39b226a1db4296b6 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 29 Jun 2025 12:28:38 +0200 Subject: [PATCH 30/83] amend c_glib api to use std::optional for RankOptions --- c_glib/arrow-glib/compute.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index e147c557308..0ac0529db84 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4399,7 +4399,15 @@ garrow_rank_options_get_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - g_value_set_enum(value, static_cast(options->null_placement)); + if(!options->null_placement.has_value()){ + g_value_set_enum(value, GARROW_OPTIONAL_NULL_PLACEMENT_UNSET); + } + else if(options->null_placement.value() == arrow::compute::NullPlacement::AtStart){ + g_value_set_enum(value, GARROW_OPTIONAL_NULL_PLACEMENT_AT_START); + } + else { + g_value_set_enum(value, GARROW_OPTIONAL_NULL_PLACEMENT_AT_END); + } break; case PROP_RANK_OPTIONS_TIEBREAKER: g_value_set_enum(value, static_cast(options->tiebreaker)); From cbfd596d3ad3bc91f2022d3ac469dc9ac06980b9 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 29 Jun 2025 12:28:58 +0200 Subject: [PATCH 31/83] fix mistake where default null placement was AtStart --- cpp/src/arrow/compute/kernels/vector_sort.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 03e6090a207..692c918625a 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -959,7 +959,7 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; - NullPlacement null_placement = NullPlacement::AtStart; + NullPlacement null_placement = NullPlacement::AtEnd; auto sort_keys = options.GetSortKeys(); if (!sort_keys.empty()) { order = sort_keys[0].order; @@ -972,7 +972,7 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const ChunkedArray& chunked_array, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; - NullPlacement null_placement = NullPlacement::AtStart; + NullPlacement null_placement = NullPlacement::AtEnd; auto sort_keys = options.GetSortKeys(); if (!sort_keys.empty()) { order = sort_keys[0].order; From 779169eb33c584fe8eb4559d3b22733a06c7a51b Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 29 Jun 2025 12:45:15 +0200 Subject: [PATCH 32/83] fix mistake where null_placement was not correctly set when sort_keys are empty --- cpp/src/arrow/compute/kernels/vector_rank.cc | 10 ++++++---- cpp/src/arrow/compute/kernels/vector_sort.cc | 20 ++++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index f862ac0bd22..c6fe7430e52 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -347,10 +347,12 @@ class RankMetaFunctionBase : public MetaFunction { SortOrder order = SortOrder::Ascending; NullPlacement null_placement = NullPlacement::AtEnd; - auto sort_keys = options.GetSortKeys(); - if (!sort_keys.empty()) { - order = sort_keys[0].order; - null_placement = sort_keys[0].null_placement; + if (!options.sort_keys.empty()) { + order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; + } + if(options.null_placement.has_value()){ + null_placement = options.null_placement.value(); } int64_t length = input.length(); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 692c918625a..44bbbf2e302 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -960,10 +960,12 @@ class SortIndicesMetaFunction : public MetaFunction { ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; NullPlacement null_placement = NullPlacement::AtEnd; - auto sort_keys = options.GetSortKeys(); - if (!sort_keys.empty()) { - order = sort_keys[0].order; - null_placement = sort_keys[0].null_placement; + if (!options.sort_keys.empty()) { + order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; + } + if(options.null_placement.has_value()){ + null_placement = options.null_placement.value(); } ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); @@ -973,10 +975,12 @@ class SortIndicesMetaFunction : public MetaFunction { ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; NullPlacement null_placement = NullPlacement::AtEnd; - auto sort_keys = options.GetSortKeys(); - if (!sort_keys.empty()) { - order = sort_keys[0].order; - null_placement = sort_keys[0].null_placement; + if (!options.sort_keys.empty()) { + order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; + } + if(options.null_placement.has_value()){ + null_placement = options.null_placement.value(); } auto out_type = uint64(); From dbd75d0ae74b67934436820fd129655228463cc5 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 29 Jun 2025 14:16:05 +0200 Subject: [PATCH 33/83] fix c_glib bindings --- c_glib/arrow-glib/compute.cpp | 60 +++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 0ac0529db84..b8b295da63e 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4347,6 +4347,39 @@ garrow_index_options_new(void) return GARROW_INDEX_OPTIONS(g_object_new(GARROW_TYPE_INDEX_OPTIONS, NULL)); } +namespace { +static GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw(const std::optional& arrow_placement) +{ + if (!arrow_placement.has_value()) { + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; + } + + switch (arrow_placement.value()) { + case arrow::compute::NullPlacement::AtStart: + return GARROW_OPTIONAL_NULL_PLACEMENT_AT_START; + case arrow::compute::NullPlacement::AtEnd: + return GARROW_OPTIONAL_NULL_PLACEMENT_AT_END; + default: + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; + } +} + +static std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_placement) +{ + switch (garrow_placement) { + case GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: + return arrow::compute::NullPlacement::AtStart; + case GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: + return arrow::compute::NullPlacement::AtEnd; + case GARROW_OPTIONAL_NULL_PLACEMENT_UNSET: + default: + return std::nullopt; + } +} +} + enum { PROP_RANK_OPTIONS_NULL_PLACEMENT = 1, PROP_RANK_OPTIONS_TIEBREAKER, @@ -4368,16 +4401,10 @@ garrow_rank_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - auto val = g_value_get_enum(value); - if (val == GARROW_OPTIONAL_NULL_PLACEMENT_AT_START) { - options->null_placement = arrow::compute::NullPlacement::AtStart; - } - else if (val == GARROW_OPTIONAL_NULL_PLACEMENT_AT_END) { - options->null_placement = arrow::compute::NullPlacement::AtEnd; - } - else if (val == GARROW_OPTIONAL_NULL_PLACEMENT_UNSET) { - options->null_placement = std::nullopt; - } + options->null_placement = + garrow_optional_null_placement_to_raw( + static_cast(g_value_get_enum(value)) + ); break; case PROP_RANK_OPTIONS_TIEBREAKER: options->tiebreaker = @@ -4399,15 +4426,8 @@ garrow_rank_options_get_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - if(!options->null_placement.has_value()){ - g_value_set_enum(value, GARROW_OPTIONAL_NULL_PLACEMENT_UNSET); - } - else if(options->null_placement.value() == arrow::compute::NullPlacement::AtStart){ - g_value_set_enum(value, GARROW_OPTIONAL_NULL_PLACEMENT_AT_START); - } - else { - g_value_set_enum(value, GARROW_OPTIONAL_NULL_PLACEMENT_AT_END); - } + g_value_set_enum(value, + garrow_optional_null_placement_from_raw(options->null_placement)); break; case PROP_RANK_OPTIONS_TIEBREAKER: g_value_set_enum(value, static_cast(options->tiebreaker)); @@ -4449,7 +4469,7 @@ garrow_rank_options_class_init(GArrowRankOptionsClass *klass) "Whether nulls and NaNs are placed " "at the start or at the end.", GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, - static_cast(options.null_placement), + garrow_optional_null_placement_from_raw(options.null_placement), static_cast(G_PARAM_READWRITE)); g_object_class_install_property(gobject_class, PROP_RANK_OPTIONS_NULL_PLACEMENT, spec); From 67e99ac06ea926e253cb687e3a6c664d6b875310 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 29 Jun 2025 16:33:50 +0200 Subject: [PATCH 34/83] formatting --- c_glib/arrow-glib/compute.cpp | 78 ++++++++++---------- cpp/src/arrow/compute/kernels/vector_rank.cc | 2 +- cpp/src/arrow/compute/kernels/vector_sort.cc | 4 +- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index b8b295da63e..603cb3ffffa 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4348,37 +4348,37 @@ garrow_index_options_new(void) } namespace { -static GArrowOptionalNullPlacement -garrow_optional_null_placement_from_raw(const std::optional& arrow_placement) -{ - if (!arrow_placement.has_value()) { - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; - } + static GArrowOptionalNullPlacement + garrow_optional_null_placement_from_raw( + const std::optional &arrow_placement) + { + if (!arrow_placement.has_value()) { + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; + } - switch (arrow_placement.value()) { - case arrow::compute::NullPlacement::AtStart: - return GARROW_OPTIONAL_NULL_PLACEMENT_AT_START; - case arrow::compute::NullPlacement::AtEnd: - return GARROW_OPTIONAL_NULL_PLACEMENT_AT_END; - default: - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; - } -} + switch (arrow_placement.value()) { + case arrow::compute::NullPlacement::AtStart: + return GARROW_OPTIONAL_NULL_PLACEMENT_AT_START; + case arrow::compute::NullPlacement::AtEnd: + return GARROW_OPTIONAL_NULL_PLACEMENT_AT_END; + default: + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; + } -static std::optional -garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_placement) -{ - switch (garrow_placement) { - case GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: - return arrow::compute::NullPlacement::AtStart; - case GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: - return arrow::compute::NullPlacement::AtEnd; - case GARROW_OPTIONAL_NULL_PLACEMENT_UNSET: - default: - return std::nullopt; + static std::optional + garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_placement) + { + switch (garrow_placement) { + case GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: + return arrow::compute::NullPlacement::AtStart; + case GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: + return arrow::compute::NullPlacement::AtEnd; + case GARROW_OPTIONAL_NULL_PLACEMENT_UNSET: + default: + return std::nullopt; + } } -} -} +} // namespace enum { PROP_RANK_OPTIONS_NULL_PLACEMENT = 1, @@ -4401,10 +4401,8 @@ garrow_rank_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - options->null_placement = - garrow_optional_null_placement_to_raw( - static_cast(g_value_get_enum(value)) - ); + options->null_placement = garrow_optional_null_placement_to_raw( + static_cast(g_value_get_enum(value))); break; case PROP_RANK_OPTIONS_TIEBREAKER: options->tiebreaker = @@ -4464,13 +4462,15 @@ garrow_rank_options_class_init(GArrowRankOptionsClass *klass) * * Since: 12.0.0 */ - spec = g_param_spec_enum("null-placement", - "Null placement", - "Whether nulls and NaNs are placed " - "at the start or at the end.", - GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, - garrow_optional_null_placement_from_raw(options.null_placement), - static_cast(G_PARAM_READWRITE)); + spec = + g_param_spec_enum("null-placement", + "Null placement", + "Whether nulls and NaNs are placed " + "at the start or at the end.", + GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, + garrow_optional_null_placement_from_raw(options.null_placement), + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_RANK_OPTIONS_NULL_PLACEMENT, spec); /** diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index c6fe7430e52..fd5e2075d73 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -351,7 +351,7 @@ class RankMetaFunctionBase : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } - if(options.null_placement.has_value()){ + if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 44bbbf2e302..0d292652537 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -964,7 +964,7 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } - if(options.null_placement.has_value()){ + if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } ArraySortOptions array_options(order, null_placement); @@ -979,7 +979,7 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } - if(options.null_placement.has_value()){ + if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } From e26fbc4d5fabcff85deb1bf744d7f0e36046ad4a Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 29 Jun 2025 17:19:14 +0200 Subject: [PATCH 35/83] formatting --- c_glib/arrow-glib/compute.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 603cb3ffffa..847ca8ba078 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4364,6 +4364,7 @@ namespace { default: return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; } + } static std::optional garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_placement) From e07a67339d2b3346ea173850f5cd10bc38932998 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 10:49:02 +0100 Subject: [PATCH 36/83] also update RankQuantileOptions to new NullPlacement api --- c_glib/arrow-glib/compute.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 6db7cd44ce1..8a92d36f70f 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -8892,8 +8892,8 @@ garrow_rank_quantile_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT: - options->null_placement = - static_cast(g_value_get_enum(value)); + options->null_placement = garrow_optional_null_placement_to_raw( + static_cast(g_value_get_enum(value))); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); @@ -8912,7 +8912,8 @@ garrow_rank_quantile_options_get_property(GObject *object, switch (prop_id) { case PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT: - g_value_set_enum(value, static_cast(options->null_placement)); + g_value_set_enum(value, + garrow_optional_null_placement_from_raw(options->null_placement)); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); @@ -8950,8 +8951,8 @@ garrow_rank_quantile_options_class_init(GArrowRankQuantileOptionsClass *klass) "Null placement", "Whether nulls and NaNs are placed " "at the start or at the end.", - GARROW_TYPE_NULL_PLACEMENT, - static_cast(options.null_placement), + GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, + garrow_optional_null_placement_from_raw(options.null_placement), static_cast(G_PARAM_READWRITE)); g_object_class_install_property(gobject_class, PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT, From 290812f0144fc1294bb0b57f9a3e977122507b37 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 11:02:47 +0100 Subject: [PATCH 37/83] update library version documentation --- c_glib/arrow-glib/compute.cpp | 2 +- c_glib/arrow-glib/compute.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 8a92d36f70f..565fe28d20c 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -3229,7 +3229,7 @@ garrow_sort_key_class_init(GArrowSortKeyClass *klass) * * Whether nulls and NaNs are placed at the start or at the end. * - * Since: 15.0.0 + * Since: 24.0.0 */ spec = g_param_spec_enum( "null-placement", diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index c45c3717fb7..1bdb8609ead 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -523,7 +523,7 @@ typedef enum /**/ { * * They are corresponding to `std::optional` values. * - * Since: 12.0.0 + * Since: 24.0.0 */ typedef enum /**/ { GARROW_OPTIONAL_NULL_PLACEMENT_AT_START, From 5320622b504ef807945034f6a6d0c14395af05ee Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 11:03:15 +0100 Subject: [PATCH 38/83] rename GARROW_OPTIONAL_NULL_PLACEMENT_UNSET to GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED --- c_glib/arrow-glib/compute.cpp | 6 +++--- c_glib/arrow-glib/compute.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 565fe28d20c..2e813b51734 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4535,7 +4535,7 @@ namespace { const std::optional &arrow_placement) { if (!arrow_placement.has_value()) { - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; } switch (arrow_placement.value()) { @@ -4544,7 +4544,7 @@ namespace { case arrow::compute::NullPlacement::AtEnd: return GARROW_OPTIONAL_NULL_PLACEMENT_AT_END; default: - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSET; + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; } } @@ -4556,7 +4556,7 @@ namespace { return arrow::compute::NullPlacement::AtStart; case GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: return arrow::compute::NullPlacement::AtEnd; - case GARROW_OPTIONAL_NULL_PLACEMENT_UNSET: + case GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED: default: return std::nullopt; } diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index 1bdb8609ead..26e8d25db24 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -517,7 +517,7 @@ typedef enum /**/ { * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: * Place nulls and NaNs after any non-null values. * NaNs will come before nulls. - * @GARROW_OPTIONAL_NULL_PLACEMENT_UNSET: + * @GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED: * Do not specify null placement. * Null placement should instead * @@ -528,7 +528,7 @@ typedef enum /**/ { typedef enum /**/ { GARROW_OPTIONAL_NULL_PLACEMENT_AT_START, GARROW_OPTIONAL_NULL_PLACEMENT_AT_END, - GARROW_OPTIONAL_NULL_PLACEMENT_UNSET, + GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED, } GArrowOptionalNullPlacement; #define GARROW_TYPE_ARRAY_SORT_OPTIONS (garrow_array_sort_options_get_type()) From bf31270a161813b8a7a12ae0b3ba8d06032c222d Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 12:52:15 +0100 Subject: [PATCH 39/83] formatting --- c_glib/arrow-glib/compute.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 2e813b51734..2d402ca9608 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -8947,13 +8947,14 @@ garrow_rank_quantile_options_class_init(GArrowRankQuantileOptionsClass *klass) * * Since: 23.0.0 */ - spec = g_param_spec_enum("null-placement", - "Null placement", - "Whether nulls and NaNs are placed " - "at the start or at the end.", - GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, - garrow_optional_null_placement_from_raw(options.null_placement), - static_cast(G_PARAM_READWRITE)); + spec = + g_param_spec_enum("null-placement", + "Null placement", + "Whether nulls and NaNs are placed " + "at the start or at the end.", + GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, + garrow_optional_null_placement_from_raw(options.null_placement), + static_cast(G_PARAM_READWRITE)); g_object_class_install_property(gobject_class, PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT, spec); From 2e58463c0341a4c0e006d650524634808072e157 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 13:33:56 +0100 Subject: [PATCH 40/83] ruby default value formatting --- ruby/red-arrow/lib/arrow/sort-key.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ruby/red-arrow/lib/arrow/sort-key.rb b/ruby/red-arrow/lib/arrow/sort-key.rb index 6438d371623..604c79fcfec 100644 --- a/ruby/red-arrow/lib/arrow/sort-key.rb +++ b/ruby/red-arrow/lib/arrow/sort-key.rb @@ -46,7 +46,7 @@ class << self # @return [Arrow::SortKey] A new suitable sort key. # # @since 4.0.0 - def resolve(target, order=nil, null_placement = nil) + def resolve(target, order=nil, null_placement=nil) return target if target.is_a?(self) new(target, order, null_placement) end From 79995dff06025aefdfaeb2fe55f073c390804981 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 13:34:34 +0100 Subject: [PATCH 41/83] change red-arrow gem to use different syntax for specifying null_placement --- ruby/red-arrow/lib/arrow/sort-key.rb | 151 ++++++++++++++++++--------- ruby/red-arrow/test/test-sort-key.rb | 39 +++++-- 2 files changed, 131 insertions(+), 59 deletions(-) diff --git a/ruby/red-arrow/lib/arrow/sort-key.rb b/ruby/red-arrow/lib/arrow/sort-key.rb index 604c79fcfec..ec5b40a98b6 100644 --- a/ruby/red-arrow/lib/arrow/sort-key.rb +++ b/ruby/red-arrow/lib/arrow/sort-key.rb @@ -71,37 +71,46 @@ def try_convert(value) # @param target [Symbol, String] The name or dot path of the # sort column. # - # If `target` is a String, the first character may be - # processed as the "leading order mark". If the first - # character is `"+"` or `"-"`, they are processed as a leading - # order mark. If the first character is processed as a leading - # order mark, the first character is removed from sort column - # target and corresponding order is used. `"+"` uses ascending - # order and `"-"` uses ascending order. - # - # If `target` is either not a String or `target` doesn't start - # with the leading order mark, sort column is `target` as-is - # and ascending order is used. - # - # @example String without the leading order mark - # key = Arrow::SortKey.new("count") - # key.target # => "count" - # key.order # => Arrow::SortOrder::ASCENDING + # If `target` is a String, it may have prefix markers that specify + # the sort order and null placement. The format is `[+/-][^/$]column`: # - # @example String with the "+" leading order mark - # key = Arrow::SortKey.new("+count") - # key.target # => "count" - # key.order # => Arrow::SortOrder::ASCENDING + # - `"+"` prefix means ascending order + # - `"-"` prefix means descending order + # - `"^"` prefix means nulls at start + # - `"$"` prefix means nulls at end # - # @example String with the "-" leading order mark + # If `target` is a Symbol, it is converted to String and used as-is + # (no prefix processing). + # + # @example String without any prefix + # key = Arrow::SortKey.new("count") + # key.target # => "count" + # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END + # + # @example String with order prefix only # key = Arrow::SortKey.new("-count") - # key.target # => "count" - # key.order # => Arrow::SortOrder::DESCENDING + # key.target # => "count" + # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END + # + # @example String with order and null placement prefixes + # key = Arrow::SortKey.new("-^count") + # key.target # => "count" + # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START # - # @example Symbol that starts with "-" + # @example String with null placement prefix only + # key = Arrow::SortKey.new("^count") + # key.target # => "count" + # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START + # + # @example Symbol (no prefix processing) # key = Arrow::SortKey.new(:"-count") - # key.target # => "-count" - # key.order # => Arrow::SortOrder::ASCENDING + # key.target # => "-count" + # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # # @overload initialize(target, order) # @@ -122,21 +131,47 @@ def try_convert(value) # key = Arrow::SortKey.new("-count", :ascending) # key.target # => "-count" # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # # @example Order by abbreviated target with Symbol # key = Arrow::SortKey.new("count", :desc) # key.target # => "count" # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # # @example Order by String # key = Arrow::SortKey.new("count", "descending") # key.target # => "count" # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # - # @example Order by Arrow::SortOrder - # key = Arrow::SortKey.new("count", Arrow::SortOrder::DESCENDING) + # @example Order by Arrow::SortOrder, give null_placement with target + # key = Arrow::SortKey.new("^count", Arrow::SortOrder::DESCENDING) # key.target # => "count" # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START + # + # @overload initialize(target, order, null_placement) + # + # @param target [Symbol, String] The name or dot path of the + # sort column. + # + # @param order [Symbol, String, Arrow::SortOrder] How to order + # by this sort key. + # + # If this is a Symbol or String, this must be `:ascending`, + # `"ascending"`, `:asc`, `"asc"`, `:descending`, + # `"descending"`, `:desc` or `"desc"`. + # + # @param null_placement [Symbol, String, Arrow::NullPlacement] + # Where to place nulls and NaNs. Must be `:at_start`, `"at_start"`, + # `:at_end`, or `"at_end"`. + # + # @example With all explicit parameters + # key = Arrow::SortKey.new("count", :desc, :at_start) + # key.target # => "count" + # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START # # @since 4.0.0 def initialize(target, order=nil, null_placement=nil) @@ -152,52 +187,64 @@ def initialize(target, order=nil, null_placement=nil) # # @example Recreate Arrow::SortKey # key = Arrow::SortKey.new("-count") - # key.to_s # => "-count" + # key.to_s # => "-$count" # key == Arrow::SortKey.new(key.to_s) # => true # # @since 4.0.0 def to_s - result = if order == SortOrder::ASCENDING - "+#{target}" + result = "" + if order == SortOrder::ASCENDING + result += "+" else - "-#{target}" + result += "-" end if null_placement == NullPlacement::AT_START - result += "_at_start" + result += "^" else - result += "_at_end" + result += "$" end - return result + result += target + result end # For backward compatibility alias_method :name, :target private + # Parse prefix format: [+/-][^/$]column + # Examples: -$column, +^column, ^column, -column + # + # Only strips prefixes if the corresponding parameter is not already set. + # This preserves backward compatibility where specifying order explicitly + # means the target is used as-is for order prefixes. def normalize_target(target, order, null_placement) - # for recreatable, we should remove suffix - if target.end_with?("_at_start") - suffix_length = "_at_start".length - target = target[0..-(suffix_length + 1)] - elsif target.end_with?("_at_end") - suffix_length = "_at_end".length - target = target[0..-(suffix_length + 1)] - end - case target when Symbol return target.to_s, order, null_placement when String - if order - return target, order, null_placement + remaining = target + + unless order + if remaining.start_with?("-") + order = :descending + remaining = remaining[1..-1] + elsif remaining.start_with?("+") + order = :ascending + remaining = remaining[1..-1] + end end - if target.start_with?("-") - return target[1..-1], order || :descending, null_placement || :at_end - elsif target.start_with?("+") - return target[1..-1], order || :ascending, null_placement || :at_end - else - return target, order, null_placement + + unless null_placement + if remaining.start_with?("^") + null_placement = :at_start + remaining = remaining[1..-1] + elsif remaining.start_with?("$") + null_placement = :at_end + remaining = remaining[1..-1] + end end + + return remaining, order, null_placement else return target, order, null_placement end diff --git a/ruby/red-arrow/test/test-sort-key.rb b/ruby/red-arrow/test/test-sort-key.rb index 499f90fcadd..fbc4848e91d 100644 --- a/ruby/red-arrow/test/test-sort-key.rb +++ b/ruby/red-arrow/test/test-sort-key.rb @@ -35,40 +35,65 @@ class SortKeyTest < Test::Unit::TestCase sub_test_case("#initialize") do test("String") do - assert_equal("+count_at_end", + assert_equal("+$count", Arrow::SortKey.new("count").to_s) end test("+String") do - assert_equal("+count_at_end", + assert_equal("+$count", Arrow::SortKey.new("+count").to_s) end test("-String") do - assert_equal("-count_at_end", + assert_equal("-$count", Arrow::SortKey.new("-count").to_s) end test("Symbol") do - assert_equal("+-count_at_end", + assert_equal("+$-count", Arrow::SortKey.new(:"-count").to_s) end test("String, Symbol") do - assert_equal("--count_at_end", + assert_equal("-$-count", Arrow::SortKey.new("-count", :desc).to_s) end test("String, String") do - assert_equal("--count_at_end", + assert_equal("-$-count", Arrow::SortKey.new("-count", "desc").to_s) end test("String, SortOrder") do - assert_equal("--count_at_end", + assert_equal("-$-count", Arrow::SortKey.new("-count", Arrow::SortOrder::DESCENDING).to_s) end + + test("^String") do + assert_equal("+^count", + Arrow::SortKey.new("^count").to_s) + end + + test("-^String") do + assert_equal("-^count", + Arrow::SortKey.new("-^count").to_s) + end + + test("+$String") do + assert_equal("+$count", + Arrow::SortKey.new("+$count").to_s) + end + + test("+^^String") do + assert_equal("+^^count", + Arrow::SortKey.new("^count", Arrow::SortOrder::ASCENDING, Arrow::NullPlacement::AtStart).to_s) + end + + test("+$$String") do + assert_equal("+$$count", + Arrow::SortKey.new("$count", Arrow::SortOrder::ASCENDING, Arrow::NullPlacement::AtEnd).to_s) + end end sub_test_case("#to_s") do From bac578a4b0c98611954bdd7820c6b71ff9922e90 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 13:36:46 +0100 Subject: [PATCH 42/83] update ruby test functions that use direct GObject introspection --- c_glib/test/test-rank-quantile-options.rb | 12 ++++++------ c_glib/test/test-select-k-options.rb | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/c_glib/test/test-rank-quantile-options.rb b/c_glib/test/test-rank-quantile-options.rb index 359f59ade00..ebe6ec6e184 100644 --- a/c_glib/test/test-rank-quantile-options.rb +++ b/c_glib/test/test-rank-quantile-options.rb @@ -24,19 +24,19 @@ def setup def test_sort_keys sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] @options.sort_keys = sort_keys assert_equal(sort_keys, @options.sort_keys) end def test_add_sort_key - @options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - @options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + @options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + @options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ], @options.sort_keys) end diff --git a/c_glib/test/test-select-k-options.rb b/c_glib/test/test-select-k-options.rb index 78c17bf1bed..ab894f626de 100644 --- a/c_glib/test/test-select-k-options.rb +++ b/c_glib/test/test-select-k-options.rb @@ -30,19 +30,19 @@ def test_k def test_sort_keys sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] @options.sort_keys = sort_keys assert_equal(sort_keys, @options.sort_keys) end def test_add_sort_key - @options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - @options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + @options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + @options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ], @options.sort_keys) end @@ -53,7 +53,7 @@ def test_select_k_unstable_function Arrow::ArrayDatum.new(input_array), ] @options.k = 3 - @options.add_sort_key(Arrow::SortKey.new("dummy", :descending)) + @options.add_sort_key(Arrow::SortKey.new("dummy", :descending, :at_end)) select_k_unstable_function = Arrow::Function.find("select_k_unstable") result = select_k_unstable_function.execute(args, @options).value assert_equal(build_uint64_array([4, 2, 0]), result) From 14111159dbc34b0b070ad3a7394ff92fa87f9871 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 14:30:53 +0100 Subject: [PATCH 43/83] simplify helper functions --- c_glib/arrow-glib/compute.cpp | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 2d402ca9608..22815d80354 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -4534,31 +4534,19 @@ namespace { garrow_optional_null_placement_from_raw( const std::optional &arrow_placement) { - if (!arrow_placement.has_value()) { - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; - } - - switch (arrow_placement.value()) { - case arrow::compute::NullPlacement::AtStart: - return GARROW_OPTIONAL_NULL_PLACEMENT_AT_START; - case arrow::compute::NullPlacement::AtEnd: - return GARROW_OPTIONAL_NULL_PLACEMENT_AT_END; - default: + if (!arrow_null_placement.has_value()) { return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; } + return static_cast(arrow_null_placement.value()); } static std::optional garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_placement) { - switch (garrow_placement) { - case GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: - return arrow::compute::NullPlacement::AtStart; - case GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: - return arrow::compute::NullPlacement::AtEnd; - case GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED: - default: + if (garrow_null_placement == GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED) { return std::nullopt; + } else { + return static_cast(garrow_null_placement); } } } // namespace From 8f041b59b56f86aa739fb5cb8c3621fc760f4ecf Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 14:31:08 +0100 Subject: [PATCH 44/83] move and export helper functions --- c_glib/arrow-glib/compute.cpp | 42 +++++++++++++++++------------------ c_glib/arrow-glib/compute.hpp | 7 ++++++ 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 22815d80354..5ce61cc454f 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -86,6 +86,26 @@ garrow_take(arrow::Datum arrow_values, } } +GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw( + const std::optional &arrow_null_placement) +{ + if (!arrow_null_placement.has_value()) { + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; + } + return static_cast(arrow_null_placement.value()); +} + +std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement) +{ + if (garrow_null_placement == GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED) { + return std::nullopt; + } else { + return static_cast(garrow_null_placement); + } +} + namespace { gboolean garrow_field_refs_add(std::vector &arrow_field_refs, @@ -4529,28 +4549,6 @@ garrow_index_options_new(void) return GARROW_INDEX_OPTIONS(g_object_new(GARROW_TYPE_INDEX_OPTIONS, NULL)); } -namespace { - static GArrowOptionalNullPlacement - garrow_optional_null_placement_from_raw( - const std::optional &arrow_placement) - { - if (!arrow_null_placement.has_value()) { - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; - } - return static_cast(arrow_null_placement.value()); - } - - static std::optional - garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_placement) - { - if (garrow_null_placement == GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED) { - return std::nullopt; - } else { - return static_cast(garrow_null_placement); - } - } -} // namespace - enum { PROP_RANK_OPTIONS_NULL_PLACEMENT = 1, PROP_RANK_OPTIONS_TIEBREAKER, diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp index ff0698cd781..c171a62a602 100644 --- a/c_glib/arrow-glib/compute.hpp +++ b/c_glib/arrow-glib/compute.hpp @@ -20,6 +20,7 @@ #pragma once #include +#include #include #include @@ -143,6 +144,12 @@ garrow_index_options_new_raw(const arrow::compute::IndexOptions *arrow_options); arrow::compute::IndexOptions * garrow_index_options_get_raw(GArrowIndexOptions *options); +GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw( + const std::optional &arrow_null_placement); +std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement); + GArrowRankOptions * garrow_rank_options_new_raw(const arrow::compute::RankOptions *arrow_options); arrow::compute::RankOptions * From d36fbb69867eb65b8300259dbce257f559f330d7 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 14:33:52 +0100 Subject: [PATCH 45/83] fixup move and export helper functions --- c_glib/arrow-glib/compute.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp index c171a62a602..3134524242b 100644 --- a/c_glib/arrow-glib/compute.hpp +++ b/c_glib/arrow-glib/compute.hpp @@ -20,7 +20,6 @@ #pragma once #include -#include #include #include From 272ec3f864b737b6e3bc3be79b7bd604b578711a Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 14:56:53 +0100 Subject: [PATCH 46/83] format ruby --- ruby/red-arrow/test/test-sort-key.rb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ruby/red-arrow/test/test-sort-key.rb b/ruby/red-arrow/test/test-sort-key.rb index fbc4848e91d..f217cae493e 100644 --- a/ruby/red-arrow/test/test-sort-key.rb +++ b/ruby/red-arrow/test/test-sort-key.rb @@ -87,12 +87,14 @@ class SortKeyTest < Test::Unit::TestCase test("+^^String") do assert_equal("+^^count", - Arrow::SortKey.new("^count", Arrow::SortOrder::ASCENDING, Arrow::NullPlacement::AtStart).to_s) + Arrow::SortKey.new("^count", Arrow::SortOrder::ASCENDING, + Arrow::NullPlacement::AtStart).to_s) end test("+$$String") do assert_equal("+$$count", - Arrow::SortKey.new("$count", Arrow::SortOrder::ASCENDING, Arrow::NullPlacement::AtEnd).to_s) + Arrow::SortKey.new("$count", Arrow::SortOrder::ASCENDING, + Arrow::NullPlacement::AtEnd).to_s) end end From f1fda4883ee9593b94674443217b812c3c47651a Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 23 Jan 2026 15:18:51 +0100 Subject: [PATCH 47/83] fix another ruby test --- c_glib/test/test-rank-quantile-options.rb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/c_glib/test/test-rank-quantile-options.rb b/c_glib/test/test-rank-quantile-options.rb index ebe6ec6e184..4a2aa75a2d8 100644 --- a/c_glib/test/test-rank-quantile-options.rb +++ b/c_glib/test/test-rank-quantile-options.rb @@ -42,9 +42,11 @@ def test_add_sort_key end def test_null_placement - assert_equal(Arrow::NullPlacement::AT_END, @options.null_placement) + assert_equal(Arrow::OptionalNullPlacement::UNSPECIFIED, @options.null_placement) + @options.null_placement = :at_end + assert_equal(Arrow::OptionalNullPlacement::AT_END, @options.null_placement) @options.null_placement = :at_start - assert_equal(Arrow::NullPlacement::AT_START, @options.null_placement) + assert_equal(Arrow::OptionalNullPlacement::AT_START, @options.null_placement) end def test_rank_quantile_function From e398fe4856138397586c1cb6b591306c6eca88f7 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sat, 24 Jan 2026 09:16:05 +0100 Subject: [PATCH 48/83] fix another ruby test --- ruby/red-arrow/test/test-sort-options.rb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ruby/red-arrow/test/test-sort-options.rb b/ruby/red-arrow/test/test-sort-options.rb index 260c62e50f7..6c044b09342 100644 --- a/ruby/red-arrow/test/test-sort-options.rb +++ b/ruby/red-arrow/test/test-sort-options.rb @@ -25,7 +25,7 @@ class SortOptionsTest < Test::Unit::TestCase test("-String, Symbol") do options = Arrow::SortOptions.new("-count", :age) - assert_equal(["-count_at_end", "+age_at_end"], + assert_equal(["-$count", "+$age"], options.sort_keys.collect(&:to_s)) end end @@ -38,19 +38,19 @@ class SortOptionsTest < Test::Unit::TestCase sub_test_case("#add_sort_key") do test("-String") do @options.add_sort_key("-count") - assert_equal(["-count_at_end"], + assert_equal(["-$count"], @options.sort_keys.collect(&:to_s)) end test("-String, Symbol") do @options.add_sort_key("-count", :desc) - assert_equal(["--count_at_end"], + assert_equal(["--$count"], @options.sort_keys.collect(&:to_s)) end test("SortKey") do @options.add_sort_key(Arrow::SortKey.new("-count")) - assert_equal(["-count_at_end"], + assert_equal(["-$count"], @options.sort_keys.collect(&:to_s)) end end From 4778d977f8e8d499ec557189bf1e4d123d0f2123 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sat, 24 Jan 2026 09:25:25 +0100 Subject: [PATCH 49/83] Improve `GArrowOptionalNullPlacement` documentation --- c_glib/arrow-glib/compute.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index 26e8d25db24..ceccf32b733 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -514,14 +514,22 @@ typedef enum /**/ { * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: * Place nulls and NaNs before any non-null values. * NaNs will come after nulls. + * Ignore null-placement of each individual + * `arrow:compute::SortKey`. * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: * Place nulls and NaNs after any non-null values. * NaNs will come before nulls. + * Ignore null-placement of each individual + * `arrow:compute::SortKey`. * @GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED: * Do not specify null placement. - * Null placement should instead + * Instead, the null-placement of each individual + * `arrow:compute::SortKey` will be followed. * - * They are corresponding to `std::optional` values. + * They are corresponding to `arrow::compute::NullPlacement` values except + * `GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED`. + * `GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED` is used to specify + * `std::nullopt`. * * Since: 24.0.0 */ From 15542b013526b6850f7b7b89dacd5658e77caad0 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sat, 24 Jan 2026 09:26:52 +0100 Subject: [PATCH 50/83] Make static_cast of GARROW_OPTIONAL_NULL_PLACEMENT safer by having GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED correspond to -1. All other (possibly future) values will have a 1:1 mapping --- c_glib/arrow-glib/compute.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index ceccf32b733..525f1de17f2 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -534,9 +534,9 @@ typedef enum /**/ { * Since: 24.0.0 */ typedef enum /**/ { + GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED = -1, GARROW_OPTIONAL_NULL_PLACEMENT_AT_START, GARROW_OPTIONAL_NULL_PLACEMENT_AT_END, - GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED, } GArrowOptionalNullPlacement; #define GARROW_TYPE_ARRAY_SORT_OPTIONS (garrow_array_sort_options_get_type()) From b7293cbc817571fad5b099f72a0252a0b7bd6f23 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sat, 24 Jan 2026 09:34:08 +0100 Subject: [PATCH 51/83] improve placement of helper garrow_optional_null_placement_* helper functions in c_glib/arrow-glib/compute.{hpp,cpp} --- c_glib/arrow-glib/compute.cpp | 40 +++++++++++++++++------------------ c_glib/arrow-glib/compute.hpp | 12 +++++------ 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 5ce61cc454f..d77eeacad77 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -86,26 +86,6 @@ garrow_take(arrow::Datum arrow_values, } } -GArrowOptionalNullPlacement -garrow_optional_null_placement_from_raw( - const std::optional &arrow_null_placement) -{ - if (!arrow_null_placement.has_value()) { - return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; - } - return static_cast(arrow_null_placement.value()); -} - -std::optional -garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement) -{ - if (garrow_null_placement == GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED) { - return std::nullopt; - } else { - return static_cast(garrow_null_placement); - } -} - namespace { gboolean garrow_field_refs_add(std::vector &arrow_field_refs, @@ -11239,6 +11219,26 @@ garrow_sort_options_get_raw(GArrowSortOptions *options) garrow_function_options_get_raw(GARROW_FUNCTION_OPTIONS(options))); } +GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw( + const std::optional &arrow_null_placement) +{ + if (!arrow_null_placement.has_value()) { + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; + } + return static_cast(arrow_null_placement.value()); +} + +std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement) +{ + if (garrow_null_placement == GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED) { + return std::nullopt; + } else { + return static_cast(garrow_null_placement); + } +} + GArrowSetLookupOptions * garrow_set_lookup_options_new_raw(const arrow::compute::SetLookupOptions *arrow_options) { diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp index 3134524242b..7da0f30745b 100644 --- a/c_glib/arrow-glib/compute.hpp +++ b/c_glib/arrow-glib/compute.hpp @@ -100,6 +100,12 @@ garrow_sort_options_new_raw(const arrow::compute::SortOptions *arrow_options); arrow::compute::SortOptions * garrow_sort_options_get_raw(GArrowSortOptions *options); +GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw( + const std::optional &arrow_null_placement); +std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement); + GArrowSetLookupOptions * garrow_set_lookup_options_new_raw(const arrow::compute::SetLookupOptions *arrow_options); arrow::compute::SetLookupOptions * @@ -143,12 +149,6 @@ garrow_index_options_new_raw(const arrow::compute::IndexOptions *arrow_options); arrow::compute::IndexOptions * garrow_index_options_get_raw(GArrowIndexOptions *options); -GArrowOptionalNullPlacement -garrow_optional_null_placement_from_raw( - const std::optional &arrow_null_placement); -std::optional -garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement); - GArrowRankOptions * garrow_rank_options_new_raw(const arrow::compute::RankOptions *arrow_options); arrow::compute::RankOptions * From b33465e35ad0cd827a570550a22849b10665b464 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sat, 24 Jan 2026 20:28:11 +0100 Subject: [PATCH 52/83] more ruby test fixes --- ruby/red-arrow/test/test-sort-key.rb | 4 ++-- ruby/red-arrow/test/test-sort-options.rb | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ruby/red-arrow/test/test-sort-key.rb b/ruby/red-arrow/test/test-sort-key.rb index f217cae493e..0a0c7fccf27 100644 --- a/ruby/red-arrow/test/test-sort-key.rb +++ b/ruby/red-arrow/test/test-sort-key.rb @@ -88,13 +88,13 @@ class SortKeyTest < Test::Unit::TestCase test("+^^String") do assert_equal("+^^count", Arrow::SortKey.new("^count", Arrow::SortOrder::ASCENDING, - Arrow::NullPlacement::AtStart).to_s) + Arrow::NullPlacement::AT_START).to_s) end test("+$$String") do assert_equal("+$$count", Arrow::SortKey.new("$count", Arrow::SortOrder::ASCENDING, - Arrow::NullPlacement::AtEnd).to_s) + Arrow::NullPlacement::AT_END).to_s) end end diff --git a/ruby/red-arrow/test/test-sort-options.rb b/ruby/red-arrow/test/test-sort-options.rb index 6c044b09342..99cea89bc7f 100644 --- a/ruby/red-arrow/test/test-sort-options.rb +++ b/ruby/red-arrow/test/test-sort-options.rb @@ -44,7 +44,7 @@ class SortOptionsTest < Test::Unit::TestCase test("-String, Symbol") do @options.add_sort_key("-count", :desc) - assert_equal(["--$count"], + assert_equal(["-$-count"], @options.sort_keys.collect(&:to_s)) end From 74d451b6a4a25ec216839d5b9031be369450afed Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Tue, 27 Jan 2026 15:41:21 +0100 Subject: [PATCH 53/83] improve syntax of ruby test --- ruby/red-arrow/test/test-sort-key.rb | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ruby/red-arrow/test/test-sort-key.rb b/ruby/red-arrow/test/test-sort-key.rb index 0a0c7fccf27..da5017dd245 100644 --- a/ruby/red-arrow/test/test-sort-key.rb +++ b/ruby/red-arrow/test/test-sort-key.rb @@ -87,14 +87,12 @@ class SortKeyTest < Test::Unit::TestCase test("+^^String") do assert_equal("+^^count", - Arrow::SortKey.new("^count", Arrow::SortOrder::ASCENDING, - Arrow::NullPlacement::AT_START).to_s) + Arrow::SortKey.new("^count", :ascending, :at_start).to_s) end test("+$$String") do assert_equal("+$$count", - Arrow::SortKey.new("$count", Arrow::SortOrder::ASCENDING, - Arrow::NullPlacement::AT_END).to_s) + Arrow::SortKey.new("$count", :ascending, :at_end).to_s) end end From db87498ef21ea866730e73ae1ec9266cfa171841 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 4 Feb 2026 22:11:34 +0100 Subject: [PATCH 54/83] wip review --- .../arrow/compute/kernels/select_k_test.cc | 44 +++++++------------ .../arrow/compute/kernels/vector_sort_test.cc | 42 +++++++++++++----- cpp/src/arrow/compute/ordering.h | 8 +++- python/pyarrow/_acero.pyx | 6 +++ python/pyarrow/_compute.pyx | 5 +++ python/pyarrow/table.pxi | 2 +- python/pyarrow/tests/test_acero.py | 12 +++++ 7 files changed, 77 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 664cdeb5899..fd748e2a6f4 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -62,8 +62,7 @@ Result> SelectK(const Datum& values, int64_t k) { } } -void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order, - bool stable_sort = false) { +void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder order) { ASSERT_TRUE(datum.is_arraylike()); ASSERT_OK_AND_ASSIGN(auto sorted_indices, SortIndices(datum, SortOptions({SortKey("unused", order)}))); @@ -71,15 +70,11 @@ void ValidateSelectK(const Datum& datum, Array& select_k_indices, SortOrder orde int64_t k = select_k_indices.length(); // head(k) auto head_k_indices = sorted_indices->Slice(0, k); - if (stable_sort) { - AssertDatumsEqual(*head_k_indices, select_k_indices); - } else { - ASSERT_OK_AND_ASSIGN(auto expected, - Take(datum, *head_k_indices, TakeOptions::NoBoundsCheck())); - ASSERT_OK_AND_ASSIGN(auto actual, - Take(datum, select_k_indices, TakeOptions::NoBoundsCheck())); - AssertDatumsEqual(Datum(expected), Datum(actual)); - } + ASSERT_OK_AND_ASSIGN(auto expected, + Take(datum, *head_k_indices, TakeOptions::NoBoundsCheck())); + ASSERT_OK_AND_ASSIGN(auto actual, + Take(datum, select_k_indices, TakeOptions::NoBoundsCheck())); + AssertDatumsEqual(Datum(expected), Datum(actual)); } template @@ -88,27 +83,24 @@ class TestSelectKBase : public ::testing::Test { protected: template - void AssertSelectKArray(const std::shared_ptr values, int k, - bool check_indices = false) { + void AssertSelectKArray(const std::shared_ptr values, int k) { std::shared_ptr select_k; ASSERT_OK_AND_ASSIGN(select_k, SelectK(Datum(*values), k)); ASSERT_EQ(select_k->data()->null_count, 0); ValidateOutput(*select_k); - ValidateSelectK(Datum(*values), *select_k, order, check_indices); + ValidateSelectK(Datum(*values), *select_k, order); } - void AssertTopKArray(const std::shared_ptr values, int n, - bool check_indices = false) { - AssertSelectKArray(values, n, check_indices); + void AssertTopKArray(const std::shared_ptr values, int n) { + AssertSelectKArray(values, n); } - void AssertBottomKArray(const std::shared_ptr values, int n, - bool check_indices = false) { - AssertSelectKArray(values, n, check_indices); + void AssertBottomKArray(const std::shared_ptr values, int n) { + AssertSelectKArray(values, n); } - void AssertSelectKJson(const std::string& values, int n, bool check_indices = false) { - AssertTopKArray(ArrayFromJSON(type_singleton(), values), n, check_indices); - AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n, check_indices); + void AssertSelectKJson(const std::string& values, int n) { + AssertTopKArray(ArrayFromJSON(type_singleton(), values), n); + AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n); } virtual std::shared_ptr type_singleton() = 0; @@ -165,11 +157,9 @@ TYPED_TEST(TestSelectKForReal, Real) { this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); - // The result will contain nan. By default, the comparison of NaN is not equal, so - // indices are used for comparison. - this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4, true); + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); - this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4, true); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); this->AssertSelectKJson("[100, 4, 2, 7, 8, 3, NaN, 3, 1]", 4); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 14d5f9a7d49..d5eae9c7024 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -63,6 +63,10 @@ std::vector AllNullPlacements() { return {NullPlacement::AtEnd, NullPlacement::AtStart}; } +std::vector> AllOptionalNullPlacements() { + return {std::nullopt, NullPlacement::AtEnd, NullPlacement::AtStart}; +} + std::vector AllTiebreakers() { return {RankOptions::Min, RankOptions::Max, RankOptions::First, RankOptions::Dense}; } @@ -72,6 +76,16 @@ std::ostream& operator<<(std::ostream& os, NullPlacement null_placement) { return os; } +std::ostream& operator<<(std::ostream& os, std::optional null_placement) { + if(null_placement.has_value()){ + os << null_placement.value(); + } + else { + os << "None"; + } + return os; +} + // ---------------------------------------------------------------------- // Tests for NthToIndices @@ -1205,12 +1219,14 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { {"a": 1, "b": 3} ])"); - for (auto null_placement : AllNullPlacements()) { - SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), - SortKey("b", SortOrder::Descending, null_placement)}, - null_placement); + for (auto overwrite_null_placement : AllOptionalNullPlacements()){ + for (auto null_placement : AllNullPlacements()) { + SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), + SortKey("b", SortOrder::Descending, null_placement)}, + overwrite_null_placement); - AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); + AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); + } } } @@ -2082,15 +2098,19 @@ TEST_P(TestTableSortIndicesRandom, Sort) { return (distribution(engine) & 1) ? SortOrder::Ascending : SortOrder::Descending; }; + auto generate_null_placement = [&]() { + return (distribution(engine) % 3) ? NullPlacement::AtEnd : NullPlacement::AtStart; + }; + std::vector sort_keys; sort_keys.reserve(fields.size()); for (const auto& field : fields) { if (field->name() != first_sort_key_name) { - sort_keys.emplace_back(field->name(), generate_order()); + sort_keys.emplace_back(field->name(), generate_order(), generate_null_placement()); } } std::shuffle(sort_keys.begin(), sort_keys.end(), engine); - sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order()); + sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order(), generate_null_placement()); sort_keys.erase(sort_keys.begin() + n_sort_keys, sort_keys.end()); ASSERT_EQ(sort_keys.size(), n_sort_keys); @@ -2128,11 +2148,9 @@ TEST_P(TestTableSortIndicesRandom, Sort) { } auto table = Table::Make(schema, std::move(columns)); - for (auto null_placement : AllNullPlacements()) { - ARROW_SCOPED_TRACE("null_placement = ", null_placement); - for (auto& sort_key : sort_keys) { - sort_key.null_placement = null_placement; - } + for (auto overwrite_null_placement : AllOptionalNullPlacements()) { + ARROW_SCOPED_TRACE("overwrite_null_placement = ", overwrite_null_placement); + options.null_placement = overwrite_null_placement; ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index 3764dfcdcb1..dc1f8cf29d1 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -63,8 +63,12 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { class ARROW_EXPORT Ordering : public util::EqualityComparable { public: + Ordering(std::vector sort_keys) + : sort_keys_(std::move(sort_keys)) {} + + // DEPRECATED(will be removed after removing null_placement from Ordering) Ordering(std::vector sort_keys, - std::optional null_placement = std::nullopt) + std::optional null_placement) : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} /// true if data ordered by other is also ordered by this /// @@ -117,7 +121,7 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys_; - // DEPRECATED(set null_placement in instead) + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys std::optional null_placement_; diff --git a/python/pyarrow/_acero.pyx b/python/pyarrow/_acero.pyx index 1b6b5bd9622..0e62db64245 100644 --- a/python/pyarrow/_acero.pyx +++ b/python/pyarrow/_acero.pyx @@ -28,6 +28,7 @@ from pyarrow.includes.libarrow_acero cimport * from pyarrow.lib cimport (Table, pyarrow_unwrap_table, pyarrow_wrap_table, RecordBatchReader) from pyarrow.lib import frombytes, tobytes +import warnings from pyarrow._compute cimport ( Expression, FunctionOptions, _ensure_field_ref, _true, unwrap_null_placement, unwrap_sort_keys @@ -274,6 +275,11 @@ class OrderByNodeOptions(_OrderByNodeOptions): """ def __init__(self, sort_keys=(), *, null_placement=None): + if null_placement is not None: + warnings.warn( + "Specifying null_placement in OrderByNodeOptions is deprecated " + "as of 24.0.0. Specify null_placement per sort_key instead." + ) self._set_options(sort_keys, null_placement) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 7b72d09bd93..5fb736ec937 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -2279,6 +2279,11 @@ class SortOptions(_SortOptions): """ def __init__(self, sort_keys=(), *, null_placement=None): + if null_placement is not None: + warnings.warn( + "Specifying null_placement in SortOptions is deprecated " + "as of 24.0.0. Specify null_placement per sort_key instead." + ) self._set_options(sort_keys, null_placement) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 0e2df0fa5c0..ba8145a6dc1 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2118,7 +2118,7 @@ cdef class _Tabular(_PandasConvertible): Parameters ---------- - sorting : str or list[tuple(name, order, null_placement)] + sorting : str or list[tuple(name, order, null_placement="at_end")] Name of the column to use to sort (ascending), or a list of multiple sorting conditions where each entry is a tuple with column name diff --git a/python/pyarrow/tests/test_acero.py b/python/pyarrow/tests/test_acero.py index 2e4770eeab9..e789b3bc7fa 100644 --- a/python/pyarrow/tests/test_acero.py +++ b/python/pyarrow/tests/test_acero.py @@ -273,12 +273,24 @@ def test_order_by(): expected = pa.table({"a": [1, 4, 2, 3], "b": [1, 2, 3, None]}) assert result.equals(expected) + ord_opts = OrderByNodeOptions([("b", "ascending")]) + decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) + result = decl.to_table() + expected = pa.table({"a": [1, 4, 2, 3], "b": [1, 2, 3, None]}) + assert result.equals(expected) + ord_opts = OrderByNodeOptions([(field("b"), "descending", "at_end")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [2, 4, 1, 3], "b": [3, 2, 1, None]}) assert result.equals(expected) + ord_opts = OrderByNodeOptions([(field("b"), "descending")]) + decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) + result = decl.to_table() + expected = pa.table({"a": [2, 4, 1, 3], "b": [3, 2, 1, None]}) + assert result.equals(expected) + ord_opts = OrderByNodeOptions([(1, "descending", "at_start")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() From a844d96b85dc5a3eb539f3121efe1695973b70a3 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 4 Feb 2026 22:31:16 +0100 Subject: [PATCH 55/83] wip review --- cpp/src/arrow/compute/api_vector.h | 11 +++++++---- cpp/src/arrow/compute/kernels/select_k_test.cc | 11 +++++------ cpp/src/arrow/compute/kernels/vector_sort.cc | 9 +++++++++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index f0edb12e274..3e1f4d8fd04 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -114,12 +114,14 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// Note: Both classes contain the exact same information. However, /// sort_options should only be used in a "function options" context while Ordering /// is used more generally. - Ordering AsOrdering() && { return Ordering(std::move(sort_keys), null_placement); } - Ordering AsOrdering() const& { return Ordering(sort_keys, null_placement); } + Ordering AsOrdering() && { return {std::move(sort_keys)}; } + Ordering AsOrdering() const& { return {sort_keys}; } /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { @@ -132,11 +134,12 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { } return overwritten_sort_keys; } +#pragma clang diagnostic pop - // DEPRECATED(set null_placement in sort_keys instead) + // DEPRECATED(Deprecated in arrow 24.0.0, use null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys - std::optional null_placement; + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") std::optional null_placement; }; /// \brief SelectK options diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index fd748e2a6f4..58a85edf407 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -242,11 +242,9 @@ class TestSelectKWithArray : public ::testing::Test { const SelectKOptions& options, const std::string& expected_json) { auto array = ArrayFromJSON(type, array_json); auto expected = ArrayFromJSON(uint64(), expected_json); - auto indices = SelectKUnstable(Datum(*array), options); - ASSERT_OK(indices); - auto actual = indices.MoveValueUnsafe(); - ValidateOutput(*actual); - AssertArraysEqual(*expected, *actual, /*verbose=*/true); + ASSERT_OK_AND_ASSIGN(auto indices, SelectKUnstable(Datum(*array), options)); + ValidateOutput(*indices); + AssertArraysEqual(*expected, *indices, /*verbose=*/true); } Status DoSelectK(const std::shared_ptr& type, const std::string& array_json, @@ -416,7 +414,8 @@ TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNullNaN) { std::vector sort_keys{SortKey("a", SortOrder::Descending)}; auto options = SelectKOptions(3, sort_keys); options.sort_keys[0].null_placement = NullPlacement::AtStart; - this->CheckIndices(chunked_array, options, "[3, 0, 4]"); + auto expected = ChunkedArrayFromJSON(uint8(), {"[3, 0, 4]"}); + this->Check(chunked_array, options, expected); options.sort_keys[0].null_placement = NullPlacement::AtEnd; this->CheckIndices(chunked_array, options, "[5, 2, 7]"); } diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 68fc4cf816f..0b9d8939a51 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -957,6 +957,8 @@ class SortIndicesMetaFunction : public MetaFunction { chunked_array->length()); } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; @@ -965,13 +967,17 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } + // TODO.TAE this member is deprecated. Is there a way to implement it without it? if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } +#pragma clang diagnostic pop +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" Result SortIndices(const ChunkedArray& chunked_array, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; @@ -980,6 +986,8 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } + // TODO.TAE this member is deprecated. Is there a way to implement it without it? + // Ah the method is only private?? if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } @@ -1000,6 +1008,7 @@ class SortIndicesMetaFunction : public MetaFunction { SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, null_placement)); return Datum(out); } +#pragma clang diagnostic pop Result SortIndices(const RecordBatch& batch, const SortOptions& options, ExecContext* ctx) const { From 1938e7af0775387ec72f13a4015d427fb91bbb25 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 4 Feb 2026 23:07:24 +0100 Subject: [PATCH 56/83] minor renaming --- .../arrow/compute/kernels/vector_select_k.cc | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 2986bf65ebf..1fbe0e1c1ff 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -104,7 +104,7 @@ class HeapSorter { int64_t null_count = p.null_count(); // non-null nan null if (null_placement_ == NullPlacement::AtEnd) { - extract_non_null_count = non_null_count <= k_ ? non_null_count : k_; + extract_non_null_count = std::min(non_null_count, k_); extract_nan_count = extract_non_null_count >= k_ ? 0 : std::min(nan_count, k_ - extract_non_null_count); @@ -129,7 +129,7 @@ class HeapSorter { int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + counter.extract_null_count; ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); - // [extrat_count....extract_nan_count...extract_null_count] + // [extract_count....extract_nan_count...extract_null_count] if (null_placement_ == NullPlacement::AtEnd) { if (counter.extract_non_null_count) { auto* out_cbegin = take_indices->template GetMutableValues(1) + @@ -183,8 +183,8 @@ class HeapSorter { } private: - int64_t k_; - NullPlacement null_placement_; + const int64_t k_; + const NullPlacement null_placement_; MemoryPool* pool_; }; @@ -292,20 +292,20 @@ class ChunkedHeapSorter { : k_(k), null_placement_(null_placement), pool_(pool) {} Result> HeapSort(const ArrayVector physical_chunks) { - std::vector> chunks_null_partions; + std::vector> chunks_null_partitions; std::vector> chunks_holder; std::vector> chunks_indices_holder; - chunks_null_partions.reserve(physical_chunks.size()); - ExtractCounter counter = ComputeExtractCounter(physical_chunks, chunks_null_partions, + chunks_null_partitions.reserve(physical_chunks.size()); + ExtractCounter counter = ComputeExtractCounter(physical_chunks, chunks_null_partitions, chunks_holder, chunks_indices_holder); - return HeapSortInternal(chunks_holder, counter, chunks_null_partions); + return HeapSortInternal(chunks_holder, counter, chunks_null_partitions); } // Extract the total count of non-nulls, nans, and nulls for all chunks ExtractCounter ComputeExtractCounter( const ArrayVector physical_chunks, std::vector>& - chunks_null_partions, + chunks_null_partitions, std::vector>& chunks_holder, std::vector>& chunks_indices_holder) { int64_t all_non_null_count = 0; @@ -335,7 +335,7 @@ class ChunkedHeapSorter { all_non_null_count += non_null_count; all_nan_count += nan_count; all_null_count += null_count; - chunks_null_partions.emplace_back(p, q); + chunks_null_partitions.emplace_back(p, q); } // non-null nan null if (null_placement_ == NullPlacement::AtEnd) { @@ -361,7 +361,7 @@ class ChunkedHeapSorter { const std::vector>& chunks_holder, ExtractCounter counter, const std::vector>& - chunks_null_partions) { + chunks_null_partitions) { std::function cmp; SelectKComparator comparator; cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { @@ -376,8 +376,8 @@ class ChunkedHeapSorter { HeapContainer null_heap(cmp); uint64_t offset = 0; - for (size_t i = 0; i < chunks_null_partions.size(); i++) { - const auto& null_part_pair = chunks_null_partions[i]; + for (size_t i = 0; i < chunks_null_partitions.size(); i++) { + const auto& null_part_pair = chunks_null_partitions[i]; const auto& p = null_part_pair.first; const auto& q = null_part_pair.second; ArrayType& arr = *chunks_holder[i]; From a8d544b7c11392124f8b350952d4f35d8662a558 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 4 Feb 2026 23:07:33 +0100 Subject: [PATCH 57/83] fixup tests --- cpp/src/arrow/compute/kernels/select_k_test.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 58a85edf407..147d597ef51 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -414,10 +414,12 @@ TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNullNaN) { std::vector sort_keys{SortKey("a", SortOrder::Descending)}; auto options = SelectKOptions(3, sort_keys); options.sort_keys[0].null_placement = NullPlacement::AtStart; - auto expected = ChunkedArrayFromJSON(uint8(), {"[3, 0, 4]"}); + auto expected = ChunkedArrayFromJSON(float64(), {"[null, null, NaN]"}); this->Check(chunked_array, options, expected); options.sort_keys[0].null_placement = NullPlacement::AtEnd; - this->CheckIndices(chunked_array, options, "[5, 2, 7]"); + expected = ChunkedArrayFromJSON(float64(), {"[10, 3, 2]"}); + this->Check(chunked_array, options, expected); + // TODO.TAE more CheckIndices? } TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNullNaN) { From b2bf5d9b67ab931ae8e482fd29329bb8d59f0d16 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Wed, 4 Feb 2026 23:07:53 +0100 Subject: [PATCH 58/83] deprecate one function --- cpp/src/arrow/compute/api_vector.cc | 5 +++++ cpp/src/arrow/compute/api_vector.h | 13 ++++++++----- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 4aa661f3988..3e635baf659 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -195,6 +195,11 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; +SortOptions::SortOptions(std::vector sort_keys) + : FunctionOptions(internal::kSortOptionsType), + sort_keys(std::move(sort_keys)), + null_placement(std::nullopt) {} + SortOptions::SortOptions(std::vector sort_keys, std::optional null_placement) : FunctionOptions(internal::kSortOptionsType), diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 3e1f4d8fd04..198c1a0ef7e 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -104,8 +104,12 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { class ARROW_EXPORT SortOptions : public FunctionOptions { public: - explicit SortOptions(std::vector sort_keys = {}, - std::optional null_placement = std::nullopt); + explicit SortOptions(std::vector sort_keys = {}); + + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") + explicit SortOptions(std::vector sort_keys, + std::optional null_placement); + explicit SortOptions(const Ordering& ordering); static constexpr const char kTypeName[] = "SortOptions"; static SortOptions Defaults() { return SortOptions(); } @@ -120,8 +124,7 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" + ARROW_SUPPRESS_DEPRECATION_WARNING // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement std::vector GetSortKeys() const { @@ -134,7 +137,7 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { } return overwritten_sort_keys; } -#pragma clang diagnostic pop + ARROW_UNSUPPRESS_DEPRECATION_WARNING // DEPRECATED(Deprecated in arrow 24.0.0, use null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end From 1afaf5c53f62c2d88dd8c4aef44c454d2a8f8330 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 6 Feb 2026 15:08:56 +0100 Subject: [PATCH 59/83] fix select_k doc text --- cpp/src/arrow/compute/kernels/vector_select_k.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 1fbe0e1c1ff..274a9903ca0 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -44,9 +44,9 @@ const FunctionDoc select_k_unstable_doc( ("This function selects an array of indices of the first `k` ordered elements\n" "from the `input` array, record batch or table specified in the column keys\n" "(`options.sort_keys`). Output is not guaranteed to be stable.\n" - "Null values are considered greater than any other value and are\n" - "therefore ordered at the end. For floating-point types, NaNs are considered\n" - "greater than any other non-null value, but smaller than null values."), + "Null values will be ordered according to the null_placement as specified per\n" + "sort-key. For floating-point types, NaNs are always ordered between\n" + "null values and non-null values."), {"input"}, "SelectKOptions", /*options_required=*/true); template From d12a6bc11f991f85f4540b0ab91f551b1bab07e8 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Fri, 6 Feb 2026 17:23:40 +0100 Subject: [PATCH 60/83] revert old pr changes to vector_select_k.cc --- .../arrow/compute/kernels/vector_select_k.cc | 552 +++++------------- 1 file changed, 147 insertions(+), 405 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 274a9903ca0..591a2509673 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -44,9 +44,9 @@ const FunctionDoc select_k_unstable_doc( ("This function selects an array of indices of the first `k` ordered elements\n" "from the `input` array, record batch or table specified in the column keys\n" "(`options.sort_keys`). Output is not guaranteed to be stable.\n" - "Null values will be ordered according to the null_placement as specified per\n" - "sort-key. For floating-point types, NaNs are always ordered between\n" - "null values and non-null values."), + "Null values are considered greater than any other value and are\n" + "therefore ordered at the end. For floating-point types, NaNs are considered\n" + "greater than any other non-null value, but smaller than null values."), {"input"}, "SelectKOptions", /*options_required=*/true); template @@ -74,120 +74,6 @@ class SelectKComparator { } }; -struct ExtractCounter { - int64_t extract_non_null_count; - int64_t extract_nan_count; - int64_t extract_null_count; -}; - -class HeapSorter { - public: - using HeapPusherFunction = - std::function; - - HeapSorter(int64_t k, NullPlacement null_placement, MemoryPool* pool) - : k_(k), null_placement_(null_placement), pool_(pool) {} - - Result> HeapSort(HeapPusherFunction heap_pusher, - NullPartitionResult p, - NullPartitionResult q) { - ExtractCounter counter = ComputeExtractCounter(p, q); - return HeapSortInternal(counter, heap_pusher, p, q); - } - - ExtractCounter ComputeExtractCounter(NullPartitionResult p, NullPartitionResult q) { - int64_t extract_non_null_count = 0; - int64_t extract_nan_count = 0; - int64_t extract_null_count = 0; - int64_t non_null_count = q.non_null_count(); - int64_t nan_count = q.null_count(); - int64_t null_count = p.null_count(); - // non-null nan null - if (null_placement_ == NullPlacement::AtEnd) { - extract_non_null_count = std::min(non_null_count, k_); - extract_nan_count = extract_non_null_count >= k_ - ? 0 - : std::min(nan_count, k_ - extract_non_null_count); - extract_null_count = extract_non_null_count + extract_nan_count >= k_ - ? 0 - : (k_ - (extract_non_null_count + extract_nan_count)); - } else { // null nan non-null - extract_null_count = null_count <= k_ ? null_count : k_; - extract_nan_count = - extract_null_count >= k_ ? 0 : std::min(nan_count, k_ - extract_null_count); - extract_non_null_count = extract_null_count + extract_nan_count >= k_ - ? 0 - : (k_ - (extract_null_count + extract_nan_count)); - } - return {extract_non_null_count, extract_nan_count, extract_null_count}; - } - - Result> HeapSortInternal(ExtractCounter counter, - HeapPusherFunction heap_pusher, - NullPartitionResult p, - NullPartitionResult q) { - int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + - counter.extract_null_count; - ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); - // [extract_count....extract_nan_count...extract_null_count] - if (null_placement_ == NullPlacement::AtEnd) { - if (counter.extract_non_null_count) { - auto* out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_non_null_count - 1; - auto kth_begin = std::min(q.non_nulls_begin + k_, q.non_nulls_end); - heap_pusher(q.non_nulls_begin, kth_begin, q.non_nulls_end, out_cbegin); - } - - if (counter.extract_nan_count) { - auto* out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_non_null_count + counter.extract_nan_count - 1; - auto kth_begin = - std::min(q.nulls_begin + k_ - counter.extract_non_null_count, q.nulls_end); - heap_pusher(q.nulls_begin, kth_begin, q.nulls_end, out_cbegin); - } - - if (counter.extract_null_count) { - auto* out_cbegin = - take_indices->template GetMutableValues(1) + out_size - 1; - auto kth_begin = std::min(p.nulls_begin + k_ - counter.extract_non_null_count - - counter.extract_nan_count, - p.nulls_end); - heap_pusher(p.nulls_begin, kth_begin, p.nulls_end, out_cbegin); - } - } else { // [extract_null_count....extract_nan_count...extrat_count] - if (counter.extract_null_count) { - auto* out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_null_count - 1; - auto kth_begin = std::min(p.nulls_begin + k_, p.nulls_end); - heap_pusher(p.nulls_begin, kth_begin, p.nulls_end, out_cbegin); - } - - if (counter.extract_nan_count) { - auto* out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_null_count + counter.extract_nan_count - 1; - auto kth_begin = - std::min(q.nulls_begin + k_ - counter.extract_null_count, q.nulls_end); - heap_pusher(q.nulls_begin, kth_begin, q.nulls_end, out_cbegin); - } - - if (counter.extract_non_null_count) { - auto* out_cbegin = - take_indices->template GetMutableValues(1) + out_size - 1; - auto kth_begin = std::min(q.non_nulls_begin + k_ - counter.extract_null_count - - counter.extract_nan_count, - q.non_nulls_end); - heap_pusher(q.non_nulls_begin, kth_begin, q.non_nulls_end, out_cbegin); - } - } - return take_indices; - } - - private: - const int64_t k_; - const NullPlacement null_placement_; - MemoryPool* pool_; -}; - class ArraySelector : public TypeVisitor { public: ArraySelector(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -196,8 +82,7 @@ class ArraySelector : public TypeVisitor { ctx_(ctx), array_(array), k_(options.k), - order_(options.GetSortKeys()[0].order), - null_placement_(options.GetSortKeys()[0].null_placement), + order_(options.sort_keys[0].order), physical_type_(GetPhysicalType(array.type())), output_(output) {} @@ -230,10 +115,11 @@ class ArraySelector : public TypeVisitor { k_ = arr.length(); } - const auto p = PartitionNullsOnly(indices_begin, indices_end, - arr, 0, null_placement_); - const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, arr, 0, null_placement_); + const auto p = PartitionNulls( + indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); + const auto end_iter = p.non_nulls_end; + + auto kth_begin = std::min(indices_begin + k_, end_iter); SelectKComparator comparator; auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { @@ -243,24 +129,24 @@ class ArraySelector : public TypeVisitor { }; using HeapContainer = std::priority_queue, decltype(cmp)>; - auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, - uint64_t* end_iter, uint64_t* out_cbegin) { - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - if (cmp(x_index, heap.top())) { - heap.pop(); - heap.push(x_index); - } - } - while (heap.size() > 0) { - *out_cbegin = heap.top(); + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + if (cmp(x_index, heap.top())) { heap.pop(); - --out_cbegin; + heap.push(x_index); } - }; - HeapSorter h(k_, null_placement_, ctx_->memory_pool()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); + } + auto out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(out_size, ctx_->memory_pool())); + + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.pop(); + --out_cbegin; + } *output_ = Datum(take_indices); return Status::OK(); } @@ -269,7 +155,6 @@ class ArraySelector : public TypeVisitor { const Array& array_; int64_t k_; SortOrder order_; - NullPlacement null_placement_; const std::shared_ptr physical_type_; Datum* output_; }; @@ -281,188 +166,6 @@ struct TypedHeapItem { ArrayType* array; }; -template -class ChunkedHeapSorter { - public: - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - using HeapItem = TypedHeapItem; - - ChunkedHeapSorter(int64_t k, NullPlacement null_placement, MemoryPool* pool) - : k_(k), null_placement_(null_placement), pool_(pool) {} - - Result> HeapSort(const ArrayVector physical_chunks) { - std::vector> chunks_null_partitions; - std::vector> chunks_holder; - std::vector> chunks_indices_holder; - chunks_null_partitions.reserve(physical_chunks.size()); - ExtractCounter counter = ComputeExtractCounter(physical_chunks, chunks_null_partitions, - chunks_holder, chunks_indices_holder); - return HeapSortInternal(chunks_holder, counter, chunks_null_partitions); - } - - // Extract the total count of non-nulls, nans, and nulls for all chunks - ExtractCounter ComputeExtractCounter( - const ArrayVector physical_chunks, - std::vector>& - chunks_null_partitions, - std::vector>& chunks_holder, - std::vector>& chunks_indices_holder) { - int64_t all_non_null_count = 0; - int64_t all_nan_count = 0; - int64_t all_null_count = 0; - int64_t extract_non_null_count = 0; - int64_t extract_nan_count = 0; - int64_t extract_null_count = 0; - for (size_t i = 0; i < physical_chunks.size(); i++) { - const auto& chunk = physical_chunks[i]; - if (chunk->length() == 0) continue; - chunks_holder.emplace_back(std::make_shared(chunk->data())); - ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; - chunks_indices_holder.emplace_back(std::vector(arr.length())); - std::vector& indices = - chunks_indices_holder[chunks_indices_holder.size() - 1]; - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); - std::iota(indices_begin, indices_end, 0); - NullPartitionResult p = PartitionNullsOnly( - indices_begin, indices_end, arr, 0, null_placement_); - NullPartitionResult q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, arr, 0, null_placement_); - int64_t non_null_count = q.non_null_count(); - int64_t nan_count = q.null_count(); - int64_t null_count = p.null_count(); - all_non_null_count += non_null_count; - all_nan_count += nan_count; - all_null_count += null_count; - chunks_null_partitions.emplace_back(p, q); - } - // non-null nan null - if (null_placement_ == NullPlacement::AtEnd) { - extract_non_null_count = all_non_null_count <= k_ ? all_non_null_count : k_; - extract_nan_count = extract_non_null_count >= k_ - ? 0 - : std::min(all_nan_count, k_ - extract_non_null_count); - extract_null_count = extract_non_null_count + extract_nan_count >= k_ - ? 0 - : (k_ - (extract_non_null_count + extract_nan_count)); - } else { // null nan non-null - extract_null_count = all_null_count <= k_ ? all_null_count : k_; - extract_nan_count = - extract_null_count >= k_ ? 0 : std::min(all_nan_count, k_ - extract_null_count); - extract_non_null_count = extract_null_count + extract_nan_count >= k_ - ? 0 - : (k_ - (extract_null_count + extract_nan_count)); - } - return {extract_non_null_count, extract_nan_count, extract_null_count}; - } - - Result> HeapSortInternal( - const std::vector>& chunks_holder, - ExtractCounter counter, - const std::vector>& - chunks_null_partitions) { - std::function cmp; - SelectKComparator comparator; - cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { - const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - return comparator(lval, rval); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - HeapContainer non_null_heap(cmp); - HeapContainer nan_heap(cmp); - HeapContainer null_heap(cmp); - - uint64_t offset = 0; - for (size_t i = 0; i < chunks_null_partitions.size(); i++) { - const auto& null_part_pair = chunks_null_partitions[i]; - const auto& p = null_part_pair.first; - const auto& q = null_part_pair.second; - ArrayType& arr = *chunks_holder[i]; - - auto HeapPusher = [&](HeapContainer& heap, int64_t extract_non_null_count, - uint64_t* indices_begin, uint64_t* kth_begin, - uint64_t* end_iter) { - uint64_t* iter = indices_begin; - for (; iter != kth_begin && - heap.size() < static_cast(extract_non_null_count); - ++iter) { - heap.push(HeapItem{*iter, offset, &arr}); - } - for (; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); - auto top_item = heap.top(); - const auto& top_value = - GetView::LogicalValue(top_item.array->GetView(top_item.index)); - if (comparator(xval, top_value)) { - heap.pop(); - heap.push(HeapItem{x_index, offset, &arr}); - } - } - }; - HeapPusher( - non_null_heap, counter.extract_non_null_count, q.non_nulls_begin, - std::min(q.non_nulls_begin + counter.extract_non_null_count, q.non_nulls_end), - q.non_nulls_end); - HeapPusher(nan_heap, counter.extract_nan_count, q.nulls_begin, - std::min(q.nulls_begin + counter.extract_nan_count, q.nulls_end), - q.nulls_end); - HeapPusher(null_heap, counter.extract_null_count, p.nulls_begin, - std::min(p.nulls_begin + counter.extract_null_count, p.nulls_end), - p.nulls_end); - offset += arr.length(); - } - - int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + - counter.extract_null_count; - ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); - - auto PopHeaper = [&](HeapContainer& heap, uint64_t* out_cbegin) { - while (heap.size() > 0) { - auto top_item = heap.top(); - *out_cbegin = top_item.index + top_item.offset; - heap.pop(); - --out_cbegin; - } - }; - - if (null_placement_ == NullPlacement::AtEnd) { - // non_null - auto* out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_non_null_count - 1; - PopHeaper(non_null_heap, out_cbegin); - // nan - out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_non_null_count + counter.extract_nan_count - 1; - PopHeaper(nan_heap, out_cbegin); - // null - out_cbegin = take_indices->template GetMutableValues(1) + out_size - 1; - PopHeaper(null_heap, out_cbegin); - } else { - // null - auto* out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_null_count - 1; - PopHeaper(null_heap, out_cbegin); - // nan - out_cbegin = take_indices->template GetMutableValues(1) + - counter.extract_null_count + counter.extract_nan_count - 1; - PopHeaper(nan_heap, out_cbegin); - // non_null - out_cbegin = take_indices->template GetMutableValues(1) + out_size - 1; - PopHeaper(non_null_heap, out_cbegin); - } - return take_indices; - } - - private: - int64_t k_; - NullPlacement null_placement_; - MemoryPool* pool_; -}; - class ChunkedArraySelector : public TypeVisitor { public: ChunkedArraySelector(ExecContext* ctx, const ChunkedArray& chunked_array, @@ -473,7 +176,6 @@ class ChunkedArraySelector : public TypeVisitor { physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), k_(options.k), order_(options.sort_keys[0].order), - null_placement_(options.sort_keys[0].null_placement), ctx_(ctx), output_(output) {} @@ -492,6 +194,10 @@ class ChunkedArraySelector : public TypeVisitor { template Status SelectKthInternal() { + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + using HeapItem = TypedHeapItem; + const auto num_chunks = chunked_array_.num_chunks(); if (num_chunks == 0) { return Status::OK(); @@ -499,9 +205,63 @@ class ChunkedArraySelector : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } + std::function cmp; + SelectKComparator comparator; + + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + + HeapContainer heap(cmp); + std::vector> chunks_holder; + uint64_t offset = 0; + for (const auto& chunk : physical_chunks_) { + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + + std::vector indices(arr.length()); + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + + const auto p = PartitionNulls( + indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); + const auto end_iter = p.non_nulls_end; + + auto kth_begin = std::min(indices_begin + k_, end_iter); + uint64_t* iter = indices_begin; + for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { + heap.push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.pop(); + heap.push(HeapItem{x_index, offset, &arr}); + } + } + offset += chunk->length(); + } - ChunkedHeapSorter h(k_, null_placement_, ctx_->memory_pool()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(physical_chunks_)); + auto out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.pop(); + --out_cbegin; + } *output_ = Datum(take_indices); return Status::OK(); } @@ -511,7 +271,6 @@ class ChunkedArraySelector : public TypeVisitor { const ArrayVector physical_chunks_; int64_t k_; SortOrder order_; - NullPlacement null_placement_; ExecContext* ctx_; Datum* output_; }; @@ -529,8 +288,8 @@ class RecordBatchSelector : public TypeVisitor { record_batch_(record_batch), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(record_batch, options.GetSortKeys(), &status_)), - comparator_(sort_keys_) {} + sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)), + comparator_(sort_keys_, NullPlacement::AtEnd) {} Status Run() { RETURN_NOT_OK(status_); @@ -556,7 +315,7 @@ class RecordBatchSelector : public TypeVisitor { *status = maybe_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_array), key.order, key.null_placement); + resolved.emplace_back(*std::move(maybe_array), key.order); } return resolved; } @@ -581,9 +340,7 @@ class RecordBatchSelector : public TypeVisitor { cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { const auto lval = GetView::LogicalValue(arr.GetView(left)); const auto rval = GetView::LogicalValue(arr.GetView(right)); - const bool is_null_left = arr.IsNull(left); - const bool is_null_right = arr.IsNull(right); - if ((lval == rval) || (is_null_left && is_null_right)) { + if (lval == rval) { // If the left value equals to the right value, // we need to compare the second and following // sort keys. @@ -599,31 +356,30 @@ class RecordBatchSelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - NullPartitionResult p = PartitionNullsOnly( - indices_begin, indices_end, arr, 0, first_sort_key.null_placement); - NullPartitionResult q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, arr, 0, first_sort_key.null_placement); + const auto p = PartitionNulls( + indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); + const auto end_iter = p.non_nulls_end; - auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, - uint64_t* end_iter, uint64_t* out_cbegin) { - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - auto top_item = heap.top(); - if (cmp(x_index, top_item)) { - heap.pop(); - heap.push(x_index); - } - } - while (heap.size() > 0) { - *out_cbegin = heap.top(); + auto kth_begin = std::min(indices_begin + k_, end_iter); + + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (cmp(x_index, top_item)) { heap.pop(); - --out_cbegin; + heap.push(x_index); } - }; - - HeapSorter h(k_, first_sort_key.null_placement, ctx_->memory_pool()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); + } + auto out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.pop(); + --out_cbegin; + } *output_ = Datum(take_indices); return Status::OK(); } @@ -641,9 +397,8 @@ class TableSelector : public TypeVisitor { private: struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& chunked_array, - const SortOrder order, const NullPlacement null_placement) + const SortOrder order) : order(order), - null_placement(null_placement), type(GetPhysicalType(chunked_array->type())), chunks(GetPhysicalChunks(*chunked_array, type)), null_count(chunked_array->null_count()), @@ -656,7 +411,6 @@ class TableSelector : public TypeVisitor { ResolvedChunk GetChunk(int64_t index) const { return resolver.Resolve(index); } const SortOrder order; - const NullPlacement null_placement; const std::shared_ptr type; const ArrayVector chunks; const int64_t null_count; @@ -672,8 +426,8 @@ class TableSelector : public TypeVisitor { table_(table), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(table, options.GetSortKeys(), &status_)), - comparator_(sort_keys_) {} + sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), + comparator_(sort_keys_, NullPlacement::AtEnd) {} Status Run() { RETURN_NOT_OK(status_); @@ -700,44 +454,36 @@ class TableSelector : public TypeVisitor { *status = maybe_chunked_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_chunked_array), key.order, - key.null_placement); + resolved.emplace_back(*std::move(maybe_chunked_array), key.order); } return resolved; } // Behaves like PartitionNulls() but this supports multiple sort keys. + template NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, uint64_t* indices_end, const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; + const auto p = PartitionNullsOnly( indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, - first_sort_key.null_placement); + NullPlacement::AtEnd); DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); - auto& comparator = comparator_; - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - - return p; - } - - template - NullPartitionResult PartitionNaNsInternal(uint64_t* indices_begin, - uint64_t* indices_end, - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; const auto q = PartitionNullLikes( - indices_begin, indices_end, first_sort_key.resolver, - first_sort_key.null_placement); + p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, + NullPlacement::AtEnd); auto& comparator = comparator_; // Sort all NaNs by the second and following sort keys. std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); + // Sort all nulls by the second and following sort keys. + std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); return q; } @@ -763,11 +509,9 @@ class TableSelector : public TypeVisitor { cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { auto chunk_left = first_sort_key.GetChunk(left); auto chunk_right = first_sort_key.GetChunk(right); - const bool is_null_left = chunk_left.IsNull(); - const bool is_null_right = chunk_right.IsNull(); auto value_left = chunk_left.Value(); auto value_right = chunk_right.Value(); - if ((value_left == value_right) || (is_null_left && is_null_right)) { + if (value_left == value_right) { return comparator.Compare(left, right, 1); } return select_k_comparator(value_left, value_right); @@ -781,30 +525,28 @@ class TableSelector : public TypeVisitor { std::iota(indices_begin, indices_end, 0); const auto p = - this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - const auto q = this->PartitionNaNsInternal(p.non_nulls_begin, p.non_nulls_end, - first_sort_key); - - auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, - uint64_t* end_iter, uint64_t* out_cbegin) { - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - uint64_t top_item = heap.top(); - if (cmp(x_index, top_item)) { - heap.pop(); - heap.push(x_index); - } - } - while (heap.size() > 0) { - *out_cbegin = heap.top(); + this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); + const auto end_iter = p.non_nulls_end; + auto kth_begin = std::min(indices_begin + k_, end_iter); + + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + uint64_t top_item = heap.top(); + if (cmp(x_index, top_item)) { heap.pop(); - --out_cbegin; + heap.push(x_index); } - }; - - HeapSorter h(k_, first_sort_key.null_placement, ctx_->memory_pool()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); + } + auto out_size = static_cast(heap.size()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(out_size, ctx_->memory_pool())); + auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; + while (heap.size() > 0) { + *out_cbegin = heap.top(); + heap.pop(); + --out_cbegin; + } *output_ = Datum(take_indices); return Status::OK(); } @@ -880,7 +622,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.GetSortKeys())); + ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.sort_keys)); Datum output; RecordBatchSelector selector(ctx, record_batch, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); @@ -888,7 +630,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const Table& table, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.GetSortKeys())); + ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.sort_keys)); Datum output; TableSelector selector(ctx, table, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); From 37b4978ff40197d22aee26664d6f3da6a493828d Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 8 Feb 2026 15:26:38 +0100 Subject: [PATCH 61/83] another select_k_test --- .../arrow/compute/kernels/select_k_test.cc | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 147d597ef51..4401376a86d 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -242,7 +242,7 @@ class TestSelectKWithArray : public ::testing::Test { const SelectKOptions& options, const std::string& expected_json) { auto array = ArrayFromJSON(type, array_json); auto expected = ArrayFromJSON(uint64(), expected_json); - ASSERT_OK_AND_ASSIGN(auto indices, SelectKUnstable(Datum(*array), options)); + ASSERT_OK_AND_ASSIGN(auto indices, SelectKUnstable(Datum(*array), options)); ValidateOutput(*indices); AssertArraysEqual(*expected, *indices, /*verbose=*/true); } @@ -263,8 +263,11 @@ class TestSelectKWithArray : public ::testing::Test { TEST_F(TestSelectKWithArray, PartialSelectKNull) { auto array_input = R"([null, 30, 20, 10, null])"; std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; - auto options = SelectKOptions(3, sort_keys); - auto expected = R"([10, 20, 30])"; + auto options = SelectKOptions(4, sort_keys); + auto expected = R"([10, 20, 30, null])"; + Check(uint8(), array_input, options, expected); + options.k = 3; + expected = R"([10, 20, 30])"; Check(uint8(), array_input, options, expected); options.sort_keys[0].null_placement = NullPlacement::AtStart; expected = R"([null, null, 10])"; @@ -589,6 +592,17 @@ TEST_F(TestSelectKWithRecordBatch, TopKNull) { ])"; Check(schema, batch_input, options, expected_batch); + + auto options_with_null = SelectKOptions::TopKDefault(4, {"a", "b"}); + + auto expected_batch_with_null = R"([ + {"a": 30, "b": 3}, + {"a": 20, "b": 5}, + {"a": 10, "b": 3}, + {"a": null, "b": 6} + ])"; + + Check(schema, batch_input, options_with_null, expected_batch_with_null); } TEST_F(TestSelectKWithRecordBatch, TopKOneColumnKey) { From 18ca67c8367e51ec2f04485519ac2174a5941190 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 8 Feb 2026 15:33:28 +0100 Subject: [PATCH 62/83] version almost finished. Some second key issues --- .../arrow/compute/kernels/vector_select_k.cc | 299 +++++++++++------- 1 file changed, 193 insertions(+), 106 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 591a2509673..4968ec82f2c 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -#include +#include +#include #include "arrow/compute/function.h" #include "arrow/compute/kernels/vector_sort_internal.h" @@ -74,6 +75,69 @@ class SelectKComparator { } }; +template +std::pair calculateNumberNonNullAndNullToTake( + const GenericNullPartitionResult& partition_result, int64_t k, + NullPlacement null_placement) { + if (null_placement == NullPlacement::AtEnd) { + int64_t l = std::min(k, partition_result.non_null_count()); + int64_t m = std::min(k - l, partition_result.null_count()); + return {l, m}; + } else { + int64_t m = std::min(k, partition_result.null_count()); + int64_t l = std::min(k - m, partition_result.non_null_count()); + return {l, m}; + } +} + +template +void HeapSortNonNullsToOutput(const typename TypeTraits::ArrayType& arr, + uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, + uint64_t* output) { + using GetView = GetViewType; + + SelectKComparator comparator; + auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return comparator(lval, rval); + }; + std::span heap{non_null_begin, non_null_begin + l}; + std::make_heap(heap.begin(), heap.end(), cmp); + for (auto iter = non_null_begin + l; iter != non_null_end; ++iter) { + uint64_t x_index = *iter; + if (cmp(x_index, heap.front())) { + std::pop_heap(heap.begin(), heap.end(), cmp); + heap.back() = x_index; + std::push_heap(heap.begin(), heap.end(), cmp); + } + } + + // fill output in reverse when destructing, + // as the "worst" (next-to-would-have-been-replaced) element is at heap-top + uint64_t* heap_begin = non_null_begin; + uint64_t* heap_end = non_null_begin + l; + for (auto reverse_out_iter = output + l - 1; reverse_out_iter >= output; + --reverse_out_iter) { + *reverse_out_iter = *heap_begin; // heap-top has the next element + std::pop_heap(heap_begin, heap_end, cmp); + --heap_end; + } +} + +template +void HeapSortNonNullsToOutput(const typename TypeTraits::ArrayType& arr, + uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, + SortOrder order, uint64_t* output) { + if (order == SortOrder::Ascending) { + HeapSortNonNullsToOutput(arr, non_null_begin, + non_null_end, l, output); + } else { + HeapSortNonNullsToOutput(arr, non_null_begin, + non_null_end, l, output); + } +} + class ArraySelector : public TypeVisitor { public: ArraySelector(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -82,7 +146,8 @@ class ArraySelector : public TypeVisitor { ctx_(ctx), array_(array), k_(options.k), - order_(options.sort_keys[0].order), + order_(options.GetSortKeys()[0].order), + null_placement_(options.GetSortKeys()[0].null_placement), physical_type_(GetPhysicalType(array.type())), output_(output) {} @@ -102,7 +167,6 @@ class ArraySelector : public TypeVisitor { template Status SelectKthInternal() { - using GetView = GetViewType; using ArrayType = typename TypeTraits::ArrayType; ArrayType arr(array_.data()); @@ -111,42 +175,30 @@ class ArraySelector : public TypeVisitor { uint64_t* indices_begin = indices.data(); uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - if (k_ > arr.length()) { - k_ = arr.length(); - } const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; + indices_begin, indices_end, arr, 0, null_placement_); - auto kth_begin = std::min(indices_begin + k_, end_iter); + // From k, calculate + // l = non_null elements to take from PartitionResult + // m = null elements to take from PartitionResult + // k = l + m if enough elements in input + auto [l, m] = calculateNumberNonNullAndNullToTake(p, k_, null_placement_); - SelectKComparator comparator; - auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return comparator(lval, rval); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - if (cmp(x_index, heap.top())) { - heap.pop(); - heap.push(x_index); - } - } - auto out_size = static_cast(heap.size()); ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; + MakeMutableUInt64Array(l + m, ctx_->memory_pool())); + auto* output = take_indices->template GetMutableValues(1); + + if (null_placement_ == NullPlacement::AtEnd) { + HeapSortNonNullsToOutput(arr, p.non_nulls_begin, + p.non_nulls_end, l, output); + std::copy(p.nulls_begin, p.nulls_begin + m, output + l); + } else { + std::copy(p.nulls_begin, p.nulls_begin + m, output); + HeapSortNonNullsToOutput(arr, p.non_nulls_begin, + p.non_nulls_end, l, output + m); } + *output_ = Datum(take_indices); return Status::OK(); } @@ -155,6 +207,7 @@ class ArraySelector : public TypeVisitor { const Array& array_; int64_t k_; SortOrder order_; + NullPlacement null_placement_; const std::shared_ptr physical_type_; Datum* output_; }; @@ -275,7 +328,7 @@ class ChunkedArraySelector : public TypeVisitor { Datum* output_; }; -class RecordBatchSelector : public TypeVisitor { +class RecordBatchSelector { private: using ResolvedSortKey = ResolvedRecordBatchSortKey; using Comparator = MultipleKeyComparator; @@ -283,29 +336,19 @@ class RecordBatchSelector : public TypeVisitor { public: RecordBatchSelector(ExecContext* ctx, const RecordBatch& record_batch, const SelectKOptions& options, Datum* output) - : TypeVisitor(), - ctx_(ctx), + : ctx_(ctx), record_batch_(record_batch), k_(options.k), output_(output), sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); - return sort_keys_[0].type->Accept(this); + return SelectKthInternal(); } protected: -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { \ - if (sort_keys_[0].order == SortOrder::Descending) \ - return SelectKthInternal(); \ - return SelectKthInternal(); \ - } - VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) -#undef VISIT - static std::vector ResolveSortKeys( const RecordBatch& batch, const std::vector& sort_keys, Status* status) { std::vector resolved; @@ -315,73 +358,113 @@ class RecordBatchSelector : public TypeVisitor { *status = maybe_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_array), key.order); + resolved.emplace_back(*std::move(maybe_array), key.order, key.null_placement); } return resolved; } - template - Status SelectKthInternal() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - auto& comparator = comparator_; - const auto& first_sort_key = sort_keys_[0]; - const auto& arr = checked_cast(first_sort_key.array); + class SelectKForKey : public TypeVisitor { + public: + SelectKForKey(RecordBatchSelector* selector, size_t start_sort_key_index, + std::span input_indices, int64_t k_remaining, + uint64_t* output_indices) + : TypeVisitor(), + selector_(selector), + start_sort_key_index_(start_sort_key_index), + input_indices_(input_indices), + k_remaining_(k_remaining), + output_indices_(output_indices) {} + + private: + template + Status Do() { + using ArrayType = typename TypeTraits::ArrayType; + const auto& first_remaining_sort_key = selector_->sort_keys_[start_sort_key_index_]; + const auto& arr = checked_cast(first_remaining_sort_key.array); + + // TODO.TAE uhh this might be prettier + uint64_t* input_indices_begin = &*input_indices_.begin(); + uint64_t* input_indices_end = input_indices_begin + input_indices_.size(); + + const auto p = PartitionNulls( + input_indices_begin, input_indices_end, arr, 0, + first_remaining_sort_key.null_placement); + + // From k, calculate + // l = non_null elements to take from PartitionResult + // m = null elements to take from PartitionResult + // k = l + m if enough elements in input + auto [l, m] = calculateNumberNonNullAndNullToTake( + p, k_remaining_, first_remaining_sort_key.null_placement); + + if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { + HeapSortNonNullsToOutput(arr, p.non_nulls_begin, p.non_nulls_end, l, + first_remaining_sort_key.order, output_indices_); + if (m > 0) { + if (start_sort_key_index_ + 1 == selector_->sort_keys_.size()) { + // We have the last sort_key, can just copy over the null values + std::copy(p.nulls_begin, p.nulls_begin + m, output_indices_ + l); + } else { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( + start_sort_key_index_ + 1, + std::span{p.nulls_begin, p.nulls_end}, l, output_indices_ + l)); + } + } + } else { + if (start_sort_key_index_ + 1 == selector_->sort_keys_.size()) { + // We have the last sort_key, can just copy over the null values + std::copy(p.nulls_begin, p.nulls_begin + m, output_indices_); + } else { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( + start_sort_key_index_ + 1, std::span{p.nulls_begin, p.nulls_end}, + l, output_indices_)); + } + HeapSortNonNullsToOutput(arr, p.non_nulls_begin, p.non_nulls_end, l, + first_remaining_sort_key.order, + output_indices_ + m); + } - const auto num_rows = record_batch_.num_rows(); - if (num_rows == 0) { return Status::OK(); } - if (k_ > record_batch_.num_rows()) { - k_ = record_batch_.num_rows(); - } - std::function cmp; - SelectKComparator select_k_comparator; - cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - if (lval == rval) { - // If the left value equals to the right value, - // we need to compare the second and following - // sort keys. - return comparator.Compare(left, right, 1); - } - return select_k_comparator(lval, rval); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - std::vector indices(arr.length()); - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); - std::iota(indices_begin, indices_end, 0); +#define VISIT(TYPE) \ + Status Visit(const TYPE& type) { return Do(); } + VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; +#undef VISIT - auto kth_begin = std::min(indices_begin + k_, end_iter); + RecordBatchSelector* selector_; + size_t start_sort_key_index_; + std::span input_indices_; + int64_t k_remaining_; + uint64_t* output_indices_; + }; - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - auto top_item = heap.top(); - if (cmp(x_index, top_item)) { - heap.pop(); - heap.push(x_index); - } + Status DoSelectKForKey(size_t start_sort_key_index, std::span input_indices, + int64_t k_remaining, uint64_t* output_indices) { + SelectKForKey tmp(this, start_sort_key_index, input_indices, k_remaining, + output_indices); + return sort_keys_.at(start_sort_key_index).type->Accept(&tmp); + } + + Status SelectKthInternal() { + if (k_ > record_batch_.num_rows()) { + k_ = record_batch_.num_rows(); } - auto out_size = static_cast(heap.size()); + + std::vector input_indices(record_batch_.num_rows()); + std::iota(input_indices.begin(), input_indices.end(), 0); + + // We do not directly sort indices in output_indices, as it hold only k_ indices, + // but e.g. need to partition all record_batch_.num_rows() of them ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + MakeMutableUInt64Array(k_, ctx_->memory_pool())); + auto* output_indices = take_indices->template GetMutableValues(1); + + std::span input_indices_span(input_indices); + ARROW_RETURN_NOT_OK(DoSelectKForKey(0, input_indices_span, k_, output_indices)); *output_ = Datum(take_indices); - return Status::OK(); + return arrow::Status::OK(); } Status status_; @@ -397,8 +480,9 @@ class TableSelector : public TypeVisitor { private: struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& chunked_array, - const SortOrder order) + const SortOrder order, const NullPlacement null_placement) : order(order), + null_placement(null_placement), type(GetPhysicalType(chunked_array->type())), chunks(GetPhysicalChunks(*chunked_array, type)), null_count(chunked_array->null_count()), @@ -411,6 +495,7 @@ class TableSelector : public TypeVisitor { ResolvedChunk GetChunk(int64_t index) const { return resolver.Resolve(index); } const SortOrder order; + const NullPlacement null_placement; const std::shared_ptr type; const ArrayVector chunks; const int64_t null_count; @@ -426,8 +511,8 @@ class TableSelector : public TypeVisitor { table_(table), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + sort_keys_(ResolveSortKeys(table, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -454,11 +539,13 @@ class TableSelector : public TypeVisitor { *status = maybe_chunked_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_chunked_array), key.order); + resolved.emplace_back(*std::move(maybe_chunked_array), key.order, + key.null_placement); } return resolved; } + // TODO.TAE remove, it sorts ALL non-null inputs // Behaves like PartitionNulls() but this supports multiple sort keys. template NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, @@ -622,7 +709,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.sort_keys)); + ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.GetSortKeys())); Datum output; RecordBatchSelector selector(ctx, record_batch, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); @@ -630,7 +717,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const Table& table, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.sort_keys)); + ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.GetSortKeys())); Datum output; TableSelector selector(ctx, table, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); From 7bd1d581eb4afde37a94cbbe99e2f84250d66de1 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 8 Feb 2026 16:42:42 +0100 Subject: [PATCH 63/83] improvements --- .../arrow/compute/kernels/vector_select_k.cc | 94 +++++++++++-------- 1 file changed, 55 insertions(+), 39 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 4968ec82f2c..490a14e2d5c 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -90,18 +90,9 @@ std::pair calculateNumberNonNullAndNullToTake( } } -template -void HeapSortNonNullsToOutput(const typename TypeTraits::ArrayType& arr, - uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, - uint64_t* output) { - using GetView = GetViewType; - - SelectKComparator comparator; - auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { - const auto lval = GetView::LogicalValue(arr.GetView(left)); - const auto rval = GetView::LogicalValue(arr.GetView(right)); - return comparator(lval, rval); - }; +template +void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, + Comparator cmp, uint64_t* output) { std::span heap{non_null_begin, non_null_begin + l}; std::make_heap(heap.begin(), heap.end(), cmp); for (auto iter = non_null_begin + l; iter != non_null_end; ++iter) { @@ -125,16 +116,30 @@ void HeapSortNonNullsToOutput(const typename TypeTraits::ArrayType& arr, } } +template +void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, + const typename TypeTraits::ArrayType& arr, + uint64_t* output) { + using GetView = GetViewType; + SelectKComparator comparator; + auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { + const auto lval = GetView::LogicalValue(arr.GetView(left)); + const auto rval = GetView::LogicalValue(arr.GetView(right)); + return comparator(lval, rval); + }; + HeapSortNonNullsToOutput(non_null_begin, non_null_end, l, cmp, output); +} + template -void HeapSortNonNullsToOutput(const typename TypeTraits::ArrayType& arr, - uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, +void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, + const typename TypeTraits::ArrayType& arr, SortOrder order, uint64_t* output) { if (order == SortOrder::Ascending) { - HeapSortNonNullsToOutput(arr, non_null_begin, - non_null_end, l, output); + HeapSortNonNullsToOutput(non_null_begin, non_null_end, + l, arr, output); } else { - HeapSortNonNullsToOutput(arr, non_null_begin, - non_null_end, l, output); + HeapSortNonNullsToOutput(non_null_begin, non_null_end, + l, arr, output); } } @@ -190,13 +195,13 @@ class ArraySelector : public TypeVisitor { auto* output = take_indices->template GetMutableValues(1); if (null_placement_ == NullPlacement::AtEnd) { - HeapSortNonNullsToOutput(arr, p.non_nulls_begin, - p.non_nulls_end, l, output); + HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, + arr, output); std::copy(p.nulls_begin, p.nulls_begin + m, output + l); } else { std::copy(p.nulls_begin, p.nulls_begin + m, output); - HeapSortNonNullsToOutput(arr, p.non_nulls_begin, - p.non_nulls_end, l, output + m); + HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, + arr, output + m); } *output_ = Datum(take_indices); @@ -397,31 +402,42 @@ class RecordBatchSelector { auto [l, m] = calculateNumberNonNullAndNullToTake( p, k_remaining_, first_remaining_sort_key.null_placement); + uint64_t* non_null_output_indices_begin; + uint64_t* null_output_indices_begin; if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { - HeapSortNonNullsToOutput(arr, p.non_nulls_begin, p.non_nulls_end, l, - first_remaining_sort_key.order, output_indices_); + non_null_output_indices_begin = output_indices_; + null_output_indices_begin = output_indices_ + l; + } else { + non_null_output_indices_begin = output_indices_ + m; + null_output_indices_begin = output_indices_; + } + + bool last_sort_key = start_sort_key_index_ + 1 == selector_->sort_keys_.size(); + + if (last_sort_key) { + if (l > 0) { + HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, arr, + first_remaining_sort_key.order, + non_null_output_indices_begin); + } if (m > 0) { - if (start_sort_key_index_ + 1 == selector_->sort_keys_.size()) { - // We have the last sort_key, can just copy over the null values - std::copy(p.nulls_begin, p.nulls_begin + m, output_indices_ + l); - } else { - ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, - std::span{p.nulls_begin, p.nulls_end}, l, output_indices_ + l)); - } + // We have the last sort_key, can just copy over the null values + std::copy(p.nulls_begin, p.nulls_begin + m, null_output_indices_begin); } } else { - if (start_sort_key_index_ + 1 == selector_->sort_keys_.size()) { + if (l > 0) { + auto cmp = [&](uint64_t left, uint64_t right) { + return selector_->comparator_.Compare(left, right, 1); + }; + HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, cmp, + non_null_output_indices_begin); + } + if (m > 0) { // We have the last sort_key, can just copy over the null values - std::copy(p.nulls_begin, p.nulls_begin + m, output_indices_); - } else { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( start_sort_key_index_ + 1, std::span{p.nulls_begin, p.nulls_end}, - l, output_indices_)); + l, null_output_indices_begin)); } - HeapSortNonNullsToOutput(arr, p.non_nulls_begin, p.non_nulls_end, l, - first_remaining_sort_key.order, - output_indices_ + m); } return Status::OK(); From c5c1fd3e38450868a497147e4fd72d08502e87dd Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 8 Feb 2026 16:59:31 +0100 Subject: [PATCH 64/83] fixes --- .../arrow/compute/kernels/vector_select_k.cc | 135 ++++++++++++------ 1 file changed, 92 insertions(+), 43 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 490a14e2d5c..9bc266bbd1a 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -75,27 +75,26 @@ class SelectKComparator { } }; -template -std::pair calculateNumberNonNullAndNullToTake( - const GenericNullPartitionResult& partition_result, int64_t k, - NullPlacement null_placement) { +std::pair calculateNumberNonNullAndNullLikesToTake( + const std::span& non_null_like_range, + const std::span& null_like_range, int64_t k, NullPlacement null_placement) { if (null_placement == NullPlacement::AtEnd) { - int64_t l = std::min(k, partition_result.non_null_count()); - int64_t m = std::min(k - l, partition_result.null_count()); + int64_t l = std::min(k, static_cast(non_null_like_range.size())); + int64_t m = std::min(k - l, static_cast(null_like_range.size())); return {l, m}; } else { - int64_t m = std::min(k, partition_result.null_count()); - int64_t l = std::min(k - m, partition_result.non_null_count()); + int64_t m = std::min(k, static_cast(null_like_range.size())); + int64_t l = std::min(k - m, static_cast(non_null_like_range.size())); return {l, m}; } } template -void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, +void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, Comparator cmp, uint64_t* output) { - std::span heap{non_null_begin, non_null_begin + l}; + std::span heap{non_null_range.begin(), non_null_range.begin() + l}; std::make_heap(heap.begin(), heap.end(), cmp); - for (auto iter = non_null_begin + l; iter != non_null_end; ++iter) { + for (auto iter = non_null_range.begin() + l; iter != non_null_range.end(); ++iter) { uint64_t x_index = *iter; if (cmp(x_index, heap.front())) { std::pop_heap(heap.begin(), heap.end(), cmp); @@ -106,8 +105,8 @@ void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, // fill output in reverse when destructing, // as the "worst" (next-to-would-have-been-replaced) element is at heap-top - uint64_t* heap_begin = non_null_begin; - uint64_t* heap_end = non_null_begin + l; + uint64_t* heap_begin = &*heap.begin(); + uint64_t* heap_end = &*heap.begin() + l; for (auto reverse_out_iter = output + l - 1; reverse_out_iter >= output; --reverse_out_iter) { *reverse_out_iter = *heap_begin; // heap-top has the next element @@ -117,7 +116,7 @@ void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, } template -void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, +void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, const typename TypeTraits::ArrayType& arr, uint64_t* output) { using GetView = GetViewType; @@ -127,19 +126,19 @@ void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, const auto rval = GetView::LogicalValue(arr.GetView(right)); return comparator(lval, rval); }; - HeapSortNonNullsToOutput(non_null_begin, non_null_end, l, cmp, output); + HeapSortNonNullsToOutput(non_null_range, l, cmp, output); } template -void HeapSortNonNullsToOutput(uint64_t* non_null_begin, uint64_t* non_null_end, int64_t l, +void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, const typename TypeTraits::ArrayType& arr, SortOrder order, uint64_t* output) { if (order == SortOrder::Ascending) { - HeapSortNonNullsToOutput(non_null_begin, non_null_end, - l, arr, output); + HeapSortNonNullsToOutput(non_null_range, l, arr, + output); } else { - HeapSortNonNullsToOutput(non_null_begin, non_null_end, - l, arr, output); + HeapSortNonNullsToOutput(non_null_range, l, arr, + output); } } @@ -185,23 +184,25 @@ class ArraySelector : public TypeVisitor { indices_begin, indices_end, arr, 0, null_placement_); // From k, calculate - // l = non_null elements to take from PartitionResult - // m = null elements to take from PartitionResult + // l = non-null elements to take from PartitionResult + // m = null-like elements to take from PartitionResult // k = l + m if enough elements in input - auto [l, m] = calculateNumberNonNullAndNullToTake(p, k_, null_placement_); + auto [l, m] = calculateNumberNonNullAndNullLikesToTake( + {p.non_nulls_begin, p.non_nulls_end}, {p.nulls_begin, p.nulls_end}, k_, + null_placement_); ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(l + m, ctx_->memory_pool())); auto* output = take_indices->template GetMutableValues(1); if (null_placement_ == NullPlacement::AtEnd) { - HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, - arr, output); + HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, + l, arr, output); std::copy(p.nulls_begin, p.nulls_begin + m, output + l); } else { std::copy(p.nulls_begin, p.nulls_begin + m, output); - HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, - arr, output + m); + HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, + l, arr, output + m); } *output_ = Datum(take_indices); @@ -333,6 +334,32 @@ class ChunkedArraySelector : public TypeVisitor { Datum* output_; }; +struct NullNanPartitionResult { + std::span non_null_like_range; + std::span null_like_range; + // Also store the null/nan distribution within null_like_range + std::span null_range; + std::span nan_range; +}; + +template +NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, + uint64_t* indices_end, + const ArrayType& values, int64_t offset, + NullPlacement null_placement) { + // Partition nulls at start (resp. end), and null-like values just before (resp. after) + NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, + values, offset, null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); + return NullNanPartitionResult{ + .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, + .null_like_range = {std::min(q.nulls_begin, p.nulls_begin), + std::max(q.nulls_end, p.nulls_end)}, + .null_range = {p.nulls_begin, p.nulls_end}, + .nan_range = {q.nulls_begin, q.nulls_end}}; +} + class RecordBatchSelector { private: using ResolvedSortKey = ResolvedRecordBatchSortKey; @@ -387,56 +414,78 @@ class RecordBatchSelector { const auto& first_remaining_sort_key = selector_->sort_keys_[start_sort_key_index_]; const auto& arr = checked_cast(first_remaining_sort_key.array); - // TODO.TAE uhh this might be prettier + // TODO.TAE uhh this could be prettier uint64_t* input_indices_begin = &*input_indices_.begin(); uint64_t* input_indices_end = input_indices_begin + input_indices_.size(); - const auto p = PartitionNulls( + const auto p = PartitionNullsAndNans( input_indices_begin, input_indices_end, arr, 0, first_remaining_sort_key.null_placement); // From k, calculate // l = non_null elements to take from PartitionResult // m = null elements to take from PartitionResult - // k = l + m if enough elements in input - auto [l, m] = calculateNumberNonNullAndNullToTake( - p, k_remaining_, first_remaining_sort_key.null_placement); + // k = l + m because k was clipped to num_rows() + auto [l, m] = calculateNumberNonNullAndNullLikesToTake( + p.non_null_like_range, p.null_like_range, k_remaining_, + first_remaining_sort_key.null_placement); uint64_t* non_null_output_indices_begin; - uint64_t* null_output_indices_begin; + uint64_t* nulllike_output_indices_begin; if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { non_null_output_indices_begin = output_indices_; - null_output_indices_begin = output_indices_ + l; + nulllike_output_indices_begin = output_indices_ + l; } else { + nulllike_output_indices_begin = output_indices_; non_null_output_indices_begin = output_indices_ + m; - null_output_indices_begin = output_indices_; } bool last_sort_key = start_sort_key_index_ + 1 == selector_->sort_keys_.size(); if (last_sort_key) { if (l > 0) { - HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, arr, + HeapSortNonNullsToOutput(p.non_null_like_range, l, arr, first_remaining_sort_key.order, non_null_output_indices_begin); } if (m > 0) { // We have the last sort_key, can just copy over the null values - std::copy(p.nulls_begin, p.nulls_begin + m, null_output_indices_begin); + std::copy(p.null_like_range.begin(), p.null_like_range.begin() + m, + nulllike_output_indices_begin); } } else { if (l > 0) { auto cmp = [&](uint64_t left, uint64_t right) { - return selector_->comparator_.Compare(left, right, 1); + return selector_->comparator_.Compare(left, right, start_sort_key_index_); }; - HeapSortNonNullsToOutput(p.non_nulls_begin, p.non_nulls_end, l, cmp, + HeapSortNonNullsToOutput(p.non_null_like_range, l, cmp, non_null_output_indices_begin); } if (m > 0) { - // We have the last sort_key, can just copy over the null values - ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, std::span{p.nulls_begin, p.nulls_end}, - l, null_output_indices_begin)); + // Need to subdivide into null and nan, we use non-nulllike / nulllike stratey + // again for nan / null division + auto [nan_count, null_count] = calculateNumberNonNullAndNullLikesToTake( + p.nan_range, p.null_range, m, first_remaining_sort_key.null_placement); + uint64_t* nan_output_indices_begin; + uint64_t* null_output_indices_begin; + if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { + nan_output_indices_begin = nulllike_output_indices_begin; + null_output_indices_begin = nulllike_output_indices_begin + nan_count; + } else { + null_output_indices_begin = nulllike_output_indices_begin; + nan_output_indices_begin = nulllike_output_indices_begin + null_count; + } + + if (nan_count > 0) { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey(start_sort_key_index_ + 1, + p.nan_range, nan_count, + nan_output_indices_begin)); + } + if (null_count > 0) { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey(start_sort_key_index_ + 1, + p.null_range, null_count, + null_output_indices_begin)); + } } } From 6ea56d1943047e38c1f25e7dcc9a019b5cd21067 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 8 Feb 2026 22:18:13 +0100 Subject: [PATCH 65/83] one more comment for later --- cpp/src/arrow/compute/kernels/vector_select_k.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 9bc266bbd1a..c2d700bba60 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -426,6 +426,9 @@ class RecordBatchSelector { // l = non_null elements to take from PartitionResult // m = null elements to take from PartitionResult // k = l + m because k was clipped to num_rows() + + // TODO.TAE change this function to directly return TARGET/OUTPUT ranges + // -> no need for counts and begins (begins are below) auto [l, m] = calculateNumberNonNullAndNullLikesToTake( p.non_null_like_range, p.null_like_range, k_remaining_, first_remaining_sort_key.null_placement); From 309eb01236e5f73744dd8ad25c81c7208ef3b044 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Sun, 8 Feb 2026 22:36:06 +0100 Subject: [PATCH 66/83] move from source_range to source_count to make helper viable for chunked arrays --- .../arrow/compute/kernels/vector_select_k.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index c2d700bba60..e9b807927bc 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -76,15 +76,14 @@ class SelectKComparator { }; std::pair calculateNumberNonNullAndNullLikesToTake( - const std::span& non_null_like_range, - const std::span& null_like_range, int64_t k, NullPlacement null_placement) { + int64_t non_null_like_count, int64_t null_like_count, int64_t k, NullPlacement null_placement) { if (null_placement == NullPlacement::AtEnd) { - int64_t l = std::min(k, static_cast(non_null_like_range.size())); - int64_t m = std::min(k - l, static_cast(null_like_range.size())); + int64_t l = std::min(k, non_null_like_count); + int64_t m = std::min(k - l, null_like_count); return {l, m}; } else { - int64_t m = std::min(k, static_cast(null_like_range.size())); - int64_t l = std::min(k - m, static_cast(non_null_like_range.size())); + int64_t m = std::min(k, null_like_count); + int64_t l = std::min(k - m, non_null_like_count); return {l, m}; } } @@ -188,7 +187,8 @@ class ArraySelector : public TypeVisitor { // m = null-like elements to take from PartitionResult // k = l + m if enough elements in input auto [l, m] = calculateNumberNonNullAndNullLikesToTake( - {p.non_nulls_begin, p.non_nulls_end}, {p.nulls_begin, p.nulls_end}, k_, + static_cast(p.non_nulls_end - p.non_nulls_begin), + static_cast(p.nulls_end - p.nulls_begin), k_, null_placement_); ARROW_ASSIGN_OR_RAISE(auto take_indices, @@ -430,7 +430,8 @@ class RecordBatchSelector { // TODO.TAE change this function to directly return TARGET/OUTPUT ranges // -> no need for counts and begins (begins are below) auto [l, m] = calculateNumberNonNullAndNullLikesToTake( - p.non_null_like_range, p.null_like_range, k_remaining_, + static_cast(p.non_null_like_range.size()), + static_cast(p.null_like_range.size()), k_remaining_, first_remaining_sort_key.null_placement); uint64_t* non_null_output_indices_begin; @@ -468,7 +469,7 @@ class RecordBatchSelector { // Need to subdivide into null and nan, we use non-nulllike / nulllike stratey // again for nan / null division auto [nan_count, null_count] = calculateNumberNonNullAndNullLikesToTake( - p.nan_range, p.null_range, m, first_remaining_sort_key.null_placement); + static_cast(p.nan_range.size()), static_cast(p.null_range.size()), m, first_remaining_sort_key.null_placement); uint64_t* nan_output_indices_begin; uint64_t* null_output_indices_begin; if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { From 0d59ce9a364d3672581e0c12b6cb7911ff9019f5 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 11:57:55 +0100 Subject: [PATCH 67/83] prepare ChunkedArray implementation --- cpp/src/arrow/compute/kernels/vector_select_k.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index e9b807927bc..05b219fc1d3 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -235,6 +235,7 @@ class ChunkedArraySelector : public TypeVisitor { physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), k_(options.k), order_(options.sort_keys[0].order), + null_placement_(options.sort_keys[0].null_placement), ctx_(ctx), output_(output) {} @@ -288,6 +289,7 @@ class ChunkedArraySelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); + // TODO.TAE maybe just remove this partitioning? const auto p = PartitionNulls( indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); const auto end_iter = p.non_nulls_end; @@ -330,6 +332,7 @@ class ChunkedArraySelector : public TypeVisitor { const ArrayVector physical_chunks_; int64_t k_; SortOrder order_; + NullPlacement null_placement_; ExecContext* ctx_; Datum* output_; }; From c53d0f976482bad551d78eca99b392a459be18d9 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 11:58:42 +0100 Subject: [PATCH 68/83] simple TableSelector fix --- .../arrow/compute/kernels/vector_select_k.cc | 58 +++---------------- 1 file changed, 9 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 05b219fc1d3..1c8039bb58c 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -617,44 +617,15 @@ class TableSelector : public TypeVisitor { return resolved; } - // TODO.TAE remove, it sorts ALL non-null inputs - // Behaves like PartitionNulls() but this supports multiple sort keys. - template - NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, - uint64_t* indices_end, - const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - - const auto p = PartitionNullsOnly( - indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, - NullPlacement::AtEnd); - DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); - - const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, - NullPlacement::AtEnd); - - auto& comparator = comparator_; - // Sort all NaNs by the second and following sort keys. - std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); - - return q; - } - // XXX this implementation is rather inefficient as it computes chunk indices // at every comparison. Instead we should iterate over individual batches // and remember ChunkLocation entries in the max-heap. - +// TODO.TAE remove sort_order? template Status SelectKthInternal() { - auto& comparator = comparator_; - const auto& first_sort_key = sort_keys_[0]; + auto& comparator = comparator_; +// TODO.TAE +// const auto& first_sort_key = sort_keys_[0]; const auto num_rows = table_.num_rows(); if (num_rows == 0) { @@ -663,17 +634,9 @@ class TableSelector : public TypeVisitor { if (k_ > table_.num_rows()) { k_ = table_.num_rows(); } - std::function cmp; - SelectKComparator select_k_comparator; - cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { - auto chunk_left = first_sort_key.GetChunk(left); - auto chunk_right = first_sort_key.GetChunk(right); - auto value_left = chunk_left.Value(); - auto value_right = chunk_right.Value(); - if (value_left == value_right) { - return comparator.Compare(left, right, 1); - } - return select_k_comparator(value_left, value_right); + std::function cmp = + [&](const uint64_t& left, const uint64_t& right) -> bool { + return comparator.Compare(left, right, 0); }; using HeapContainer = std::priority_queue, decltype(cmp)>; @@ -683,13 +646,10 @@ class TableSelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - const auto p = - this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - const auto end_iter = p.non_nulls_end; - auto kth_begin = std::min(indices_begin + k_, end_iter); + auto kth_begin = std::min(indices_begin + k_, indices_end); HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + for (auto iter = kth_begin; iter != indices_end && !heap.empty(); ++iter) { uint64_t x_index = *iter; uint64_t top_item = heap.top(); if (cmp(x_index, top_item)) { From e429ddcc19d2a2b849ab4f0ebba52ad6fb446bd1 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 12:00:55 +0100 Subject: [PATCH 69/83] clean up TableSelector --- .../arrow/compute/kernels/vector_select_k.cc | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 1c8039bb58c..ef102d425ff 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -76,7 +76,8 @@ class SelectKComparator { }; std::pair calculateNumberNonNullAndNullLikesToTake( - int64_t non_null_like_count, int64_t null_like_count, int64_t k, NullPlacement null_placement) { + int64_t non_null_like_count, int64_t null_like_count, int64_t k, + NullPlacement null_placement) { if (null_placement == NullPlacement::AtEnd) { int64_t l = std::min(k, non_null_like_count); int64_t m = std::min(k - l, null_like_count); @@ -188,8 +189,7 @@ class ArraySelector : public TypeVisitor { // k = l + m if enough elements in input auto [l, m] = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_nulls_end - p.non_nulls_begin), - static_cast(p.nulls_end - p.nulls_begin), k_, - null_placement_); + static_cast(p.nulls_end - p.nulls_begin), k_, null_placement_); ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(l + m, ctx_->memory_pool())); @@ -472,7 +472,9 @@ class RecordBatchSelector { // Need to subdivide into null and nan, we use non-nulllike / nulllike stratey // again for nan / null division auto [nan_count, null_count] = calculateNumberNonNullAndNullLikesToTake( - static_cast(p.nan_range.size()), static_cast(p.null_range.size()), m, first_remaining_sort_key.null_placement); + static_cast(p.nan_range.size()), + static_cast(p.null_range.size()), m, + first_remaining_sort_key.null_placement); uint64_t* nan_output_indices_begin; uint64_t* null_output_indices_begin; if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { @@ -588,20 +590,10 @@ class TableSelector : public TypeVisitor { Status Run() { RETURN_NOT_OK(status_); - return sort_keys_[0].type->Accept(this); + return SelectKthInternal(); } protected: -#define VISIT(TYPE) \ - Status Visit(const TYPE& type) { \ - if (sort_keys_[0].order == SortOrder::Descending) \ - return SelectKthInternal(); \ - return SelectKthInternal(); \ - } - VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) - -#undef VISIT - static std::vector ResolveSortKeys( const Table& table, const std::vector& sort_keys, Status* status) { std::vector resolved; @@ -620,12 +612,8 @@ class TableSelector : public TypeVisitor { // XXX this implementation is rather inefficient as it computes chunk indices // at every comparison. Instead we should iterate over individual batches // and remember ChunkLocation entries in the max-heap. -// TODO.TAE remove sort_order? - template Status SelectKthInternal() { - auto& comparator = comparator_; -// TODO.TAE -// const auto& first_sort_key = sort_keys_[0]; + auto& comparator = comparator_; const auto num_rows = table_.num_rows(); if (num_rows == 0) { From 5465d95769615aa909a02c0e4ffbb8f7ca6d682e Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 12:46:21 +0100 Subject: [PATCH 70/83] improve helper functions for clearer logic --- .../arrow/compute/kernels/vector_select_k.cc | 124 +++++++++++------- 1 file changed, 77 insertions(+), 47 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index ef102d425ff..293bb1b1b0c 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -75,18 +75,31 @@ class SelectKComparator { } }; -std::pair calculateNumberNonNullAndNullLikesToTake( - int64_t non_null_like_count, int64_t null_like_count, int64_t k, - NullPlacement null_placement) { +struct OutputRangesByNullLikeness { + std::span non_null_like_output; + std::span nan_output; + std::span null_output; +}; + +OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( + int64_t non_null_like_count, int64_t nan_count, int64_t null_count, int64_t k, + NullPlacement null_placement, uint64_t* output_begin) { + int64_t non_null_like_to_take = 0; + int64_t nan_to_take = 0; + int64_t null_to_take = 0; if (null_placement == NullPlacement::AtEnd) { - int64_t l = std::min(k, non_null_like_count); - int64_t m = std::min(k - l, null_like_count); - return {l, m}; + non_null_like_to_take = std::min(k, non_null_like_count); + nan_to_take = std::min(k - non_null_like_to_take, nan_count); + null_to_take = std::min(k - non_null_like_to_take - nan_to_take, null_count); } else { - int64_t m = std::min(k, null_like_count); - int64_t l = std::min(k - m, non_null_like_count); - return {l, m}; + null_to_take = std::min(k, null_count); + nan_to_take = std::min(k - null_to_take, nan_count); + non_null_like_to_take = std::min(k - null_to_take - nan_to_take, non_null_like_count); } + return OutputRangesByNullLikeness{ + .non_null_like_output = {output_begin, output_begin + non_null_like_to_take}, + .nan_output = {output_begin, output_begin + nan_to_take}, + .null_output = {output_begin, output_begin + null_to_take}}; } template @@ -130,6 +143,7 @@ void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, } template +// TODO.TAE Could merge l and output into one span now void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, const typename TypeTraits::ArrayType& arr, SortOrder order, uint64_t* output) { @@ -183,26 +197,32 @@ class ArraySelector : public TypeVisitor { const auto p = PartitionNulls( indices_begin, indices_end, arr, 0, null_placement_); + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(k_, ctx_->memory_pool())); + auto* output = take_indices->template GetMutableValues(1); + // From k, calculate // l = non-null elements to take from PartitionResult // m = null-like elements to take from PartitionResult // k = l + m if enough elements in input - auto [l, m] = calculateNumberNonNullAndNullLikesToTake( + auto output_ranges = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_nulls_end - p.non_nulls_begin), - static_cast(p.nulls_end - p.nulls_begin), k_, null_placement_); - - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(l + m, ctx_->memory_pool())); - auto* output = take_indices->template GetMutableValues(1); + 0, // TODO.TAE it would be okay to consider these equal, but better not? + static_cast(p.nulls_end - p.nulls_begin), k_, null_placement_, output); if (null_placement_ == NullPlacement::AtEnd) { - HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, - l, arr, output); - std::copy(p.nulls_begin, p.nulls_begin + m, output + l); + HeapSortNonNullsToOutput( + {p.non_nulls_begin, p.non_nulls_end}, output_ranges.non_null_like_output.size(), + arr, output); + std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), + output_ranges.null_output.begin()); } else { - std::copy(p.nulls_begin, p.nulls_begin + m, output); - HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, - l, arr, output + m); + std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), output); + HeapSortNonNullsToOutput( + {p.non_nulls_begin, p.non_nulls_end}, output_ranges.non_null_like_output.size(), + arr, + // TODO.TAE remove this &* + &*output_ranges.null_output.begin()); } *output_ = Datum(take_indices); @@ -339,8 +359,6 @@ class ChunkedArraySelector : public TypeVisitor { struct NullNanPartitionResult { std::span non_null_like_range; - std::span null_like_range; - // Also store the null/nan distribution within null_like_range std::span null_range; std::span nan_range; }; @@ -357,8 +375,6 @@ NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); return NullNanPartitionResult{ .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, - .null_like_range = {std::min(q.nulls_begin, p.nulls_begin), - std::max(q.nulls_end, p.nulls_end)}, .null_range = {p.nulls_begin, p.nulls_end}, .nan_range = {q.nulls_begin, q.nulls_end}}; } @@ -430,38 +446,36 @@ class RecordBatchSelector { // m = null elements to take from PartitionResult // k = l + m because k was clipped to num_rows() - // TODO.TAE change this function to directly return TARGET/OUTPUT ranges - // -> no need for counts and begins (begins are below) - auto [l, m] = calculateNumberNonNullAndNullLikesToTake( + auto output_ranges = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_null_like_range.size()), - static_cast(p.null_like_range.size()), k_remaining_, - first_remaining_sort_key.null_placement); - - uint64_t* non_null_output_indices_begin; - uint64_t* nulllike_output_indices_begin; - if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { - non_null_output_indices_begin = output_indices_; - nulllike_output_indices_begin = output_indices_ + l; - } else { - nulllike_output_indices_begin = output_indices_; - non_null_output_indices_begin = output_indices_ + m; - } + static_cast(p.nan_range.size()), + static_cast(p.null_range.size()), k_remaining_, + first_remaining_sort_key.null_placement, output_indices_); bool last_sort_key = start_sort_key_index_ + 1 == selector_->sort_keys_.size(); if (last_sort_key) { - if (l > 0) { - HeapSortNonNullsToOutput(p.non_null_like_range, l, arr, + if (!output_ranges.non_null_like_output.empty()) { + HeapSortNonNullsToOutput(p.non_null_like_range, + output_ranges.non_null_like_output.size(), arr, first_remaining_sort_key.order, - non_null_output_indices_begin); + // TODO.TAE remove this &* + &*output_ranges.non_null_like_output.begin()); } - if (m > 0) { + if (output_ranges.nan_output.size() > 0) { // We have the last sort_key, can just copy over the null values - std::copy(p.null_like_range.begin(), p.null_like_range.begin() + m, - nulllike_output_indices_begin); + std::copy(p.nan_range.begin(), + p.nan_range.begin() + output_ranges.nan_output.size(), + output_ranges.nan_output.begin()); + } + if (output_ranges.null_output.size() > 0) { + // We have the last sort_key, can just copy over the null values + std::copy(p.null_range.begin(), + p.null_range.begin() + output_ranges.null_output.size(), + output_ranges.null_output.begin()); } } else { - if (l > 0) { + if (!output_ranges.non_null_like_output.empty()) { auto cmp = [&](uint64_t left, uint64_t right) { return selector_->comparator_.Compare(left, right, start_sort_key_index_); }; @@ -495,6 +509,22 @@ class RecordBatchSelector { p.null_range, null_count, null_output_indices_begin)); } + HeapSortNonNullsToOutput(p.non_null_like_range, + output_ranges.non_null_like_output.size(), cmp, + // TODO.TAE remove this &* + &*output_ranges.non_null_like_output.begin()); + } + if (output_ranges.nan_output.size() > 0) { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( + start_sort_key_index_ + 1, p.nan_range, output_ranges.nan_output.size(), + // TODO.TAE remove this &* + &*output_ranges.nan_output.begin())); + } + if (output_ranges.null_output.size() > 0) { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( + start_sort_key_index_ + 1, p.null_range, output_ranges.null_output.size(), + // TODO.TAE remove this &* + &*output_ranges.null_output.begin())); } } From cea0ced2bab8139741cc3c82afa02b6605f89dd0 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 14:13:58 +0100 Subject: [PATCH 71/83] fix helper function --- .../arrow/compute/kernels/vector_select_k.cc | 44 +++++-------------- 1 file changed, 10 insertions(+), 34 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 293bb1b1b0c..d063acb4132 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -91,15 +91,21 @@ OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( non_null_like_to_take = std::min(k, non_null_like_count); nan_to_take = std::min(k - non_null_like_to_take, nan_count); null_to_take = std::min(k - non_null_like_to_take - nan_to_take, null_count); + // TODO.TAE make this prettier + return OutputRangesByNullLikeness{ + .non_null_like_output = {output_begin, output_begin + non_null_like_to_take}, + .nan_output = {output_begin + non_null_like_to_take, output_begin + non_null_like_to_take + nan_to_take}, + .null_output = {output_begin + non_null_like_to_take + nan_to_take, output_begin + non_null_like_to_take + nan_to_take + null_to_take}}; } else { null_to_take = std::min(k, null_count); nan_to_take = std::min(k - null_to_take, nan_count); non_null_like_to_take = std::min(k - null_to_take - nan_to_take, non_null_like_count); + // TODO.TAE make this prettier + return OutputRangesByNullLikeness{ + .non_null_like_output = {output_begin + null_to_take + nan_to_take, output_begin + null_to_take + nan_to_take + non_null_like_to_take}, + .nan_output = {output_begin + null_to_take, output_begin + null_to_take + nan_to_take}, + .null_output = {output_begin, output_begin + null_to_take}}; } - return OutputRangesByNullLikeness{ - .non_null_like_output = {output_begin, output_begin + non_null_like_to_take}, - .nan_output = {output_begin, output_begin + nan_to_take}, - .null_output = {output_begin, output_begin + null_to_take}}; } template @@ -479,36 +485,6 @@ class RecordBatchSelector { auto cmp = [&](uint64_t left, uint64_t right) { return selector_->comparator_.Compare(left, right, start_sort_key_index_); }; - HeapSortNonNullsToOutput(p.non_null_like_range, l, cmp, - non_null_output_indices_begin); - } - if (m > 0) { - // Need to subdivide into null and nan, we use non-nulllike / nulllike stratey - // again for nan / null division - auto [nan_count, null_count] = calculateNumberNonNullAndNullLikesToTake( - static_cast(p.nan_range.size()), - static_cast(p.null_range.size()), m, - first_remaining_sort_key.null_placement); - uint64_t* nan_output_indices_begin; - uint64_t* null_output_indices_begin; - if (first_remaining_sort_key.null_placement == NullPlacement::AtEnd) { - nan_output_indices_begin = nulllike_output_indices_begin; - null_output_indices_begin = nulllike_output_indices_begin + nan_count; - } else { - null_output_indices_begin = nulllike_output_indices_begin; - nan_output_indices_begin = nulllike_output_indices_begin + null_count; - } - - if (nan_count > 0) { - ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey(start_sort_key_index_ + 1, - p.nan_range, nan_count, - nan_output_indices_begin)); - } - if (null_count > 0) { - ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey(start_sort_key_index_ + 1, - p.null_range, null_count, - null_output_indices_begin)); - } HeapSortNonNullsToOutput(p.non_null_like_range, output_ranges.non_null_like_output.size(), cmp, // TODO.TAE remove this &* From d3f4229751c5d9772eef38f22569bbcf21edad25 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 14:50:30 +0100 Subject: [PATCH 72/83] fix array function --- .../arrow/compute/kernels/vector_select_k.cc | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index d063acb4132..fcb98c73c4d 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -94,16 +94,22 @@ OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( // TODO.TAE make this prettier return OutputRangesByNullLikeness{ .non_null_like_output = {output_begin, output_begin + non_null_like_to_take}, - .nan_output = {output_begin + non_null_like_to_take, output_begin + non_null_like_to_take + nan_to_take}, - .null_output = {output_begin + non_null_like_to_take + nan_to_take, output_begin + non_null_like_to_take + nan_to_take + null_to_take}}; + .nan_output = {output_begin + non_null_like_to_take, + output_begin + non_null_like_to_take + nan_to_take}, + .null_output = { + output_begin + non_null_like_to_take + nan_to_take, + output_begin + non_null_like_to_take + nan_to_take + null_to_take}}; } else { null_to_take = std::min(k, null_count); nan_to_take = std::min(k - null_to_take, nan_count); non_null_like_to_take = std::min(k - null_to_take - nan_to_take, non_null_like_count); // TODO.TAE make this prettier return OutputRangesByNullLikeness{ - .non_null_like_output = {output_begin + null_to_take + nan_to_take, output_begin + null_to_take + nan_to_take + non_null_like_to_take}, - .nan_output = {output_begin + null_to_take, output_begin + null_to_take + nan_to_take}, + .non_null_like_output = {output_begin + null_to_take + nan_to_take, + output_begin + null_to_take + nan_to_take + + non_null_like_to_take}, + .nan_output = {output_begin + null_to_take, + output_begin + null_to_take + nan_to_take}, .null_output = {output_begin, output_begin + null_to_take}}; } } @@ -194,6 +200,9 @@ class ArraySelector : public TypeVisitor { using ArrayType = typename TypeTraits::ArrayType; ArrayType arr(array_.data()); + + k_ = std::min(k_, arr.length()); + std::vector indices(arr.length()); uint64_t* indices_begin = indices.data(); @@ -223,12 +232,13 @@ class ArraySelector : public TypeVisitor { std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), output_ranges.null_output.begin()); } else { - std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), output); + std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), + output_ranges.null_output.begin()); HeapSortNonNullsToOutput( {p.non_nulls_begin, p.non_nulls_end}, output_ranges.non_null_like_output.size(), arr, // TODO.TAE remove this &* - &*output_ranges.null_output.begin()); + &*output_ranges.non_null_like_output.begin()); } *output_ = Datum(take_indices); From 623a421ad354edc1161aa369857c6d753bafcf25 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 14:54:49 +0100 Subject: [PATCH 73/83] cleaner helper function --- .../arrow/compute/kernels/vector_select_k.cc | 91 +++++++++++-------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index fcb98c73c4d..8c14abd07e1 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -115,11 +115,13 @@ OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( } template -void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, - Comparator cmp, uint64_t* output) { - std::span heap{non_null_range.begin(), non_null_range.begin() + l}; +void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparator cmp, + std::span output_range) { + std::span heap{non_null_input_range.begin(), + non_null_input_range.begin() + output_range.size()}; std::make_heap(heap.begin(), heap.end(), cmp); - for (auto iter = non_null_range.begin() + l; iter != non_null_range.end(); ++iter) { + for (auto iter = non_null_input_range.begin() + output_range.size(); + iter != non_null_input_range.end(); ++iter) { uint64_t x_index = *iter; if (cmp(x_index, heap.front())) { std::pop_heap(heap.begin(), heap.end(), cmp); @@ -130,10 +132,11 @@ void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, // fill output in reverse when destructing, // as the "worst" (next-to-would-have-been-replaced) element is at heap-top + // TODO.TAE remove these &* uint64_t* heap_begin = &*heap.begin(); - uint64_t* heap_end = &*heap.begin() + l; - for (auto reverse_out_iter = output + l - 1; reverse_out_iter >= output; - --reverse_out_iter) { + uint64_t* heap_end = &*heap.begin() + output_range.size(); + for (auto reverse_out_iter = output_range.rbegin(); + reverse_out_iter != output_range.rend(); reverse_out_iter++) { *reverse_out_iter = *heap_begin; // heap-top has the next element std::pop_heap(heap_begin, heap_end, cmp); --heap_end; @@ -141,9 +144,9 @@ void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, } template -void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, +void HeapSortNonNullsToOutput(std::span non_null_input_range, const typename TypeTraits::ArrayType& arr, - uint64_t* output) { + std::span output_range) { using GetView = GetViewType; SelectKComparator comparator; auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { @@ -151,20 +154,20 @@ void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, const auto rval = GetView::LogicalValue(arr.GetView(right)); return comparator(lval, rval); }; - HeapSortNonNullsToOutput(non_null_range, l, cmp, output); + HeapSortNonNullsToOutput(non_null_input_range, cmp, output_range); } template // TODO.TAE Could merge l and output into one span now -void HeapSortNonNullsToOutput(std::span non_null_range, int64_t l, +void HeapSortNonNullsToOutput(std::span non_null_input_range, const typename TypeTraits::ArrayType& arr, - SortOrder order, uint64_t* output) { + SortOrder order, std::span output_range) { if (order == SortOrder::Ascending) { - HeapSortNonNullsToOutput(non_null_range, l, arr, - output); + HeapSortNonNullsToOutput(non_null_input_range, arr, + output_range); } else { - HeapSortNonNullsToOutput(non_null_range, l, arr, - output); + HeapSortNonNullsToOutput(non_null_input_range, arr, + output_range); } } @@ -227,18 +230,14 @@ class ArraySelector : public TypeVisitor { if (null_placement_ == NullPlacement::AtEnd) { HeapSortNonNullsToOutput( - {p.non_nulls_begin, p.non_nulls_end}, output_ranges.non_null_like_output.size(), - arr, output); + {p.non_nulls_begin, p.non_nulls_end}, arr, output_ranges.non_null_like_output); std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), output_ranges.null_output.begin()); } else { std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), output_ranges.null_output.begin()); HeapSortNonNullsToOutput( - {p.non_nulls_begin, p.non_nulls_end}, output_ranges.non_null_like_output.size(), - arr, - // TODO.TAE remove this &* - &*output_ranges.non_null_like_output.begin()); + {p.non_nulls_begin, p.non_nulls_end}, arr, output_ranges.non_null_like_output); } *output_ = Datum(take_indices); @@ -262,6 +261,9 @@ struct TypedHeapItem { }; class ChunkedArraySelector : public TypeVisitor { + using ResolvedSortKey = ResolvedTableSortKey; + using Comparator = MultipleKeyComparator; + public: ChunkedArraySelector(ExecContext* ctx, const ChunkedArray& chunked_array, const SelectKOptions& options, Datum* output) @@ -288,6 +290,26 @@ class ChunkedArraySelector : public TypeVisitor { VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) #undef VISIT + // template + // int64_t ComputeNanCount(){ + // using GetView = GetViewType; + // using ArrayType = typename TypeTraits::ArrayType; + // if constexpr (has_null_like_values()) { + // int64_t nan_count = 0; + // for (const auto& chunk : physical_chunks_) { + // auto values = std::make_shared(chunk->data()); + // int64_t length = values->length(); + // for(int64_t index = 0; index < length; ++index){ + // if(std::isnan(values->GetView(index))){ + // nan_count++; + // } + // } + // } + // return nan_count; + // } + // return 0; + // } + template Status SelectKthInternal() { using GetView = GetViewType; @@ -301,6 +323,11 @@ class ChunkedArraySelector : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } + // int64_t null_count = chunked_array_.null_count(); + // int64_t nan_count = ComputeNanCount(); + // TODO.TAE int64_t non_null_like_count = chunked_array_.length() - null_count - + // nan_count; + std::function cmp; SelectKComparator comparator; @@ -325,17 +352,12 @@ class ChunkedArraySelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - // TODO.TAE maybe just remove this partitioning? - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); + auto kth_begin = std::min(indices_begin + k_, indices_end); uint64_t* iter = indices_begin; for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { heap.push(HeapItem{*iter, offset, &arr}); } - for (; iter != end_iter && !heap.empty(); ++iter) { + for (; iter != indices_end && !heap.empty(); ++iter) { uint64_t x_index = *iter; const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); auto top_item = heap.top(); @@ -472,11 +494,10 @@ class RecordBatchSelector { if (last_sort_key) { if (!output_ranges.non_null_like_output.empty()) { - HeapSortNonNullsToOutput(p.non_null_like_range, - output_ranges.non_null_like_output.size(), arr, + HeapSortNonNullsToOutput(p.non_null_like_range, arr, first_remaining_sort_key.order, // TODO.TAE remove this &* - &*output_ranges.non_null_like_output.begin()); + output_ranges.non_null_like_output); } if (output_ranges.nan_output.size() > 0) { // We have the last sort_key, can just copy over the null values @@ -495,10 +516,8 @@ class RecordBatchSelector { auto cmp = [&](uint64_t left, uint64_t right) { return selector_->comparator_.Compare(left, right, start_sort_key_index_); }; - HeapSortNonNullsToOutput(p.non_null_like_range, - output_ranges.non_null_like_output.size(), cmp, - // TODO.TAE remove this &* - &*output_ranges.non_null_like_output.begin()); + HeapSortNonNullsToOutput(p.non_null_like_range, cmp, + output_ranges.non_null_like_output); } if (output_ranges.nan_output.size() > 0) { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( From 9d0b0c2316f185da3ca063ace639513724097d7d Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 16:28:23 +0100 Subject: [PATCH 74/83] fixed ChunkedArraySelector diff --git c/cpp/src/arrow/compute/kernels/vector_select_k.cc i/cpp/src/arrow/compute/kernels/vector_select_k.cc index 8c14abd07e..ed1a89b027 100644 --- c/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ i/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include "arrow/compute/function.h" @@ -82,8 +83,9 @@ struct OutputRangesByNullLikeness { }; OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( - int64_t non_null_like_count, int64_t nan_count, int64_t null_count, int64_t k, - NullPlacement null_placement, uint64_t* output_begin) { + int64_t non_null_like_count, int64_t nan_count, int64_t null_count, + NullPlacement null_placement, std::span output_indices) { + int64_t k = output_indices.size(); int64_t non_null_like_to_take = 0; int64_t nan_to_take = 0; int64_t null_to_take = 0; @@ -91,38 +93,31 @@ OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( non_null_like_to_take = std::min(k, non_null_like_count); nan_to_take = std::min(k - non_null_like_to_take, nan_count); null_to_take = std::min(k - non_null_like_to_take - nan_to_take, null_count); - // TODO.TAE make this prettier return OutputRangesByNullLikeness{ - .non_null_like_output = {output_begin, output_begin + non_null_like_to_take}, - .nan_output = {output_begin + non_null_like_to_take, - output_begin + non_null_like_to_take + nan_to_take}, - .null_output = { - output_begin + non_null_like_to_take + nan_to_take, - output_begin + non_null_like_to_take + nan_to_take + null_to_take}}; + .non_null_like_output = output_indices.subspan(0, non_null_like_to_take), + .nan_output = output_indices.subspan(non_null_like_to_take, nan_to_take), + .null_output = + output_indices.subspan(non_null_like_to_take + nan_to_take, null_to_take)}; } else { null_to_take = std::min(k, null_count); nan_to_take = std::min(k - null_to_take, nan_count); non_null_like_to_take = std::min(k - null_to_take - nan_to_take, non_null_like_count); - // TODO.TAE make this prettier return OutputRangesByNullLikeness{ - .non_null_like_output = {output_begin + null_to_take + nan_to_take, - output_begin + null_to_take + nan_to_take + - non_null_like_to_take}, - .nan_output = {output_begin + null_to_take, - output_begin + null_to_take + nan_to_take}, - .null_output = {output_begin, output_begin + null_to_take}}; + .non_null_like_output = + output_indices.subspan(null_to_take + nan_to_take, non_null_like_to_take), + .nan_output = output_indices.subspan(null_to_take, nan_to_take), + .null_output = output_indices.subspan(0, null_to_take)}; } } template void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparator cmp, std::span output_range) { - std::span heap{non_null_input_range.begin(), - non_null_input_range.begin() + output_range.size()}; + std::span heap = non_null_input_range.subspan(0, output_range.size()); std::make_heap(heap.begin(), heap.end(), cmp); - for (auto iter = non_null_input_range.begin() + output_range.size(); - iter != non_null_input_range.end(); ++iter) { - uint64_t x_index = *iter; + + std::span remaining_input = non_null_input_range.subspan(output_range.size()); + for (uint64_t x_index : remaining_input) { if (cmp(x_index, heap.front())) { std::pop_heap(heap.begin(), heap.end(), cmp); heap.back() = x_index; @@ -132,14 +127,12 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparat // fill output in reverse when destructing, // as the "worst" (next-to-would-have-been-replaced) element is at heap-top - // TODO.TAE remove these &* - uint64_t* heap_begin = &*heap.begin(); - uint64_t* heap_end = &*heap.begin() + output_range.size(); for (auto reverse_out_iter = output_range.rbegin(); reverse_out_iter != output_range.rend(); reverse_out_iter++) { - *reverse_out_iter = *heap_begin; // heap-top has the next element - std::pop_heap(heap_begin, heap_end, cmp); - --heap_end; + *reverse_out_iter = heap.front(); // heap-top has the next element + std::ranges::pop_heap(heap, cmp); + // Decrease heap-size by one + heap = heap.first(heap.size() - 1); } } @@ -158,7 +151,6 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, } template -// TODO.TAE Could merge l and output into one span now void HeapSortNonNullsToOutput(std::span non_null_input_range, const typename TypeTraits::ArrayType& arr, SortOrder order, std::span output_range) { @@ -171,6 +163,28 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, } } +struct NullNanPartitionResult { + std::span non_null_like_range; + std::span null_range; + std::span nan_range; +}; + +template +NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, + uint64_t* indices_end, + const ArrayType& values, int64_t offset, + NullPlacement null_placement) { + // Partition nulls at start (resp. end), and null-like values just before (resp. after) + NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, + values, offset, null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); + return NullNanPartitionResult{ + .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, + .null_range = {p.nulls_begin, p.nulls_end}, + .nan_range = {q.nulls_begin, q.nulls_end}}; +} + class ArraySelector : public TypeVisitor { public: ArraySelector(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -220,25 +234,22 @@ class ArraySelector : public TypeVisitor { auto* output = take_indices->template GetMutableValues(1); // From k, calculate - // l = non-null elements to take from PartitionResult - // m = null-like elements to take from PartitionResult - // k = l + m if enough elements in input + // l = non_null_like elements to take from PartitionResult + // m = nan elements to take from PartitionResult + // n = null elements to take from PartitionResult + // k = l + m + n because k was clipped to arr.length() + // And directly compute the ranges in {output, output+k} where we will need to place + // the selected elements from each group -> no longer need to track null_placement auto output_ranges = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_nulls_end - p.non_nulls_begin), - 0, // TODO.TAE it would be okay to consider these equal, but better not? - static_cast(p.nulls_end - p.nulls_begin), k_, null_placement_, output); + 0, // TODO.TAE it would be okay to consider null/nan equal, but better not? + static_cast(p.nulls_end - p.nulls_begin), null_placement_, + {output, output + k_}); - if (null_placement_ == NullPlacement::AtEnd) { - HeapSortNonNullsToOutput( - {p.non_nulls_begin, p.non_nulls_end}, arr, output_ranges.non_null_like_output); - std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), - output_ranges.null_output.begin()); - } else { - std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), - output_ranges.null_output.begin()); - HeapSortNonNullsToOutput( - {p.non_nulls_begin, p.non_nulls_end}, arr, output_ranges.non_null_like_output); - } + HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, + arr, output_ranges.non_null_like_output); + std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), + output_ranges.null_output.begin()); *output_ = Datum(take_indices); return Status::OK(); @@ -290,25 +301,24 @@ class ChunkedArraySelector : public TypeVisitor { VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) #undef VISIT - // template - // int64_t ComputeNanCount(){ - // using GetView = GetViewType; - // using ArrayType = typename TypeTraits::ArrayType; - // if constexpr (has_null_like_values()) { - // int64_t nan_count = 0; - // for (const auto& chunk : physical_chunks_) { - // auto values = std::make_shared(chunk->data()); - // int64_t length = values->length(); - // for(int64_t index = 0; index < length; ++index){ - // if(std::isnan(values->GetView(index))){ - // nan_count++; - // } - // } - // } - // return nan_count; - // } - // return 0; - // } + template + int64_t ComputeNanCount() { + using ArrayType = typename TypeTraits::ArrayType; + if constexpr (has_null_like_values()) { + int64_t nan_count = 0; + for (const auto& chunk : physical_chunks_) { + auto values = std::make_shared(chunk->data()); + int64_t length = values->length(); + for (int64_t index = 0; index < length; ++index) { + if (std::isnan(values->GetView(index))) { + nan_count++; + } + } + } + return nan_count; + } + return 0; + } template Status SelectKthInternal() { @@ -323,14 +333,28 @@ class ChunkedArraySelector : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } - // int64_t null_count = chunked_array_.null_count(); - // int64_t nan_count = ComputeNanCount(); - // TODO.TAE int64_t non_null_like_count = chunked_array_.length() - null_count - - // nan_count; + + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(k_, ctx_->memory_pool())); + auto* output_begin = take_indices->GetMutableValues(1); + + int64_t null_count = chunked_array_.null_count(); + int64_t nan_count = ComputeNanCount(); + int64_t non_null_like_count = chunked_array_.length() - null_count - nan_count; + + auto output = calculateNumberNonNullAndNullLikesToTake( + non_null_like_count, nan_count, null_count, null_placement_, + {output_begin, output_begin + k_}); + + // Now we can independently fill the output with non_null, nan and null items. + // For non_null, we do a heap_sort, the others can just be copied until + // nan_taken = output.nan_range.size() and + // null_taken = output.null_range.size() respectively + size_t nan_taken = 0; + size_t null_taken = 0; std::function cmp; SelectKComparator comparator; - cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); @@ -338,9 +362,9 @@ class ChunkedArraySelector : public TypeVisitor { }; using HeapContainer = std::priority_queue, decltype(cmp)>; - HeapContainer heap(cmp); std::vector> chunks_holder; + uint64_t offset = 0; for (const auto& chunk : physical_chunks_) { if (chunk->length() == 0) continue; @@ -352,12 +376,29 @@ class ChunkedArraySelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - auto kth_begin = std::min(indices_begin + k_, indices_end); - uint64_t* iter = indices_begin; - for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { + const auto p = PartitionNullsAndNans( + indices_begin, indices_end, arr, 0, null_placement_); + + // First do nulls and nans + auto iter = p.null_range.begin(); + for (; iter != p.null_range.end() && null_taken < output.null_output.size(); + ++iter) { + output.null_output[null_taken] = offset + *iter; + null_taken++; + } + iter = p.nan_range.begin(); + for (; iter != p.nan_range.end() && nan_taken < output.nan_output.size(); ++iter) { + output.nan_output[nan_taken] = offset + *iter; + nan_taken++; + } + + iter = p.non_null_like_range.begin(); + for (; iter != p.non_null_like_range.end() && + heap.size() < output.non_null_like_output.size(); + ++iter) { heap.push(HeapItem{*iter, offset, &arr}); } - for (; iter != indices_end && !heap.empty(); ++iter) { + for (; iter != p.non_null_like_range.end() && !heap.empty(); ++iter) { uint64_t x_index = *iter; const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); auto top_item = heap.top(); @@ -371,16 +412,17 @@ class ChunkedArraySelector : public TypeVisitor { offset += chunk->length(); } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - auto top_item = heap.top(); - *out_cbegin = top_item.index + top_item.offset; + // We sized output.non_null_like_output to hold exactly sufficient indices, + // so the heap must have been completely filled + assert(heap.size() == output.non_null_like_output.size()); + + for (auto reverse_out_iter = output.non_null_like_output.rbegin(); + reverse_out_iter != output.non_null_like_output.rend(); reverse_out_iter++) { + *reverse_out_iter = + heap.top().index + heap.top().offset; // heap-top has the next element heap.pop(); - --out_cbegin; } + *output_ = Datum(take_indices); return Status::OK(); } @@ -395,28 +437,6 @@ class ChunkedArraySelector : public TypeVisitor { Datum* output_; }; -struct NullNanPartitionResult { - std::span non_null_like_range; - std::span null_range; - std::span nan_range; -}; - -template -NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, - uint64_t* indices_end, - const ArrayType& values, int64_t offset, - NullPlacement null_placement) { - // Partition nulls at start (resp. end), and null-like values just before (resp. after) - NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, - values, offset, null_placement); - NullPartitionResult q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); - return NullNanPartitionResult{ - .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, - .null_range = {p.nulls_begin, p.nulls_end}, - .nan_range = {q.nulls_begin, q.nulls_end}}; -} - class RecordBatchSelector { private: using ResolvedSortKey = ResolvedRecordBatchSortKey; @@ -455,13 +475,11 @@ class RecordBatchSelector { class SelectKForKey : public TypeVisitor { public: SelectKForKey(RecordBatchSelector* selector, size_t start_sort_key_index, - std::span input_indices, int64_t k_remaining, - uint64_t* output_indices) + std::span input_indices, std::span output_indices) : TypeVisitor(), selector_(selector), start_sort_key_index_(start_sort_key_index), input_indices_(input_indices), - k_remaining_(k_remaining), output_indices_(output_indices) {} private: @@ -471,23 +489,24 @@ class RecordBatchSelector { const auto& first_remaining_sort_key = selector_->sort_keys_[start_sort_key_index_]; const auto& arr = checked_cast(first_remaining_sort_key.array); - // TODO.TAE uhh this could be prettier - uint64_t* input_indices_begin = &*input_indices_.begin(); - uint64_t* input_indices_end = input_indices_begin + input_indices_.size(); + uint64_t* input_indices_begin = input_indices_.data(); + uint64_t* input_indices_end = input_indices_.data() + input_indices_.size(); const auto p = PartitionNullsAndNans( input_indices_begin, input_indices_end, arr, 0, first_remaining_sort_key.null_placement); - // From k, calculate + // From k = output_range.size(), calculate // l = non_null elements to take from PartitionResult - // m = null elements to take from PartitionResult - // k = l + m because k was clipped to num_rows() - + // m = nan elements to take from PartitionResult + // n = null elements to take from PartitionResult + // k = l + m + n because k was clipped to num_rows() + // And directly compute the ranges in output_indices_ where we will need to place + // the selected elements from each group -> no longer need to track null_placement auto output_ranges = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_null_like_range.size()), static_cast(p.nan_range.size()), - static_cast(p.null_range.size()), k_remaining_, + static_cast(p.null_range.size()), first_remaining_sort_key.null_placement, output_indices_); bool last_sort_key = start_sort_key_index_ + 1 == selector_->sort_keys_.size(); @@ -496,7 +515,6 @@ class RecordBatchSelector { if (!output_ranges.non_null_like_output.empty()) { HeapSortNonNullsToOutput(p.non_null_like_range, arr, first_remaining_sort_key.order, - // TODO.TAE remove this &* output_ranges.non_null_like_output); } if (output_ranges.nan_output.size() > 0) { @@ -521,15 +539,11 @@ class RecordBatchSelector { } if (output_ranges.nan_output.size() > 0) { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, p.nan_range, output_ranges.nan_output.size(), - // TODO.TAE remove this &* - &*output_ranges.nan_output.begin())); + start_sort_key_index_ + 1, p.nan_range, output_ranges.nan_output)); } if (output_ranges.null_output.size() > 0) { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, p.null_range, output_ranges.null_output.size(), - // TODO.TAE remove this &* - &*output_ranges.null_output.begin())); + start_sort_key_index_ + 1, p.null_range, output_ranges.null_output)); } } @@ -545,14 +559,12 @@ class RecordBatchSelector { RecordBatchSelector* selector_; size_t start_sort_key_index_; std::span input_indices_; - int64_t k_remaining_; - uint64_t* output_indices_; + std::span output_indices_; }; Status DoSelectKForKey(size_t start_sort_key_index, std::span input_indices, - int64_t k_remaining, uint64_t* output_indices) { - SelectKForKey tmp(this, start_sort_key_index, input_indices, k_remaining, - output_indices); + std::span output_indices) { + SelectKForKey tmp(this, start_sort_key_index, input_indices, output_indices); return sort_keys_.at(start_sort_key_index).type->Accept(&tmp); } @@ -571,7 +583,8 @@ class RecordBatchSelector { auto* output_indices = take_indices->template GetMutableValues(1); std::span input_indices_span(input_indices); - ARROW_RETURN_NOT_OK(DoSelectKForKey(0, input_indices_span, k_, output_indices)); + ARROW_RETURN_NOT_OK( + DoSelectKForKey(0, input_indices_span, {output_indices, output_indices + k_})); *output_ = Datum(take_indices); return arrow::Status::OK(); } --- .../arrow/compute/kernels/vector_select_k.cc | 269 +++++++++--------- 1 file changed, 141 insertions(+), 128 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 8c14abd07e1..ed1a89b027d 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include "arrow/compute/function.h" @@ -82,8 +83,9 @@ struct OutputRangesByNullLikeness { }; OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( - int64_t non_null_like_count, int64_t nan_count, int64_t null_count, int64_t k, - NullPlacement null_placement, uint64_t* output_begin) { + int64_t non_null_like_count, int64_t nan_count, int64_t null_count, + NullPlacement null_placement, std::span output_indices) { + int64_t k = output_indices.size(); int64_t non_null_like_to_take = 0; int64_t nan_to_take = 0; int64_t null_to_take = 0; @@ -91,38 +93,31 @@ OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( non_null_like_to_take = std::min(k, non_null_like_count); nan_to_take = std::min(k - non_null_like_to_take, nan_count); null_to_take = std::min(k - non_null_like_to_take - nan_to_take, null_count); - // TODO.TAE make this prettier return OutputRangesByNullLikeness{ - .non_null_like_output = {output_begin, output_begin + non_null_like_to_take}, - .nan_output = {output_begin + non_null_like_to_take, - output_begin + non_null_like_to_take + nan_to_take}, - .null_output = { - output_begin + non_null_like_to_take + nan_to_take, - output_begin + non_null_like_to_take + nan_to_take + null_to_take}}; + .non_null_like_output = output_indices.subspan(0, non_null_like_to_take), + .nan_output = output_indices.subspan(non_null_like_to_take, nan_to_take), + .null_output = + output_indices.subspan(non_null_like_to_take + nan_to_take, null_to_take)}; } else { null_to_take = std::min(k, null_count); nan_to_take = std::min(k - null_to_take, nan_count); non_null_like_to_take = std::min(k - null_to_take - nan_to_take, non_null_like_count); - // TODO.TAE make this prettier return OutputRangesByNullLikeness{ - .non_null_like_output = {output_begin + null_to_take + nan_to_take, - output_begin + null_to_take + nan_to_take + - non_null_like_to_take}, - .nan_output = {output_begin + null_to_take, - output_begin + null_to_take + nan_to_take}, - .null_output = {output_begin, output_begin + null_to_take}}; + .non_null_like_output = + output_indices.subspan(null_to_take + nan_to_take, non_null_like_to_take), + .nan_output = output_indices.subspan(null_to_take, nan_to_take), + .null_output = output_indices.subspan(0, null_to_take)}; } } template void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparator cmp, std::span output_range) { - std::span heap{non_null_input_range.begin(), - non_null_input_range.begin() + output_range.size()}; + std::span heap = non_null_input_range.subspan(0, output_range.size()); std::make_heap(heap.begin(), heap.end(), cmp); - for (auto iter = non_null_input_range.begin() + output_range.size(); - iter != non_null_input_range.end(); ++iter) { - uint64_t x_index = *iter; + + std::span remaining_input = non_null_input_range.subspan(output_range.size()); + for (uint64_t x_index : remaining_input) { if (cmp(x_index, heap.front())) { std::pop_heap(heap.begin(), heap.end(), cmp); heap.back() = x_index; @@ -132,14 +127,12 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparat // fill output in reverse when destructing, // as the "worst" (next-to-would-have-been-replaced) element is at heap-top - // TODO.TAE remove these &* - uint64_t* heap_begin = &*heap.begin(); - uint64_t* heap_end = &*heap.begin() + output_range.size(); for (auto reverse_out_iter = output_range.rbegin(); reverse_out_iter != output_range.rend(); reverse_out_iter++) { - *reverse_out_iter = *heap_begin; // heap-top has the next element - std::pop_heap(heap_begin, heap_end, cmp); - --heap_end; + *reverse_out_iter = heap.front(); // heap-top has the next element + std::ranges::pop_heap(heap, cmp); + // Decrease heap-size by one + heap = heap.first(heap.size() - 1); } } @@ -158,7 +151,6 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, } template -// TODO.TAE Could merge l and output into one span now void HeapSortNonNullsToOutput(std::span non_null_input_range, const typename TypeTraits::ArrayType& arr, SortOrder order, std::span output_range) { @@ -171,6 +163,28 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, } } +struct NullNanPartitionResult { + std::span non_null_like_range; + std::span null_range; + std::span nan_range; +}; + +template +NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, + uint64_t* indices_end, + const ArrayType& values, int64_t offset, + NullPlacement null_placement) { + // Partition nulls at start (resp. end), and null-like values just before (resp. after) + NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, + values, offset, null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); + return NullNanPartitionResult{ + .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, + .null_range = {p.nulls_begin, p.nulls_end}, + .nan_range = {q.nulls_begin, q.nulls_end}}; +} + class ArraySelector : public TypeVisitor { public: ArraySelector(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -220,25 +234,22 @@ class ArraySelector : public TypeVisitor { auto* output = take_indices->template GetMutableValues(1); // From k, calculate - // l = non-null elements to take from PartitionResult - // m = null-like elements to take from PartitionResult - // k = l + m if enough elements in input + // l = non_null_like elements to take from PartitionResult + // m = nan elements to take from PartitionResult + // n = null elements to take from PartitionResult + // k = l + m + n because k was clipped to arr.length() + // And directly compute the ranges in {output, output+k} where we will need to place + // the selected elements from each group -> no longer need to track null_placement auto output_ranges = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_nulls_end - p.non_nulls_begin), - 0, // TODO.TAE it would be okay to consider these equal, but better not? - static_cast(p.nulls_end - p.nulls_begin), k_, null_placement_, output); - - if (null_placement_ == NullPlacement::AtEnd) { - HeapSortNonNullsToOutput( - {p.non_nulls_begin, p.non_nulls_end}, arr, output_ranges.non_null_like_output); - std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), - output_ranges.null_output.begin()); - } else { - std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), - output_ranges.null_output.begin()); - HeapSortNonNullsToOutput( - {p.non_nulls_begin, p.non_nulls_end}, arr, output_ranges.non_null_like_output); - } + 0, // TODO.TAE it would be okay to consider null/nan equal, but better not? + static_cast(p.nulls_end - p.nulls_begin), null_placement_, + {output, output + k_}); + + HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, + arr, output_ranges.non_null_like_output); + std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), + output_ranges.null_output.begin()); *output_ = Datum(take_indices); return Status::OK(); @@ -290,25 +301,24 @@ class ChunkedArraySelector : public TypeVisitor { VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) #undef VISIT - // template - // int64_t ComputeNanCount(){ - // using GetView = GetViewType; - // using ArrayType = typename TypeTraits::ArrayType; - // if constexpr (has_null_like_values()) { - // int64_t nan_count = 0; - // for (const auto& chunk : physical_chunks_) { - // auto values = std::make_shared(chunk->data()); - // int64_t length = values->length(); - // for(int64_t index = 0; index < length; ++index){ - // if(std::isnan(values->GetView(index))){ - // nan_count++; - // } - // } - // } - // return nan_count; - // } - // return 0; - // } + template + int64_t ComputeNanCount() { + using ArrayType = typename TypeTraits::ArrayType; + if constexpr (has_null_like_values()) { + int64_t nan_count = 0; + for (const auto& chunk : physical_chunks_) { + auto values = std::make_shared(chunk->data()); + int64_t length = values->length(); + for (int64_t index = 0; index < length; ++index) { + if (std::isnan(values->GetView(index))) { + nan_count++; + } + } + } + return nan_count; + } + return 0; + } template Status SelectKthInternal() { @@ -323,14 +333,28 @@ class ChunkedArraySelector : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } - // int64_t null_count = chunked_array_.null_count(); - // int64_t nan_count = ComputeNanCount(); - // TODO.TAE int64_t non_null_like_count = chunked_array_.length() - null_count - - // nan_count; + + ARROW_ASSIGN_OR_RAISE(auto take_indices, + MakeMutableUInt64Array(k_, ctx_->memory_pool())); + auto* output_begin = take_indices->GetMutableValues(1); + + int64_t null_count = chunked_array_.null_count(); + int64_t nan_count = ComputeNanCount(); + int64_t non_null_like_count = chunked_array_.length() - null_count - nan_count; + + auto output = calculateNumberNonNullAndNullLikesToTake( + non_null_like_count, nan_count, null_count, null_placement_, + {output_begin, output_begin + k_}); + + // Now we can independently fill the output with non_null, nan and null items. + // For non_null, we do a heap_sort, the others can just be copied until + // nan_taken = output.nan_range.size() and + // null_taken = output.null_range.size() respectively + size_t nan_taken = 0; + size_t null_taken = 0; std::function cmp; SelectKComparator comparator; - cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); @@ -338,9 +362,9 @@ class ChunkedArraySelector : public TypeVisitor { }; using HeapContainer = std::priority_queue, decltype(cmp)>; - HeapContainer heap(cmp); std::vector> chunks_holder; + uint64_t offset = 0; for (const auto& chunk : physical_chunks_) { if (chunk->length() == 0) continue; @@ -352,12 +376,29 @@ class ChunkedArraySelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - auto kth_begin = std::min(indices_begin + k_, indices_end); - uint64_t* iter = indices_begin; - for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { + const auto p = PartitionNullsAndNans( + indices_begin, indices_end, arr, 0, null_placement_); + + // First do nulls and nans + auto iter = p.null_range.begin(); + for (; iter != p.null_range.end() && null_taken < output.null_output.size(); + ++iter) { + output.null_output[null_taken] = offset + *iter; + null_taken++; + } + iter = p.nan_range.begin(); + for (; iter != p.nan_range.end() && nan_taken < output.nan_output.size(); ++iter) { + output.nan_output[nan_taken] = offset + *iter; + nan_taken++; + } + + iter = p.non_null_like_range.begin(); + for (; iter != p.non_null_like_range.end() && + heap.size() < output.non_null_like_output.size(); + ++iter) { heap.push(HeapItem{*iter, offset, &arr}); } - for (; iter != indices_end && !heap.empty(); ++iter) { + for (; iter != p.non_null_like_range.end() && !heap.empty(); ++iter) { uint64_t x_index = *iter; const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); auto top_item = heap.top(); @@ -371,16 +412,17 @@ class ChunkedArraySelector : public TypeVisitor { offset += chunk->length(); } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - auto top_item = heap.top(); - *out_cbegin = top_item.index + top_item.offset; + // We sized output.non_null_like_output to hold exactly sufficient indices, + // so the heap must have been completely filled + assert(heap.size() == output.non_null_like_output.size()); + + for (auto reverse_out_iter = output.non_null_like_output.rbegin(); + reverse_out_iter != output.non_null_like_output.rend(); reverse_out_iter++) { + *reverse_out_iter = + heap.top().index + heap.top().offset; // heap-top has the next element heap.pop(); - --out_cbegin; } + *output_ = Datum(take_indices); return Status::OK(); } @@ -395,28 +437,6 @@ class ChunkedArraySelector : public TypeVisitor { Datum* output_; }; -struct NullNanPartitionResult { - std::span non_null_like_range; - std::span null_range; - std::span nan_range; -}; - -template -NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, - uint64_t* indices_end, - const ArrayType& values, int64_t offset, - NullPlacement null_placement) { - // Partition nulls at start (resp. end), and null-like values just before (resp. after) - NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, - values, offset, null_placement); - NullPartitionResult q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); - return NullNanPartitionResult{ - .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, - .null_range = {p.nulls_begin, p.nulls_end}, - .nan_range = {q.nulls_begin, q.nulls_end}}; -} - class RecordBatchSelector { private: using ResolvedSortKey = ResolvedRecordBatchSortKey; @@ -455,13 +475,11 @@ class RecordBatchSelector { class SelectKForKey : public TypeVisitor { public: SelectKForKey(RecordBatchSelector* selector, size_t start_sort_key_index, - std::span input_indices, int64_t k_remaining, - uint64_t* output_indices) + std::span input_indices, std::span output_indices) : TypeVisitor(), selector_(selector), start_sort_key_index_(start_sort_key_index), input_indices_(input_indices), - k_remaining_(k_remaining), output_indices_(output_indices) {} private: @@ -471,23 +489,24 @@ class RecordBatchSelector { const auto& first_remaining_sort_key = selector_->sort_keys_[start_sort_key_index_]; const auto& arr = checked_cast(first_remaining_sort_key.array); - // TODO.TAE uhh this could be prettier - uint64_t* input_indices_begin = &*input_indices_.begin(); - uint64_t* input_indices_end = input_indices_begin + input_indices_.size(); + uint64_t* input_indices_begin = input_indices_.data(); + uint64_t* input_indices_end = input_indices_.data() + input_indices_.size(); const auto p = PartitionNullsAndNans( input_indices_begin, input_indices_end, arr, 0, first_remaining_sort_key.null_placement); - // From k, calculate + // From k = output_range.size(), calculate // l = non_null elements to take from PartitionResult - // m = null elements to take from PartitionResult - // k = l + m because k was clipped to num_rows() - + // m = nan elements to take from PartitionResult + // n = null elements to take from PartitionResult + // k = l + m + n because k was clipped to num_rows() + // And directly compute the ranges in output_indices_ where we will need to place + // the selected elements from each group -> no longer need to track null_placement auto output_ranges = calculateNumberNonNullAndNullLikesToTake( static_cast(p.non_null_like_range.size()), static_cast(p.nan_range.size()), - static_cast(p.null_range.size()), k_remaining_, + static_cast(p.null_range.size()), first_remaining_sort_key.null_placement, output_indices_); bool last_sort_key = start_sort_key_index_ + 1 == selector_->sort_keys_.size(); @@ -496,7 +515,6 @@ class RecordBatchSelector { if (!output_ranges.non_null_like_output.empty()) { HeapSortNonNullsToOutput(p.non_null_like_range, arr, first_remaining_sort_key.order, - // TODO.TAE remove this &* output_ranges.non_null_like_output); } if (output_ranges.nan_output.size() > 0) { @@ -521,15 +539,11 @@ class RecordBatchSelector { } if (output_ranges.nan_output.size() > 0) { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, p.nan_range, output_ranges.nan_output.size(), - // TODO.TAE remove this &* - &*output_ranges.nan_output.begin())); + start_sort_key_index_ + 1, p.nan_range, output_ranges.nan_output)); } if (output_ranges.null_output.size() > 0) { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, p.null_range, output_ranges.null_output.size(), - // TODO.TAE remove this &* - &*output_ranges.null_output.begin())); + start_sort_key_index_ + 1, p.null_range, output_ranges.null_output)); } } @@ -545,14 +559,12 @@ class RecordBatchSelector { RecordBatchSelector* selector_; size_t start_sort_key_index_; std::span input_indices_; - int64_t k_remaining_; - uint64_t* output_indices_; + std::span output_indices_; }; Status DoSelectKForKey(size_t start_sort_key_index, std::span input_indices, - int64_t k_remaining, uint64_t* output_indices) { - SelectKForKey tmp(this, start_sort_key_index, input_indices, k_remaining, - output_indices); + std::span output_indices) { + SelectKForKey tmp(this, start_sort_key_index, input_indices, output_indices); return sort_keys_.at(start_sort_key_index).type->Accept(&tmp); } @@ -571,7 +583,8 @@ class RecordBatchSelector { auto* output_indices = take_indices->template GetMutableValues(1); std::span input_indices_span(input_indices); - ARROW_RETURN_NOT_OK(DoSelectKForKey(0, input_indices_span, k_, output_indices)); + ARROW_RETURN_NOT_OK( + DoSelectKForKey(0, input_indices_span, {output_indices, output_indices + k_})); *output_ = Datum(take_indices); return arrow::Status::OK(); } From aee16e1c2f336aab4103d5ffc396b1666daccd54 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 16:35:01 +0100 Subject: [PATCH 75/83] fix test and add comment --- cpp/src/arrow/compute/kernels/select_k_test.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 4401376a86d..16e4ae712ea 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -431,9 +431,10 @@ TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNullNaN) { std::vector sort_keys{SortKey("a", SortOrder::Descending)}; auto options = SelectKOptions(10, sort_keys); options.sort_keys[0].null_placement = NullPlacement::AtStart; - this->CheckIndices(chunked_array, options, "[3, 0, 6, 4, 5, 2, 7, 8, 1]"); + // These check that nulls and Nan are sorted in a stable way, but do we want that? + this->CheckIndices(chunked_array, options, "[0, 3, 4, 6, 5, 2, 7, 8, 1]"); options.sort_keys[0].null_placement = NullPlacement::AtEnd; - this->CheckIndices(chunked_array, options, "[5, 2, 7, 8, 1, 6, 4, 3, 0]"); + this->CheckIndices(chunked_array, options, "[5, 2, 7, 8, 1, 4, 6, 0, 3]"); } template From 2ea9474222643857d68689b99f9a4b19fbf0b537 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 17:09:54 +0100 Subject: [PATCH 76/83] improve naming --- .../arrow/compute/kernels/vector_select_k.cc | 125 +++++++++--------- 1 file changed, 62 insertions(+), 63 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index ed1a89b027d..e3201301a3c 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -77,15 +77,15 @@ class SelectKComparator { }; struct OutputRangesByNullLikeness { - std::span non_null_like_output; - std::span nan_output; - std::span null_output; + std::span non_null_like_range; + std::span nan_range; + std::span null_range; }; -OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( +OutputRangesByNullLikeness CalculateOutputRangesByNullLikeness( int64_t non_null_like_count, int64_t nan_count, int64_t null_count, NullPlacement null_placement, std::span output_indices) { - int64_t k = output_indices.size(); + auto k = static_cast(output_indices.size()); int64_t non_null_like_to_take = 0; int64_t nan_to_take = 0; int64_t null_to_take = 0; @@ -94,19 +94,19 @@ OutputRangesByNullLikeness calculateNumberNonNullAndNullLikesToTake( nan_to_take = std::min(k - non_null_like_to_take, nan_count); null_to_take = std::min(k - non_null_like_to_take - nan_to_take, null_count); return OutputRangesByNullLikeness{ - .non_null_like_output = output_indices.subspan(0, non_null_like_to_take), - .nan_output = output_indices.subspan(non_null_like_to_take, nan_to_take), - .null_output = + .non_null_like_range = output_indices.subspan(0, non_null_like_to_take), + .nan_range = output_indices.subspan(non_null_like_to_take, nan_to_take), + .null_range = output_indices.subspan(non_null_like_to_take + nan_to_take, null_to_take)}; } else { null_to_take = std::min(k, null_count); nan_to_take = std::min(k - null_to_take, nan_count); non_null_like_to_take = std::min(k - null_to_take - nan_to_take, non_null_like_count); return OutputRangesByNullLikeness{ - .non_null_like_output = + .non_null_like_range = output_indices.subspan(null_to_take + nan_to_take, non_null_like_to_take), - .nan_output = output_indices.subspan(null_to_take, nan_to_take), - .null_output = output_indices.subspan(0, null_to_take)}; + .nan_range = output_indices.subspan(null_to_take, nan_to_take), + .null_range = output_indices.subspan(0, null_to_take)}; } } @@ -163,23 +163,24 @@ void HeapSortNonNullsToOutput(std::span non_null_input_range, } } -struct NullNanPartitionResult { +struct PartitionResultByNullLikeness { std::span non_null_like_range; std::span null_range; std::span nan_range; }; template -NullNanPartitionResult PartitionNullsAndNans(uint64_t* indices_begin, - uint64_t* indices_end, - const ArrayType& values, int64_t offset, - NullPlacement null_placement) { +PartitionResultByNullLikeness PartitionNullsAndNans(uint64_t* indices_begin, + uint64_t* indices_end, + const ArrayType& values, + int64_t offset, + NullPlacement null_placement) { // Partition nulls at start (resp. end), and null-like values just before (resp. after) NullPartitionResult p = PartitionNullsOnly(indices_begin, indices_end, values, offset, null_placement); NullPartitionResult q = PartitionNullLikes( p.non_nulls_begin, p.non_nulls_end, values, offset, null_placement); - return NullNanPartitionResult{ + return PartitionResultByNullLikeness{ .non_null_like_range = {q.non_nulls_begin, q.non_nulls_end}, .null_range = {p.nulls_begin, p.nulls_end}, .nan_range = {q.nulls_begin, q.nulls_end}}; @@ -226,12 +227,12 @@ class ArraySelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, null_placement_); - ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(k_, ctx_->memory_pool())); - auto* output = take_indices->template GetMutableValues(1); + auto* output_begin = take_indices->template GetMutableValues(1); + + const auto p = PartitionNullsAndNans( + indices_begin, indices_end, arr, 0, null_placement_); // From k, calculate // l = non_null_like elements to take from PartitionResult @@ -240,16 +241,16 @@ class ArraySelector : public TypeVisitor { // k = l + m + n because k was clipped to arr.length() // And directly compute the ranges in {output, output+k} where we will need to place // the selected elements from each group -> no longer need to track null_placement - auto output_ranges = calculateNumberNonNullAndNullLikesToTake( - static_cast(p.non_nulls_end - p.non_nulls_begin), - 0, // TODO.TAE it would be okay to consider null/nan equal, but better not? - static_cast(p.nulls_end - p.nulls_begin), null_placement_, - {output, output + k_}); + auto output = CalculateOutputRangesByNullLikeness( + p.non_null_like_range.size(), p.nan_range.size(), p.null_range.size(), + null_placement_, {output_begin, output_begin + k_}); - HeapSortNonNullsToOutput({p.non_nulls_begin, p.non_nulls_end}, - arr, output_ranges.non_null_like_output); - std::copy(p.nulls_begin, p.nulls_begin + output_ranges.null_output.size(), - output_ranges.null_output.begin()); + HeapSortNonNullsToOutput(p.non_null_like_range, arr, + output.non_null_like_range); + std::copy(p.nan_range.begin(), p.nan_range.begin() + output.nan_range.size(), + output.nan_range.begin()); + std::copy(p.null_range.begin(), p.null_range.begin() + output.null_range.size(), + output.null_range.begin()); *output_ = Datum(take_indices); return Status::OK(); @@ -342,14 +343,14 @@ class ChunkedArraySelector : public TypeVisitor { int64_t nan_count = ComputeNanCount(); int64_t non_null_like_count = chunked_array_.length() - null_count - nan_count; - auto output = calculateNumberNonNullAndNullLikesToTake( - non_null_like_count, nan_count, null_count, null_placement_, - {output_begin, output_begin + k_}); + auto output = CalculateOutputRangesByNullLikeness(non_null_like_count, nan_count, + null_count, null_placement_, + {output_begin, output_begin + k_}); // Now we can independently fill the output with non_null, nan and null items. // For non_null, we do a heap_sort, the others can just be copied until - // nan_taken = output.nan_range.size() and - // null_taken = output.null_range.size() respectively + // nan_taken == output.nan_range.size() and + // null_taken == output.null_range.size() respectively size_t nan_taken = 0; size_t null_taken = 0; @@ -381,20 +382,20 @@ class ChunkedArraySelector : public TypeVisitor { // First do nulls and nans auto iter = p.null_range.begin(); - for (; iter != p.null_range.end() && null_taken < output.null_output.size(); + for (; iter != p.null_range.end() && null_taken < output.null_range.size(); ++iter) { - output.null_output[null_taken] = offset + *iter; + output.null_range[null_taken] = offset + *iter; null_taken++; } iter = p.nan_range.begin(); - for (; iter != p.nan_range.end() && nan_taken < output.nan_output.size(); ++iter) { - output.nan_output[nan_taken] = offset + *iter; + for (; iter != p.nan_range.end() && nan_taken < output.nan_range.size(); ++iter) { + output.nan_range[nan_taken] = offset + *iter; nan_taken++; } iter = p.non_null_like_range.begin(); for (; iter != p.non_null_like_range.end() && - heap.size() < output.non_null_like_output.size(); + heap.size() < output.non_null_like_range.size(); ++iter) { heap.push(HeapItem{*iter, offset, &arr}); } @@ -412,12 +413,12 @@ class ChunkedArraySelector : public TypeVisitor { offset += chunk->length(); } - // We sized output.non_null_like_output to hold exactly sufficient indices, + // We sized output.non_null_like_range to hold exactly sufficient indices, // so the heap must have been completely filled - assert(heap.size() == output.non_null_like_output.size()); + assert(heap.size() == output.non_null_like_range.size()); - for (auto reverse_out_iter = output.non_null_like_output.rbegin(); - reverse_out_iter != output.non_null_like_output.rend(); reverse_out_iter++) { + for (auto reverse_out_iter = output.non_null_like_range.rbegin(); + reverse_out_iter != output.non_null_like_range.rend(); reverse_out_iter++) { *reverse_out_iter = heap.top().index + heap.top().offset; // heap-top has the next element heap.pop(); @@ -497,13 +498,13 @@ class RecordBatchSelector { first_remaining_sort_key.null_placement); // From k = output_range.size(), calculate - // l = non_null elements to take from PartitionResult + // l = non_null_like elements to take from PartitionResult // m = nan elements to take from PartitionResult // n = null elements to take from PartitionResult // k = l + m + n because k was clipped to num_rows() // And directly compute the ranges in output_indices_ where we will need to place // the selected elements from each group -> no longer need to track null_placement - auto output_ranges = calculateNumberNonNullAndNullLikesToTake( + auto output = CalculateOutputRangesByNullLikeness( static_cast(p.non_null_like_range.size()), static_cast(p.nan_range.size()), static_cast(p.null_range.size()), @@ -512,38 +513,36 @@ class RecordBatchSelector { bool last_sort_key = start_sort_key_index_ + 1 == selector_->sort_keys_.size(); if (last_sort_key) { - if (!output_ranges.non_null_like_output.empty()) { + if (!output.non_null_like_range.empty()) { HeapSortNonNullsToOutput(p.non_null_like_range, arr, first_remaining_sort_key.order, - output_ranges.non_null_like_output); + output.non_null_like_range); } - if (output_ranges.nan_output.size() > 0) { + if (output.nan_range.size() > 0) { // We have the last sort_key, can just copy over the null values - std::copy(p.nan_range.begin(), - p.nan_range.begin() + output_ranges.nan_output.size(), - output_ranges.nan_output.begin()); + std::copy(p.nan_range.begin(), p.nan_range.begin() + output.nan_range.size(), + output.nan_range.begin()); } - if (output_ranges.null_output.size() > 0) { + if (output.null_range.size() > 0) { // We have the last sort_key, can just copy over the null values - std::copy(p.null_range.begin(), - p.null_range.begin() + output_ranges.null_output.size(), - output_ranges.null_output.begin()); + std::copy(p.null_range.begin(), p.null_range.begin() + output.null_range.size(), + output.null_range.begin()); } } else { - if (!output_ranges.non_null_like_output.empty()) { + if (!output.non_null_like_range.empty()) { auto cmp = [&](uint64_t left, uint64_t right) { return selector_->comparator_.Compare(left, right, start_sort_key_index_); }; HeapSortNonNullsToOutput(p.non_null_like_range, cmp, - output_ranges.non_null_like_output); + output.non_null_like_range); } - if (output_ranges.nan_output.size() > 0) { - ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, p.nan_range, output_ranges.nan_output)); + if (output.nan_range.size() > 0) { + ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey(start_sort_key_index_ + 1, + p.nan_range, output.nan_range)); } - if (output_ranges.null_output.size() > 0) { + if (output.null_range.size() > 0) { ARROW_RETURN_NOT_OK(selector_->DoSelectKForKey( - start_sort_key_index_ + 1, p.null_range, output_ranges.null_output)); + start_sort_key_index_ + 1, p.null_range, output.null_range)); } } From 577007ef2b60bc6ba8b77018bac87562e87c83bf Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 17:30:56 +0100 Subject: [PATCH 77/83] supress deprecation warning --- cpp/src/arrow/compute/api_vector.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 3e635baf659..4edd32c114a 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -195,6 +195,7 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; +ARROW_SUPPRESS_DEPRECATION_WARNING SortOptions::SortOptions(std::vector sort_keys) : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)), @@ -210,6 +211,7 @@ SortOptions::SortOptions(const Ordering& ordering) sort_keys(ordering.sort_keys()), null_placement(ordering.null_placement()) {} constexpr char SortOptions::kTypeName[]; +ARROW_UNSUPPRESS_DEPRECATION_WARNING PartitionNthOptions::PartitionNthOptions(int64_t pivot, NullPlacement null_placement) : FunctionOptions(internal::kPartitionNthOptionsType), From 45a9419c9ebb54402199684d2461b9e4975ffd44 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 17:45:24 +0100 Subject: [PATCH 78/83] improve heap functions for clarity --- cpp/src/arrow/compute/kernels/vector_select_k.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index e3201301a3c..534466bfdd7 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include "arrow/compute/function.h" @@ -114,22 +115,21 @@ template void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparator cmp, std::span output_range) { std::span heap = non_null_input_range.subspan(0, output_range.size()); - std::make_heap(heap.begin(), heap.end(), cmp); + std::ranges::make_heap(heap, cmp); std::span remaining_input = non_null_input_range.subspan(output_range.size()); for (uint64_t x_index : remaining_input) { if (cmp(x_index, heap.front())) { - std::pop_heap(heap.begin(), heap.end(), cmp); + std::ranges::pop_heap(heap, cmp); heap.back() = x_index; - std::push_heap(heap.begin(), heap.end(), cmp); + std::ranges::push_heap(heap, cmp); } } // fill output in reverse when destructing, // as the "worst" (next-to-would-have-been-replaced) element is at heap-top - for (auto reverse_out_iter = output_range.rbegin(); - reverse_out_iter != output_range.rend(); reverse_out_iter++) { - *reverse_out_iter = heap.front(); // heap-top has the next element + for (auto& reverse_out_iter : std::ranges::reverse_view(output_range)) { + reverse_out_iter = heap.front(); // heap-top has the next element std::ranges::pop_heap(heap, cmp); // Decrease heap-size by one heap = heap.first(heap.size() - 1); From 2877ce3d78b3c02a14be1ceb6ca67938c4c24339 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 17:47:28 +0100 Subject: [PATCH 79/83] consistent style for clipping k_ to Datum length --- cpp/src/arrow/compute/kernels/vector_select_k.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 534466bfdd7..d5a14ce840b 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -219,7 +219,9 @@ class ArraySelector : public TypeVisitor { ArrayType arr(array_.data()); - k_ = std::min(k_, arr.length()); + if (k_ > arr.length()) { + k_ = arr.length(); + } std::vector indices(arr.length()); From e96c7390bdc12cfac85b3077bc66dd46e3fc334b Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Mon, 9 Feb 2026 20:11:34 +0100 Subject: [PATCH 80/83] deprecate more null_placement options --- cpp/src/arrow/acero/order_by_node.cc | 2 +- cpp/src/arrow/compute/api_vector.cc | 9 +++ cpp/src/arrow/compute/api_vector.h | 61 ++++++++++--------- cpp/src/arrow/compute/kernels/vector_rank.cc | 3 + cpp/src/arrow/compute/kernels/vector_sort.cc | 14 ++--- .../arrow/compute/kernels/vector_sort_test.cc | 6 +- 6 files changed, 56 insertions(+), 39 deletions(-) diff --git a/cpp/src/arrow/acero/order_by_node.cc b/cpp/src/arrow/acero/order_by_node.cc index 213730e6f9a..f4a5acd88d5 100644 --- a/cpp/src/arrow/acero/order_by_node.cc +++ b/cpp/src/arrow/acero/order_by_node.cc @@ -117,7 +117,7 @@ class OrderByNode : public ExecNode, public TracedNode { ARROW_ASSIGN_OR_RAISE( auto table, Table::FromRecordBatches(output_schema_, std::move(accumulation_queue_))); - SortOptions sort_options(ordering_.sort_keys(), ordering_.null_placement()); + SortOptions sort_options(ordering_); ExecContext* ctx = plan_->query_context()->exec_context(); ARROW_ASSIGN_OR_RAISE(auto indices, SortIndices(table, sort_options, ctx)); ARROW_ASSIGN_OR_RAISE(Datum sorted, diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 4edd32c114a..3a4835b688c 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -240,6 +241,7 @@ CumulativeOptions::CumulativeOptions(std::shared_ptr start, bool skip_nu skip_nulls(skip_nulls) {} constexpr char CumulativeOptions::kTypeName[]; +ARROW_SUPPRESS_DEPRECATION_WARNING RankOptions::RankOptions(std::vector sort_keys, std::optional null_placement, RankOptions::Tiebreaker tiebreaker) @@ -247,7 +249,14 @@ RankOptions::RankOptions(std::vector sort_keys, sort_keys(std::move(sort_keys)), null_placement(null_placement), tiebreaker(tiebreaker) {} +RankOptions::RankOptions(std::vector sort_keys, + RankOptions::Tiebreaker tiebreaker) + : FunctionOptions(internal::kRankOptionsType), + sort_keys(std::move(sort_keys)), + null_placement(std::nullopt), + tiebreaker(tiebreaker) {} constexpr char RankOptions::kTypeName[]; +ARROW_UNSUPPRESS_DEPRECATION_WARNING RankQuantileOptions::RankQuantileOptions(std::vector sort_keys, std::optional null_placement) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 198c1a0ef7e..a01520ff107 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -24,6 +24,7 @@ #include "arrow/compute/ordering.h" #include "arrow/result.h" #include "arrow/type_fwd.h" +#include "arrow/util/macros.h" namespace arrow { namespace compute { @@ -102,6 +103,8 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { NullPlacement null_placement; }; + +ARROW_SUPPRESS_DEPRECATION_WARNING class ARROW_EXPORT SortOptions : public FunctionOptions { public: explicit SortOptions(std::vector sort_keys = {}); @@ -121,12 +124,8 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { Ordering AsOrdering() && { return {std::move(sort_keys)}; } Ordering AsOrdering() const& { return {sort_keys}; } - /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys; - - ARROW_SUPPRESS_DEPRECATION_WARNING - // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement + /// Will be removed after deprecated null_placement has been removed std::vector GetSortKeys() const { if (!null_placement.has_value()) { return sort_keys; @@ -137,13 +136,17 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { } return overwritten_sort_keys; } - ARROW_UNSUPPRESS_DEPRECATION_WARNING + + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; // DEPRECATED(Deprecated in arrow 24.0.0, use null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys - ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") std::optional null_placement; + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") + std::optional null_placement; }; +ARROW_UNSUPPRESS_DEPRECATION_WARNING /// \brief SelectK options class ARROW_EXPORT SelectKOptions : public FunctionOptions { @@ -174,15 +177,15 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { return SelectKOptions{k, keys}; } + /// Get sort_keys + /// will be removed after null_placement has been removed from other + /// SortOptions-like structs + std::vector GetSortKeys() const { return sort_keys; } + /// The number of `k` elements to keep. int64_t k; /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; - - // DEPRECATED(will be removed after null_placement has been removed from other - // SortOptions-like structs) - /// Get sort_keys - std::vector GetSortKeys() const { return sort_keys; } }; /// \brief Rank options @@ -202,31 +205,29 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { Dense }; - explicit RankOptions(std::vector sort_keys = {}, + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") + explicit RankOptions(std::vector sort_keys, std::optional null_placement = std::nullopt, Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs explicit RankOptions(SortOrder order, - std::optional null_placement = std::nullopt, + NullPlacement null_placement, Tiebreaker tiebreaker = RankOptions::First) - : RankOptions({SortKey("", order)}, null_placement, tiebreaker) {} + : RankOptions({SortKey("", order, null_placement)}, tiebreaker) {} - explicit RankOptions(std::vector sort_keys, - Tiebreaker tiebreaker = RankOptions::First) - : RankOptions(std::move(sort_keys), std::nullopt, tiebreaker) {} + explicit RankOptions(std::vector sort_keys = {}, + Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs explicit RankOptions(SortOrder order, Tiebreaker tiebreaker = RankOptions::First) - : RankOptions({SortKey("", order)}, std::nullopt, tiebreaker) {} + : RankOptions({SortKey("", order)}, tiebreaker) {} static constexpr const char kTypeName[] = "RankOptions"; static RankOptions Defaults() { return RankOptions(); } - /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys; - - // DEPRECATED(will be removed after null_placement has been removed) + ARROW_SUPPRESS_DEPRECATION_WARNING /// Get sort_keys with overwritten null_placement + /// Will be removed after deprecated null_placement has been removed std::vector GetSortKeys() const { if (!null_placement.has_value()) { return sort_keys; @@ -237,11 +238,15 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { } return overwritten_sort_keys; } + ARROW_UNSUPPRESS_DEPRECATION_WARNING + + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys - std::optional null_placement; + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") std::optional null_placement; /// Tiebreaker for dealing with equal values in ranks Tiebreaker tiebreaker; }; @@ -261,11 +266,8 @@ class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { static constexpr const char kTypeName[] = "RankQuantileOptions"; static RankQuantileOptions Defaults() { return RankQuantileOptions(); } - /// Column key(s) to order by and how to order by these sort keys. - std::vector sort_keys; - - // DEPRECATED(will be removed after null_placement has been removed) /// Get sort_keys with overwritten null_placement + /// Will be removed after deprecated null_placement has been removed std::vector GetSortKeys() const { if (!null_placement.has_value()) { return sort_keys; @@ -277,6 +279,9 @@ class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { return overwritten_sort_keys; } + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index 1f17fc285e0..6b7717936c5 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -23,6 +23,7 @@ #include "arrow/compute/registry.h" #include "arrow/compute/registry_internal.h" #include "arrow/util/logging_internal.h" +#include "arrow/util/macros.h" #include "arrow/util/math_internal.h" namespace arrow::compute::internal { @@ -352,9 +353,11 @@ class RankMetaFunctionBase : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } + ARROW_SUPPRESS_DEPRECATION_WARNING if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } + ARROW_UNSUPPRESS_DEPRECATION_WARNING int64_t length = input.length(); ARROW_ASSIGN_OR_RAISE(auto indices, diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 0b9d8939a51..ac9c5068c13 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -22,6 +22,7 @@ #include "arrow/compute/registry.h" #include "arrow/compute/registry_internal.h" #include "arrow/util/logging_internal.h" +#include "arrow/util/macros.h" namespace arrow { @@ -957,8 +958,6 @@ class SortIndicesMetaFunction : public MetaFunction { chunked_array->length()); } -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; @@ -967,17 +966,15 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } - // TODO.TAE this member is deprecated. Is there a way to implement it without it? +ARROW_SUPPRESS_DEPRECATION_WARNING if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } +ARROW_UNSUPPRESS_DEPRECATION_WARNING ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } -#pragma clang diagnostic pop -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wdeprecated-declarations" Result SortIndices(const ChunkedArray& chunked_array, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; @@ -986,11 +983,11 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } - // TODO.TAE this member is deprecated. Is there a way to implement it without it? - // Ah the method is only private?? +ARROW_SUPPRESS_DEPRECATION_WARNING if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } +ARROW_UNSUPPRESS_DEPRECATION_WARNING auto out_type = uint64(); auto length = chunked_array.length(); @@ -1008,7 +1005,6 @@ class SortIndicesMetaFunction : public MetaFunction { SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, null_placement)); return Datum(out); } -#pragma clang diagnostic pop Result SortIndices(const RecordBatch& batch, const SortOptions& options, ExecContext* ctx) const { diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index d5eae9c7024..ba76083680f 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1221,9 +1221,11 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { for (auto overwrite_null_placement : AllOptionalNullPlacements()){ for (auto null_placement : AllNullPlacements()) { + ARROW_SUPPRESS_DEPRECATION_WARNING SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), SortKey("b", SortOrder::Descending, null_placement)}, overwrite_null_placement); + ARROW_UNSUPPRESS_DEPRECATION_WARNING AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } @@ -1273,7 +1275,7 @@ TEST_F(TestRecordBatchSortIndices, MixedNullOrdering) { SortKey("a", SortOrder::Ascending, NullPlacement::AtEnd), SortKey("b", SortOrder::Descending, NullPlacement::AtEnd)}; - SortOptions options(sort_keys, std::nullopt); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); options.sort_keys.at(0).null_placement = NullPlacement::AtStart; @@ -2150,7 +2152,9 @@ TEST_P(TestTableSortIndicesRandom, Sort) { auto table = Table::Make(schema, std::move(columns)); for (auto overwrite_null_placement : AllOptionalNullPlacements()) { ARROW_SCOPED_TRACE("overwrite_null_placement = ", overwrite_null_placement); + ARROW_SUPPRESS_DEPRECATION_WARNING options.null_placement = overwrite_null_placement; + ARROW_UNSUPPRESS_DEPRECATION_WARNING ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } From 9f7f748bc07cad1495d1ae960624655ee203b1c6 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Tue, 10 Feb 2026 08:35:25 +0100 Subject: [PATCH 81/83] remove superfluous comment --- cpp/src/arrow/compute/kernels/select_k_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 16e4ae712ea..5dcd6956f46 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -422,7 +422,6 @@ TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNullNaN) { options.sort_keys[0].null_placement = NullPlacement::AtEnd; expected = ChunkedArrayFromJSON(float64(), {"[10, 3, 2]"}); this->Check(chunked_array, options, expected); - // TODO.TAE more CheckIndices? } TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNullNaN) { From 01fb5f27ffda0f6eded41168785d56f8952ac081 Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Tue, 17 Feb 2026 20:27:41 +0100 Subject: [PATCH 82/83] add check against empty output span --- cpp/src/arrow/compute/kernels/vector_select_k.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index d5a14ce840b..4376f25526a 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -114,6 +114,9 @@ OutputRangesByNullLikeness CalculateOutputRangesByNullLikeness( template void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparator cmp, std::span output_range) { + if(output_range.empty()){ + return; + } std::span heap = non_null_input_range.subspan(0, output_range.size()); std::ranges::make_heap(heap, cmp); From 7b59c4ffaabfcd4b5780bc3b90f1549d3f032fcc Mon Sep 17 00:00:00 2001 From: Alexander Taepper Date: Tue, 17 Feb 2026 21:07:24 +0100 Subject: [PATCH 83/83] formatting --- cpp/src/arrow/compute/api_vector.h | 11 +++++------ cpp/src/arrow/compute/kernels/vector_select_k.cc | 2 +- cpp/src/arrow/compute/kernels/vector_sort.cc | 8 ++++---- cpp/src/arrow/compute/kernels/vector_sort_test.cc | 10 +++++----- cpp/src/arrow/compute/ordering.h | 6 ++---- 5 files changed, 17 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index a01520ff107..6ef584f8fa5 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -103,7 +103,6 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { NullPlacement null_placement; }; - ARROW_SUPPRESS_DEPRECATION_WARNING class ARROW_EXPORT SortOptions : public FunctionOptions { public: @@ -143,7 +142,7 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { // DEPRECATED(Deprecated in arrow 24.0.0, use null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys - ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") std::optional null_placement; }; ARROW_UNSUPPRESS_DEPRECATION_WARNING @@ -205,13 +204,12 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { Dense }; - ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") explicit RankOptions(std::vector sort_keys, std::optional null_placement = std::nullopt, Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs - explicit RankOptions(SortOrder order, - NullPlacement null_placement, + explicit RankOptions(SortOrder order, NullPlacement null_placement, Tiebreaker tiebreaker = RankOptions::First) : RankOptions({SortKey("", order, null_placement)}, tiebreaker) {} @@ -246,7 +244,8 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end /// Will overwrite null ordering of sort keys - ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") std::optional null_placement; + ARROW_DEPRECATED("Deprecated in arrow 24.0.0, use null_placement in sort_keys instead") + std::optional null_placement; /// Tiebreaker for dealing with equal values in ranks Tiebreaker tiebreaker; }; diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 4376f25526a..81f7aad0110 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -114,7 +114,7 @@ OutputRangesByNullLikeness CalculateOutputRangesByNullLikeness( template void HeapSortNonNullsToOutput(std::span non_null_input_range, Comparator cmp, std::span output_range) { - if(output_range.empty()){ + if (output_range.empty()) { return; } std::span heap = non_null_input_range.subspan(0, output_range.size()); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index ac9c5068c13..57abc358ccc 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -966,11 +966,11 @@ class SortIndicesMetaFunction : public MetaFunction { order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } -ARROW_SUPPRESS_DEPRECATION_WARNING + ARROW_SUPPRESS_DEPRECATION_WARNING if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } -ARROW_UNSUPPRESS_DEPRECATION_WARNING + ARROW_UNSUPPRESS_DEPRECATION_WARNING ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } @@ -983,11 +983,11 @@ ARROW_UNSUPPRESS_DEPRECATION_WARNING order = options.sort_keys[0].order; null_placement = options.sort_keys[0].null_placement; } -ARROW_SUPPRESS_DEPRECATION_WARNING + ARROW_SUPPRESS_DEPRECATION_WARNING if (options.null_placement.has_value()) { null_placement = options.null_placement.value(); } -ARROW_UNSUPPRESS_DEPRECATION_WARNING + ARROW_UNSUPPRESS_DEPRECATION_WARNING auto out_type = uint64(); auto length = chunked_array.length(); diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index ba76083680f..8da8b55be9c 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -77,10 +77,9 @@ std::ostream& operator<<(std::ostream& os, NullPlacement null_placement) { } std::ostream& operator<<(std::ostream& os, std::optional null_placement) { - if(null_placement.has_value()){ + if (null_placement.has_value()) { os << null_placement.value(); - } - else { + } else { os << "None"; } return os; @@ -1219,7 +1218,7 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { {"a": 1, "b": 3} ])"); - for (auto overwrite_null_placement : AllOptionalNullPlacements()){ + for (auto overwrite_null_placement : AllOptionalNullPlacements()) { for (auto null_placement : AllNullPlacements()) { ARROW_SUPPRESS_DEPRECATION_WARNING SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), @@ -2112,7 +2111,8 @@ TEST_P(TestTableSortIndicesRandom, Sort) { } } std::shuffle(sort_keys.begin(), sort_keys.end(), engine); - sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order(), generate_null_placement()); + sort_keys.emplace(sort_keys.begin(), first_sort_key_name, generate_order(), + generate_null_placement()); sort_keys.erase(sort_keys.begin() + n_sort_keys, sort_keys.end()); ASSERT_EQ(sort_keys.size(), n_sort_keys); diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index dc1f8cf29d1..efb7fd81034 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -63,12 +63,10 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { class ARROW_EXPORT Ordering : public util::EqualityComparable { public: - Ordering(std::vector sort_keys) - : sort_keys_(std::move(sort_keys)) {} + explicit Ordering(std::vector sort_keys) : sort_keys_(std::move(sort_keys)) {} // DEPRECATED(will be removed after removing null_placement from Ordering) - Ordering(std::vector sort_keys, - std::optional null_placement) + Ordering(std::vector sort_keys, std::optional null_placement) : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} /// true if data ordered by other is also ordered by this ///