diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index e6ef7e7620e8..ca89a3ab1ef4 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter { )); } + // Unwrap dictionary-encoded needles when the value type matches + // in_array, evaluating against the dictionary values and mapping + // back via keys. downcast_dictionary_array! { v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) + // Only unwrap when the haystack (in_array) type matches + // the dictionary value type + if v.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())); + } } _ => {} } @@ -3878,4 +3885,204 @@ mod tests { ); Ok(()) } + + // ----------------------------------------------------------------------- + // Tests for try_new_from_array: evaluates `needle IN in_array`. + // + // This exercises the code path used by HashJoin dynamic filter pushdown, + // where in_array is built directly from the join's build-side arrays. + // Unlike try_new (used by SQL IN expressions), which always produces a + // non-Dictionary in_array because evaluate_list() flattens Dictionary + // scalars, try_new_from_array passes the array directly and can produce + // a Dictionary in_array. + // ----------------------------------------------------------------------- + + fn wrap_in_dict(array: ArrayRef) -> ArrayRef { + let keys = Int32Array::from((0..array.len() as i32).collect::>()); + Arc::new(DictionaryArray::new(keys, array)) + } + + /// Evaluates `needle IN in_array` via try_new_from_array, the same + /// path used by HashJoin dynamic filter pushdown (not the SQL literal + /// IN path which goes through try_new). + fn eval_in_list_from_array( + needle: ArrayRef, + in_array: ArrayRef, + ) -> Result { + let schema = + Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]); + let col_a = col("a", &schema)?; + let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?) + as Arc; + let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + Ok(as_boolean_array(&result).clone()) + } + + #[test] + fn test_in_list_from_array_type_combinations() -> Result<()> { + use arrow::compute::cast; + + // All cases: needle[0] and needle[2] match, needle[1] does not. + let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + + // Base arrays cast to each target type + let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef; + let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef; + + // Test all specializations in instantiate_static_filter + let primitive_types = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + ]; + + for dt in &primitive_types { + let in_array = cast(&base_in, dt)?; + let needle = cast(&base_needle, dt)?; + + // T in_array, T needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?, + "same-type failed for {dt:?}" + ); + + // T in_array, Dict(Int32, T) needle + assert_eq!( + expected, + eval_in_list_from_array(wrap_in_dict(needle), in_array)?, + "dict-needle failed for {dt:?}" + ); + } + + // Utf8 (falls through to ArrayStaticFilter) + let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef; + + // Utf8 in_array, Utf8 needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)? + ); + + // Utf8 in_array, Dict(Utf8) needle + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::clone(&utf8_in), + )? + ); + + // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + wrap_in_dict(Arc::clone(&utf8_in)), + )? + ); + + // Struct in_array, Struct needle: multi-column join + let struct_fields = Fields::from(vec![ + Field::new("c0", DataType::Utf8, true), + Field::new("c1", DataType::Int64, true), + ]); + let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_struct( + Arc::clone(&utf8_needle), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_struct( + Arc::clone(&utf8_in), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + // Struct with Dict fields: multi-column Dict join + let dict_struct_fields = Fields::from(vec![ + Field::new( + "c0", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("c1", DataType::Int64, true), + ]); + let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + dict_struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_in)), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + Ok(()) + } + + #[test] + fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { + // Utf8 needle, Dict(Utf8) in_array + let err = eval_in_list_from_array( + Arc::new(StringArray::from(vec!["a", "d", "b"])), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("Can't compare arrays of different types"), + "{err}" + ); + + // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter + // rejects the Utf8 dictionary values at construction time + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("Failed to downcast"), "{err}"); + + // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different + // value types, make_comparator rejects the comparison + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("Can't compare arrays of different types"), + "{err}" + ); + + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 00d7851befd1..85f954935713 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -918,13 +918,18 @@ CREATE EXTERNAL TABLE dict_filter_bug STORED AS PARQUET LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet'; -query error Can't compare arrays of different types +query TR SELECT t.tag1, t.value FROM dict_filter_bug t JOIN (VALUES ('A'), ('B')) AS v(c1) ON t.tag1 = v.c1 ORDER BY t.tag1, t.value LIMIT 4; +---- +A 0 +A 26 +A 52 +A 78 # Cleanup statement ok