From 7b8ab8c5b6c34d4210309d2877fc230ead073c8b Mon Sep 17 00:00:00 2001 From: blaginin Date: Mon, 2 Feb 2026 13:17:07 +0000 Subject: [PATCH] List contains over scalars Co-authored-by: Claude Signed-off-by: blaginin --- vortex-array/src/expr/exprs/list_contains.rs | 53 ++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/vortex-array/src/expr/exprs/list_contains.rs b/vortex-array/src/expr/exprs/list_contains.rs index e0811d81400..e08948fbc9d 100644 --- a/vortex-array/src/expr/exprs/list_contains.rs +++ b/vortex-array/src/expr/exprs/list_contains.rs @@ -8,6 +8,7 @@ use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; +use vortex_scalar::Scalar; use vortex_session::VortexSession; use crate::ArrayRef; @@ -105,6 +106,13 @@ impl VTable for ListContains { .try_into() .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?; + if let Some(list_scalar) = list_array.as_constant() + && let Some(value_scalar) = value_array.as_constant() + { + let result = compute_contains_scalar(&list_scalar, &value_scalar)?; + return Ok(ExecutionResult::constant(result, args.row_count)); + } + compute_list_contains(list_array.as_ref(), value_array.as_ref())?.execute(args.ctx) } @@ -158,6 +166,23 @@ impl VTable for ListContains { } } +fn compute_contains_scalar(list: &Scalar, needle: &Scalar) -> VortexResult { + let nullability = list.dtype().nullability() | needle.dtype().nullability(); + + // Handle null list or null needle + if list.is_null() || needle.is_null() { + return Ok(Scalar::null(DType::Bool(nullability))); + } + + let list_scalar = list.as_list(); + let elements = list_scalar + .elements() + .ok_or_else(|| vortex_err!("Expected non-null list"))?; + + let contains = elements.iter().any(|elem| elem == needle); + Ok(Scalar::bool(contains, nullability)) +} + /// Creates an expression that checks if a value is contained in a list. /// /// Returns a boolean array indicating whether the value appears in each list. @@ -379,4 +404,32 @@ mod tests { let expr2 = list_contains(root(), lit(42)); assert_eq!(expr2.to_string(), "contains($, 42i32)"); } + + #[test] + pub fn test_constant_scalars() { + let arr = test_array(); + + // Both list and needle are constants - should use scalar optimization + let list_scalar = Scalar::list( + Arc::new(DType::Primitive(I32, Nullability::NonNullable)), + vec![1.into(), 2.into(), 3.into()], + Nullability::NonNullable, + ); + + // Test contains true + let expr = list_contains(lit(list_scalar.clone()), lit(2i32)); + let result = arr.apply(&expr).unwrap(); + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::bool(true, Nullability::NonNullable) + ); + + // Test contains false + let expr = list_contains(lit(list_scalar), lit(42i32)); + let result = arr.apply(&expr).unwrap(); + assert_eq!( + result.scalar_at(0).unwrap(), + Scalar::bool(false, Nullability::NonNullable) + ); + } }