From a1a2d19ab1a7242aa9ce1db2c04e0a89a5b8e698 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 19 Mar 2026 14:09:30 -0700 Subject: [PATCH 1/2] Add slicing to BufferHandle Signed-off-by: Nicholas Gates --- Cargo.lock | 1 + encodings/bytebool/Cargo.toml | 1 + encodings/bytebool/public-api.lock | 4 ++ encodings/bytebool/src/rules.rs | 2 + encodings/bytebool/src/slice.rs | 20 +++++++ vortex-array/public-api.lock | 36 ++++++++++++ .../src/arrays/decimal/compute/rules.rs | 29 ++++++++++ .../src/arrays/primitive/compute/rules.rs | 2 + .../src/arrays/primitive/compute/slice.rs | 22 ++++++++ .../src/arrays/varbinview/compute/rules.rs | 2 + .../src/arrays/varbinview/compute/slice.rs | 22 ++++++++ vortex-layout/src/layouts/flat/writer.rs | 55 ++++++++++--------- 12 files changed, 170 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8a09716070c..72b187dd9e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9914,6 +9914,7 @@ dependencies = [ "vortex-array", "vortex-buffer", "vortex-error", + "vortex-mask", "vortex-session", ] diff --git a/encodings/bytebool/Cargo.toml b/encodings/bytebool/Cargo.toml index 53197452ff5..bca2ed71aa9 100644 --- a/encodings/bytebool/Cargo.toml +++ b/encodings/bytebool/Cargo.toml @@ -21,6 +21,7 @@ num-traits = { workspace = true } vortex-array = { workspace = true } vortex-buffer = { workspace = true } vortex-error = { workspace = true } +vortex-mask = { workspace = true } vortex-session = { workspace = true } [dev-dependencies] diff --git a/encodings/bytebool/public-api.lock b/encodings/bytebool/public-api.lock index 192d025cf34..a7af1f45d43 100644 --- a/encodings/bytebool/public-api.lock +++ b/encodings/bytebool/public-api.lock @@ -14,6 +14,10 @@ impl vortex_array::arrays::dict::take::TakeExecute for vortex_bytebool::ByteBool pub fn vortex_bytebool::ByteBool::take(array: &vortex_bytebool::ByteBoolArray, indices: &vortex_array::array::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::kernel::FilterReduce for vortex_bytebool::ByteBool + +pub fn vortex_bytebool::ByteBool::filter(array: &vortex_bytebool::ByteBoolArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_bytebool::ByteBool pub fn vortex_bytebool::ByteBool::slice(array: &vortex_bytebool::ByteBoolArray, range: core::ops::range::Range) -> vortex_error::VortexResult> diff --git a/encodings/bytebool/src/rules.rs b/encodings/bytebool/src/rules.rs index f67d3567326..b4dc61fa1c7 100644 --- a/encodings/bytebool/src/rules.rs +++ b/encodings/bytebool/src/rules.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_array::arrays::filter::FilterReduceAdaptor; use vortex_array::arrays::slice::SliceReduceAdaptor; use vortex_array::optimizer::rules::ParentRuleSet; use vortex_array::scalar_fn::fns::cast::CastReduceAdaptor; @@ -12,4 +13,5 @@ pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(ByteBool)), ParentRuleSet::lift(&MaskReduceAdaptor(ByteBool)), ParentRuleSet::lift(&SliceReduceAdaptor(ByteBool)), + ParentRuleSet::lift(&FilterReduceAdaptor(ByteBool)), ]); diff --git a/encodings/bytebool/src/slice.rs b/encodings/bytebool/src/slice.rs index c80be024fd1..8cf24bda5e4 100644 --- a/encodings/bytebool/src/slice.rs +++ b/encodings/bytebool/src/slice.rs @@ -5,9 +5,11 @@ use std::ops::Range; use vortex_array::ArrayRef; use vortex_array::IntoArray; +use vortex_array::arrays::filter::FilterReduce; use vortex_array::arrays::slice::SliceReduce; use vortex_array::vtable::ValidityHelper; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::ByteBool; use crate::ByteBoolArray; @@ -23,3 +25,21 @@ impl SliceReduce for ByteBool { )) } } + +impl FilterReduce for ByteBool { + fn filter(array: &ByteBoolArray, mask: &Mask) -> VortexResult> { + let ranges: Vec> = mask + .slices() + .unwrap_or_else(|| unreachable!(), || unreachable!()) + .iter() + .map(|&(s, e)| s..e) + .collect(); + Ok(Some( + ByteBoolArray::new( + array.buffer().filter_typed::(&ranges)?, + array.validity().filter(mask)?, + ) + .into_array(), + )) + } +} diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 67f1955e17c..c8a15145e59 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -1434,6 +1434,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Decimal pub fn vortex_array::arrays::Decimal::take(array: &vortex_array::arrays::DecimalArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Decimal + +pub fn vortex_array::arrays::Decimal::filter(array: &vortex_array::arrays::DecimalArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::Decimal pub fn vortex_array::arrays::Decimal::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> @@ -2386,6 +2390,10 @@ impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Consta pub fn vortex_array::arrays::Constant::filter(array: &vortex_array::arrays::ConstantArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Decimal + +pub fn vortex_array::arrays::Decimal::filter(array: &vortex_array::arrays::DecimalArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Extension pub fn vortex_array::arrays::Extension::filter(array: &vortex_array::arrays::ExtensionArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> @@ -2394,6 +2402,14 @@ impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Masked pub fn vortex_array::arrays::Masked::filter(array: &vortex_array::arrays::MaskedArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Primitive + +pub fn vortex_array::arrays::Primitive::filter(array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::VarBinView + +pub fn vortex_array::arrays::VarBinView::filter(array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::dict::Dict pub fn vortex_array::arrays::dict::Dict::filter(array: &vortex_array::arrays::dict::DictArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> @@ -3256,6 +3272,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Primitive pub fn vortex_array::arrays::Primitive::take(array: &vortex_array::arrays::PrimitiveArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Primitive + +pub fn vortex_array::arrays::Primitive::filter(array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::Primitive pub fn vortex_array::arrays::Primitive::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> @@ -4662,6 +4682,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::VarBinVie pub fn vortex_array::arrays::VarBinView::take(array: &vortex_array::arrays::VarBinViewArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::VarBinView + +pub fn vortex_array::arrays::VarBinView::filter(array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::VarBinView pub fn vortex_array::arrays::VarBinView::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> @@ -5478,6 +5502,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Decimal pub fn vortex_array::arrays::Decimal::take(array: &vortex_array::arrays::DecimalArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Decimal + +pub fn vortex_array::arrays::Decimal::filter(array: &vortex_array::arrays::DecimalArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::Decimal pub fn vortex_array::arrays::Decimal::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> @@ -6818,6 +6846,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::Primitive pub fn vortex_array::arrays::Primitive::take(array: &vortex_array::arrays::PrimitiveArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Primitive + +pub fn vortex_array::arrays::Primitive::filter(array: &vortex_array::arrays::PrimitiveArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::Primitive pub fn vortex_array::arrays::Primitive::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> @@ -7854,6 +7886,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::VarBinVie pub fn vortex_array::arrays::VarBinView::take(array: &vortex_array::arrays::VarBinViewArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::VarBinView + +pub fn vortex_array::arrays::VarBinView::filter(array: &vortex_array::arrays::VarBinViewArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::VarBinView pub fn vortex_array::arrays::VarBinView::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> diff --git a/vortex-array/src/arrays/decimal/compute/rules.rs b/vortex-array/src/arrays/decimal/compute/rules.rs index a7d7dab51e0..f84df1dac95 100644 --- a/vortex-array/src/arrays/decimal/compute/rules.rs +++ b/vortex-array/src/arrays/decimal/compute/rules.rs @@ -4,6 +4,7 @@ use std::ops::Range; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; @@ -11,6 +12,8 @@ use crate::arrays::Decimal; use crate::arrays::DecimalArray; use crate::arrays::Masked; use crate::arrays::MaskedArray; +use crate::arrays::filter::FilterReduce; +use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduce; use crate::arrays::slice::SliceReduceAdaptor; use crate::match_each_decimal_value_type; @@ -23,6 +26,7 @@ pub(crate) static RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&DecimalMaskedValidityRule), ParentRuleSet::lift(&MaskReduceAdaptor(Decimal)), ParentRuleSet::lift(&SliceReduceAdaptor(Decimal)), + ParentRuleSet::lift(&FilterReduceAdaptor(Decimal)), ]); /// Rule to push down validity masking from MaskedArray parent into DecimalArray child. @@ -72,3 +76,28 @@ impl SliceReduce for Decimal { Ok(Some(result)) } } + +impl FilterReduce for Decimal { + fn filter(array: &DecimalArray, mask: &Mask) -> VortexResult> { + let ranges: Vec> = mask + .slices() + .unwrap_or_else(|| unreachable!(), || unreachable!()) + .iter() + .map(|&(s, e)| s..e) + .collect(); + let result = match_each_decimal_value_type!(array.values_type(), |D| { + // SAFETY: Filtering preserves all DecimalArray invariants — values within + // precision bounds remain valid, and we correctly filter the validity. + unsafe { + DecimalArray::new_unchecked_handle( + array.buffer_handle().filter_typed::(&ranges)?, + array.values_type(), + array.decimal_dtype(), + array.validity().filter(mask)?, + ) + } + .into_array() + }); + Ok(Some(result)) + } +} diff --git a/vortex-array/src/arrays/primitive/compute/rules.rs b/vortex-array/src/arrays/primitive/compute/rules.rs index df6eb35d888..f3a423cea20 100644 --- a/vortex-array/src/arrays/primitive/compute/rules.rs +++ b/vortex-array/src/arrays/primitive/compute/rules.rs @@ -9,6 +9,7 @@ use crate::arrays::Masked; use crate::arrays::MaskedArray; use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; +use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduceAdaptor; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; @@ -19,6 +20,7 @@ pub(crate) const RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&PrimitiveMaskedValidityRule), ParentRuleSet::lift(&MaskReduceAdaptor(Primitive)), ParentRuleSet::lift(&SliceReduceAdaptor(Primitive)), + ParentRuleSet::lift(&FilterReduceAdaptor(Primitive)), ]); /// Rule to push down validity masking from MaskedArray parent into PrimitiveArray child. diff --git a/vortex-array/src/arrays/primitive/compute/slice.rs b/vortex-array/src/arrays/primitive/compute/slice.rs index 2844163e557..38115d9f527 100644 --- a/vortex-array/src/arrays/primitive/compute/slice.rs +++ b/vortex-array/src/arrays/primitive/compute/slice.rs @@ -4,11 +4,13 @@ use std::ops::Range; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; +use crate::arrays::filter::FilterReduce; use crate::arrays::slice::SliceReduce; use crate::dtype::NativePType; use crate::match_each_native_ptype; @@ -27,3 +29,23 @@ impl SliceReduce for Primitive { Ok(Some(result)) } } + +impl FilterReduce for Primitive { + fn filter(array: &PrimitiveArray, mask: &Mask) -> VortexResult> { + let ranges: Vec> = mask + .slices() + .unwrap_or_else(|| unreachable!(), || unreachable!()) + .iter() + .map(|&(s, e)| s..e) + .collect(); + let result = match_each_native_ptype!(array.ptype(), |T| { + PrimitiveArray::from_buffer_handle( + array.buffer_handle().filter_typed::(&ranges)?, + T::PTYPE, + array.validity().filter(mask)?, + ) + .into_array() + }); + Ok(Some(result)) + } +} diff --git a/vortex-array/src/arrays/varbinview/compute/rules.rs b/vortex-array/src/arrays/varbinview/compute/rules.rs index 5ec24dca7de..3a7b98cd5c5 100644 --- a/vortex-array/src/arrays/varbinview/compute/rules.rs +++ b/vortex-array/src/arrays/varbinview/compute/rules.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors use crate::arrays::VarBinView; +use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; use crate::scalar_fn::fns::cast::CastReduceAdaptor; @@ -10,4 +11,5 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&CastReduceAdaptor(VarBinView)), ParentRuleSet::lift(&MaskReduceAdaptor(VarBinView)), ParentRuleSet::lift(&SliceReduceAdaptor(VarBinView)), + ParentRuleSet::lift(&FilterReduceAdaptor(VarBinView)), ]); diff --git a/vortex-array/src/arrays/varbinview/compute/slice.rs b/vortex-array/src/arrays/varbinview/compute/slice.rs index 02582841601..47082ea1add 100644 --- a/vortex-array/src/arrays/varbinview/compute/slice.rs +++ b/vortex-array/src/arrays/varbinview/compute/slice.rs @@ -5,11 +5,13 @@ use std::ops::Range; use std::sync::Arc; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::VarBinView; use crate::arrays::VarBinViewArray; +use crate::arrays::filter::FilterReduce; use crate::arrays::slice::SliceReduce; use crate::arrays::varbinview::BinaryView; @@ -28,3 +30,23 @@ impl SliceReduce for VarBinView { )) } } + +impl FilterReduce for VarBinView { + fn filter(array: &VarBinViewArray, mask: &Mask) -> VortexResult> { + let ranges: Vec> = mask + .slices() + .unwrap_or_else(|| unreachable!(), || unreachable!()) + .iter() + .map(|&(s, e)| s..e) + .collect(); + Ok(Some( + VarBinViewArray::new_handle( + array.views_handle().filter_typed::(&ranges)?, + Arc::clone(array.buffers()), + array.dtype().clone(), + array.validity()?.filter(mask)?, + ) + .into_array(), + )) + } +} diff --git a/vortex-layout/src/layouts/flat/writer.rs b/vortex-layout/src/layouts/flat/writer.rs index 0e83f717f49..2b3f06d7944 100644 --- a/vortex-layout/src/layouts/flat/writer.rs +++ b/vortex-layout/src/layouts/flat/writer.rs @@ -207,6 +207,7 @@ mod tests { use vortex_array::arrays::Primitive; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; + use vortex_array::assert_arrays_eq; use vortex_array::builders::ArrayBuilder; use vortex_array::builders::VarBinViewBuilder; use vortex_array::dtype::DType; @@ -413,40 +414,42 @@ mod tests { } #[test] - fn flat_invalid_array_fails() -> VortexResult<()> { + fn flat_filter_array_reduces_to_primitive() -> VortexResult<()> { block_on(|handle| async { let prim: PrimitiveArray = (0..10).collect(); let filter = prim.filter(Mask::from_indices(10, vec![2, 3]))?; let ctx = ArrayContext::empty(); - // Write the array into a byte buffer. - let (layout, _segments) = { - let segments = Arc::new(TestSegments::default()); - let (ptr, eof) = SequenceId::root().split(); - // Only allow primitive encodings - filter arrays should fail. - let allowed = ArrayRegistry::default(); - allowed.register(Primitive::ID, Primitive); - let layout = FlatLayoutStrategy::default() - .with_allow_encodings(allowed) - .write_stream( - ctx, - segments.clone(), - filter.to_array_stream().sequenced(ptr), - eof, - handle, - ) - .await; + // FilterReduce reduces FilterArray(PrimitiveArray) → PrimitiveArray during + // optimization, so the write should succeed even with only Primitive allowed. + let segments = Arc::new(TestSegments::default()); + let (ptr, eof) = SequenceId::root().split(); + let allowed = ArrayRegistry::default(); + allowed.register(Primitive::ID, Primitive); + let layout = FlatLayoutStrategy::default() + .with_allow_encodings(allowed) + .write_stream( + ctx, + segments.clone(), + filter.to_array_stream().sequenced(ptr), + eof, + handle, + ) + .await?; - (layout, segments) - }; + let result = layout + .new_reader("".into(), segments, &SESSION)? + .projection_evaluation( + &(0..layout.row_count()), + &root(), + MaskFuture::new_true(layout.row_count().try_into()?), + )? + .await?; - let err = layout.expect_err("expected error"); - assert!( - err.to_string() - .contains("normalize forbids encoding (vortex.filter)"), - "unexpected error: {err}" - ); + let expected = + PrimitiveArray::new(buffer![2i32, 3], Validity::NonNullable).into_array(); + assert_arrays_eq!(result, expected); Ok(()) }) From a804764c3d2ac4d6898116fe16b42df50c852ede Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 19 Mar 2026 14:14:30 -0700 Subject: [PATCH 2/2] Add slicing to BufferHandle Signed-off-by: Nicholas Gates --- vortex-array/public-api.lock | 12 ++++++ .../arrays/fixed_size_list/compute/rules.rs | 2 + .../arrays/fixed_size_list/compute/slice.rs | 38 +++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index c8a15145e59..92520709be3 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -2398,6 +2398,10 @@ impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Extens pub fn vortex_array::arrays::Extension::filter(array: &vortex_array::arrays::ExtensionArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::FixedSizeList + +pub fn vortex_array::arrays::FixedSizeList::filter(array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::Masked pub fn vortex_array::arrays::Masked::filter(array: &vortex_array::arrays::MaskedArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> @@ -2434,6 +2438,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::FixedSize pub fn vortex_array::arrays::FixedSizeList::take(array: &vortex_array::arrays::FixedSizeListArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::FixedSizeList + +pub fn vortex_array::arrays::FixedSizeList::filter(array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::FixedSizeList pub fn vortex_array::arrays::FixedSizeList::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> @@ -6124,6 +6132,10 @@ impl vortex_array::arrays::dict::TakeExecute for vortex_array::arrays::FixedSize pub fn vortex_array::arrays::FixedSizeList::take(array: &vortex_array::arrays::FixedSizeListArray, indices: &vortex_array::ArrayRef, ctx: &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult> +impl vortex_array::arrays::filter::FilterReduce for vortex_array::arrays::FixedSizeList + +pub fn vortex_array::arrays::FixedSizeList::filter(array: &vortex_array::arrays::FixedSizeListArray, mask: &vortex_mask::Mask) -> vortex_error::VortexResult> + impl vortex_array::arrays::slice::SliceReduce for vortex_array::arrays::FixedSizeList pub fn vortex_array::arrays::FixedSizeList::slice(array: &Self::Array, range: core::ops::range::Range) -> vortex_error::VortexResult> diff --git a/vortex-array/src/arrays/fixed_size_list/compute/rules.rs b/vortex-array/src/arrays/fixed_size_list/compute/rules.rs index da1d91423e2..149b4bfe6f3 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/rules.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/rules.rs @@ -2,6 +2,7 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use crate::arrays::FixedSizeList; +use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduceAdaptor; use crate::optimizer::rules::ParentRuleSet; use crate::scalar_fn::fns::cast::CastReduceAdaptor; @@ -11,4 +12,5 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new ParentRuleSet::lift(&CastReduceAdaptor(FixedSizeList)), ParentRuleSet::lift(&MaskReduceAdaptor(FixedSizeList)), ParentRuleSet::lift(&SliceReduceAdaptor(FixedSizeList)), + ParentRuleSet::lift(&FilterReduceAdaptor(FixedSizeList)), ]); diff --git a/vortex-array/src/arrays/fixed_size_list/compute/slice.rs b/vortex-array/src/arrays/fixed_size_list/compute/slice.rs index 9db5d8a7c08..1b06fb7651f 100644 --- a/vortex-array/src/arrays/fixed_size_list/compute/slice.rs +++ b/vortex-array/src/arrays/fixed_size_list/compute/slice.rs @@ -4,11 +4,13 @@ use std::ops::Range; use vortex_error::VortexResult; +use vortex_mask::Mask; use crate::ArrayRef; use crate::IntoArray; use crate::arrays::FixedSizeList; use crate::arrays::FixedSizeListArray; +use crate::arrays::filter::FilterReduce; use crate::arrays::slice::SliceReduce; use crate::vtable::ValidityHelper; @@ -33,3 +35,39 @@ impl SliceReduce for FixedSizeList { )) } } + +impl FilterReduce for FixedSizeList { + fn filter(array: &FixedSizeListArray, mask: &Mask) -> VortexResult> { + let list_size = array.list_size() as usize; + let new_len = mask.true_count(); + + let filtered_elements = if list_size == 0 { + // Degenerate case: elements array is empty regardless of filter. + array.elements().clone() + } else { + let elements_len = array.elements().len(); + let expanded_slices: Vec<(usize, usize)> = mask + .slices() + .unwrap_or_else(|| unreachable!(), || unreachable!()) + .iter() + .map(|&(s, e)| (s * list_size, e * list_size)) + .collect(); + let elements_mask = Mask::from_slices(elements_len, expanded_slices); + array.elements().filter(elements_mask)? + }; + + // SAFETY: Filtering preserves FixedSizeListArray invariants — each selected list's + // elements are contiguously preserved, maintaining elements.len() == new_len * list_size. + Ok(Some( + unsafe { + FixedSizeListArray::new_unchecked( + filtered_elements, + array.list_size(), + array.validity().filter(mask)?, + new_len, + ) + } + .into_array(), + )) + } +}