diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 44a6572f53d26..6c81fcc11c6c0 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -99,11 +99,18 @@ impl StaticFilter for ArrayStaticFilter { )); } + // Unwrap dictionary-encoded needles when the value type matches + // in_array, evaluating against the dictionary values and mapping + // back via keys. downcast_dictionary_array! { v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) + // Only unwrap when the haystack (in_array) type matches + // the dictionary value type + if v.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())); + } } _ => {} } @@ -3724,4 +3731,348 @@ mod tests { assert_eq!(result, &BooleanArray::from(vec![true, false, false])); Ok(()) } + /// Tests that short-circuit evaluation produces correct results. + /// When all rows match after the first list item, remaining items + /// should be skipped without affecting correctness. + #[test] + fn test_in_list_with_columns_short_circuit() -> Result<()> { + // a IN (b, c) where b already matches every row of a + // The short-circuit should skip evaluating c + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![1, 2, 3])), // b == a for all rows + Arc::new(Int32Array::from(vec![99, 99, 99])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + assert_eq!(result, &BooleanArray::from(vec![true, true, true])); + Ok(()) + } + + /// Short-circuit must NOT skip when nulls are present (three-valued logic). + /// Even if all non-null values are true, null rows keep the result as null. + #[test] + fn test_in_list_with_columns_short_circuit_with_nulls() -> Result<()> { + // a IN (b, c) where a has nulls + // Even if b matches all non-null rows, result should preserve nulls + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])), + Arc::new(Int32Array::from(vec![1, 2, 3])), // matches non-null rows + Arc::new(Int32Array::from(vec![99, 99, 99])), + ], + )?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: 1 IN (1, 99) → true + // row 1: NULL IN (2, 99) → NULL + // row 2: 3 IN (3, 99) → true + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + Ok(()) + } + + /// Tests the make_comparator + collect_bool fallback path using + /// struct column references (nested types don't support arrow_eq). + #[test] + fn test_in_list_with_columns_struct() -> Result<()> { + let struct_fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, false), + ]); + let struct_dt = DataType::Struct(struct_fields.clone()); + + let schema = Schema::new(vec![ + Field::new("a", struct_dt.clone(), true), + Field::new("b", struct_dt.clone(), false), + Field::new("c", struct_dt.clone(), false), + ]); + + // a: [{1,"a"}, {2,"b"}, NULL, {4,"d"}] + // b: [{1,"a"}, {9,"z"}, {3,"c"}, {4,"d"}] + // c: [{9,"z"}, {2,"b"}, {9,"z"}, {9,"z"}] + let a = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["a", "b", "c", "d"])), + ], + Some(vec![true, true, false, true].into()), + )); + let b = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 9, 3, 4])), + Arc::new(StringArray::from(vec!["a", "z", "c", "d"])), + ], + None, + )); + let c = Arc::new(StructArray::new( + struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![9, 2, 9, 9])), + Arc::new(StringArray::from(vec!["z", "b", "z", "z"])), + ], + None, + )); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b, c])?; + + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, false); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: {1,"a"} IN ({1,"a"}, {9,"z"}) → true (matches b) + // row 1: {2,"b"} IN ({9,"z"}, {2,"b"}) → true (matches c) + // row 2: NULL IN ({3,"c"}, {9,"z"}) → NULL + // row 3: {4,"d"} IN ({4,"d"}, {9,"z"}) → true (matches b) + assert_eq!( + result, + &BooleanArray::from(vec![Some(true), Some(true), None, Some(true)]) + ); + + // Also test NOT IN + let col_a = col("a", &schema)?; + let list = vec![col("b", &schema)?, col("c", &schema)?]; + let expr = make_in_list_with_columns(col_a, list, true); + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_boolean_array(&result); + // row 0: {1,"a"} NOT IN ({1,"a"}, {9,"z"}) → false + // row 1: {2,"b"} NOT IN ({9,"z"}, {2,"b"}) → false + // row 2: NULL NOT IN ({3,"c"}, {9,"z"}) → NULL + // row 3: {4,"d"} NOT IN ({4,"d"}, {9,"z"}) → false + assert_eq!( + result, + &BooleanArray::from(vec![Some(false), Some(false), None, Some(false)]) + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Tests for try_new_from_array: evaluates `needle IN in_array`. + // + // This exercises the code path used by HashJoin dynamic filter pushdown, + // where in_array is built directly from the join's build-side arrays. + // Unlike try_new (used by SQL IN expressions), which always produces a + // non-Dictionary in_array because evaluate_list() flattens Dictionary + // scalars, try_new_from_array passes the array directly and can produce + // a Dictionary in_array. + // ----------------------------------------------------------------------- + + fn wrap_in_dict(array: ArrayRef) -> ArrayRef { + let keys = Int32Array::from((0..array.len() as i32).collect::>()); + Arc::new(DictionaryArray::new(keys, array)) + } + + /// Evaluates `needle IN in_array` via try_new_from_array, the same + /// path used by HashJoin dynamic filter pushdown (not the SQL literal + /// IN path which goes through try_new). + fn eval_in_list_from_array( + needle: ArrayRef, + in_array: ArrayRef, + ) -> Result { + let schema = + Schema::new(vec![Field::new("a", needle.data_type().clone(), false)]); + let col_a = col("a", &schema)?; + let expr = Arc::new(InListExpr::try_new_from_array(col_a, in_array, false)?) + as Arc; + let batch = RecordBatch::try_new(Arc::new(schema), vec![needle])?; + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + Ok(as_boolean_array(&result).clone()) + } + + #[test] + fn test_in_list_from_array_type_combinations() -> Result<()> { + use arrow::compute::cast; + + // All cases: needle[0] and needle[2] match, needle[1] does not. + let expected = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + + // Base arrays cast to each target type + let base_in = Arc::new(Int64Array::from(vec![1i64, 2, 3])) as ArrayRef; + let base_needle = Arc::new(Int64Array::from(vec![1i64, 4, 2])) as ArrayRef; + + // Test all specializations in instantiate_static_filter + let primitive_types = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + ]; + + for dt in &primitive_types { + let in_array = cast(&base_in, dt)?; + let needle = cast(&base_needle, dt)?; + + // T in_array, T needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&needle), Arc::clone(&in_array))?, + "same-type failed for {dt:?}" + ); + + // T in_array, Dict(Int32, T) needle + assert_eq!( + expected, + eval_in_list_from_array(wrap_in_dict(needle), in_array)?, + "dict-needle failed for {dt:?}" + ); + } + + // Utf8 (falls through to ArrayStaticFilter) + let utf8_in = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef; + let utf8_needle = Arc::new(StringArray::from(vec!["a", "d", "b"])) as ArrayRef; + + // Utf8 in_array, Utf8 needle + assert_eq!( + expected, + eval_in_list_from_array(Arc::clone(&utf8_needle), Arc::clone(&utf8_in),)? + ); + + // Utf8 in_array, Dict(Utf8) needle + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::clone(&utf8_in), + )? + ); + + // Dict(Utf8) in_array, Dict(Utf8) needle: the #20937 bug + assert_eq!( + expected, + eval_in_list_from_array( + wrap_in_dict(Arc::clone(&utf8_needle)), + wrap_in_dict(Arc::clone(&utf8_in)), + )? + ); + + // Struct in_array, Struct needle: multi-column join + let struct_fields = Fields::from(vec![ + Field::new("c0", DataType::Utf8, true), + Field::new("c1", DataType::Int64, true), + ]); + let make_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_struct( + Arc::clone(&utf8_needle), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_struct( + Arc::clone(&utf8_in), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + // Struct with Dict fields: multi-column Dict join + let dict_struct_fields = Fields::from(vec![ + Field::new( + "c0", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new("c1", DataType::Int64, true), + ]); + let make_dict_struct = |c0: ArrayRef, c1: ArrayRef| -> ArrayRef { + let pairs: Vec<(FieldRef, ArrayRef)> = + dict_struct_fields.iter().cloned().zip([c0, c1]).collect(); + Arc::new(StructArray::from(pairs)) + }; + assert_eq!( + expected, + eval_in_list_from_array( + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_needle)), + Arc::new(Int64Array::from(vec![1, 4, 2])), + ), + make_dict_struct( + wrap_in_dict(Arc::clone(&utf8_in)), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ), + )? + ); + + Ok(()) + } + + #[test] + fn test_in_list_from_array_type_mismatch_errors() -> Result<()> { + // Utf8 needle, Dict(Utf8) in_array + let err = eval_in_list_from_array( + Arc::new(StringArray::from(vec!["a", "d", "b"])), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("Can't compare arrays of different types"), + "{err}" + ); + + // Dict(Utf8) needle, Int64 in_array: specialized Int64StaticFilter + // rejects the Utf8 dictionary values at construction time + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "d", "b"]))), + Arc::new(Int64Array::from(vec![1, 2, 3])), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("Failed to downcast"), "{err}"); + + // Dict(Int64) needle, Dict(Utf8) in_array: both Dict but different + // value types, make_comparator rejects the comparison + let err = eval_in_list_from_array( + wrap_in_dict(Arc::new(Int64Array::from(vec![1, 4, 2]))), + wrap_in_dict(Arc::new(StringArray::from(vec!["a", "b", "c"]))), + ) + .unwrap_err() + .to_string(); + assert!( + err.contains("Can't compare arrays of different types"), + "{err}" + ); + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 6c4383f997f81..85f9549357138 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -889,3 +889,54 @@ set datafusion.execution.parquet.pushdown_filters = false; statement ok DROP TABLE t_struct_filter; + +########## +# Regression test for https://github.com/apache/datafusion/issues/20937 +# +# Dynamic filter pushdown fails when joining VALUES against +# Dictionary-encoded Parquet columns. The InListExpr's ArrayStaticFilter +# unwraps the needle Dictionary but not the stored in_array, causing a +# make_comparator(Utf8, Dictionary) type mismatch. +########## + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +statement ok +set datafusion.execution.parquet.reorder_filters = true; + +statement ok +COPY ( + SELECT + arrow_cast(chr(65 + (row_num % 26)), 'Dictionary(Int32, Utf8)') as tag1, + row_num * 1.0 as value + FROM (SELECT unnest(range(0, 10000)) as row_num) +) TO 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet'; + +statement ok +CREATE EXTERNAL TABLE dict_filter_bug +STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/dict_filter_bug.parquet'; + +query TR +SELECT t.tag1, t.value +FROM dict_filter_bug t +JOIN (VALUES ('A'), ('B')) AS v(c1) +ON t.tag1 = v.c1 +ORDER BY t.tag1, t.value +LIMIT 4; +---- +A 0 +A 26 +A 52 +A 78 + +# Cleanup +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +statement ok +set datafusion.execution.parquet.reorder_filters = false; + +statement ok +DROP TABLE dict_filter_bug;