Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions vortex-array/src/expr/exprs/list_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use vortex_dtype::DType;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_scalar::Scalar;
use vortex_session::VortexSession;

use crate::ArrayRef;
Expand Down Expand Up @@ -105,6 +106,13 @@ impl VTable for ListContains {
.try_into()
.map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;

if let Some(list_scalar) = list_array.as_constant()
&& let Some(value_scalar) = value_array.as_constant()
{
let result = compute_contains_scalar(&list_scalar, &value_scalar)?;
return Ok(ExecutionResult::constant(result, args.row_count));
}

compute_list_contains(list_array.as_ref(), value_array.as_ref())?.execute(args.ctx)
}

Expand Down Expand Up @@ -158,6 +166,23 @@ impl VTable for ListContains {
}
}

fn compute_contains_scalar(list: &Scalar, needle: &Scalar) -> VortexResult<Scalar> {
let nullability = list.dtype().nullability() | needle.dtype().nullability();

// Handle null list or null needle
if list.is_null() || needle.is_null() {
return Ok(Scalar::null(DType::Bool(nullability)));
}

let list_scalar = list.as_list();
let elements = list_scalar
.elements()
.ok_or_else(|| vortex_err!("Expected non-null list"))?;

let contains = elements.iter().any(|elem| elem == needle);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do this with a compare and an any true. maybe keep that you you have for len less than some value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we'll always falback to https://github.com/apache/arrow-rs/blob/181007df053e9004f7a211f7de086c4bbbd0a9e9/arrow-ord/src/cmp.rs#L395 - don't think it has any performance wins over that implementation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will use scalar equality not rust primitive type equality and not vectorize

Ok(Scalar::bool(contains, nullability))
}

/// Creates an expression that checks if a value is contained in a list.
///
/// Returns a boolean array indicating whether the value appears in each list.
Expand Down Expand Up @@ -379,4 +404,32 @@ mod tests {
let expr2 = list_contains(root(), lit(42));
assert_eq!(expr2.to_string(), "contains($, 42i32)");
}

#[test]
pub fn test_constant_scalars() {
let arr = test_array();

// Both list and needle are constants - should use scalar optimization
let list_scalar = Scalar::list(
Arc::new(DType::Primitive(I32, Nullability::NonNullable)),
vec![1.into(), 2.into(), 3.into()],
Nullability::NonNullable,
);

// Test contains true
let expr = list_contains(lit(list_scalar.clone()), lit(2i32));
let result = arr.apply(&expr).unwrap();
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::bool(true, Nullability::NonNullable)
);

// Test contains false
let expr = list_contains(lit(list_scalar), lit(42i32));
let result = arr.apply(&expr).unwrap();
assert_eq!(
result.scalar_at(0).unwrap(),
Scalar::bool(false, Nullability::NonNullable)
);
}
}
Loading