From a846f659c055beeceb07ee4ada895f2611e55366 Mon Sep 17 00:00:00 2001 From: Rizky Mirzaviandy Priambodo <142987522+Xavrir@users.noreply.github.com> Date: Wed, 18 Mar 2026 07:25:15 +0700 Subject: [PATCH 1/2] docs: clarify NULL handling for array_remove functions (#21014) --- datafusion/functions-nested/src/remove.rs | 352 +++++++++++++++++- docs/source/user-guide/expressions.md | 6 +- .../source/user-guide/sql/scalar_functions.md | 29 +- 3 files changed, 369 insertions(+), 18 deletions(-) diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 3d4076800e1e9..4e28fd0fcebe8 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -40,13 +40,13 @@ make_udf_expr_and_func!( ArrayRemove, array_remove, array element, - "removes the first element from the array equal to the given value.", + "removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", array_remove_udf ); #[user_doc( doc_section(label = "Array Functions"), - description = "Removes the first element from the array equal to the given value.", + description = "Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", syntax_example = "array_remove(array, element)", sql_example = r#"```sql > select array_remove([1, 2, 2, 3, 2, 1, 4], 2); @@ -55,6 +55,13 @@ make_udf_expr_and_func!( +----------------------------------------------+ | [1, 2, 3, 2, 1, 4] | +----------------------------------------------+ + +> select array_remove([1, 2, NULL, 2, 4], 2); ++---------------------------------------------------+ +| array_remove(List([1,2,NULL,2,4]),Int64(2)) | ++---------------------------------------------------+ +| [1, NULL, 2, 4] | ++---------------------------------------------------+ ```"#, argument( name = "array", @@ -130,14 +137,14 @@ make_udf_expr_and_func!( ArrayRemoveN, array_remove_n, array element max, - "removes the first `max` elements from the array equal to the given value.", + "removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", array_remove_n_udf ); #[user_doc( doc_section(label = "Array Functions"), - description = "Removes the first `max` elements from the array equal to the given value.", - syntax_example = "array_remove_n(array, element, max))", + description = "Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", + syntax_example = "array_remove_n(array, element, max)", sql_example = r#"```sql > select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); +---------------------------------------------------------+ @@ -145,6 +152,13 @@ make_udf_expr_and_func!( +---------------------------------------------------------+ | [1, 3, 2, 1, 4] | +---------------------------------------------------------+ + +> select array_remove_n([1, 2, NULL, 2, 4], 2, 2); ++----------------------------------------------------------+ +| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) | ++----------------------------------------------------------+ +| [1, NULL, 4] | ++----------------------------------------------------------+ ```"#, argument( name = "array", @@ -225,13 +239,13 @@ make_udf_expr_and_func!( ArrayRemoveAll, array_remove_all, array element, - "removes all elements from the array equal to the given value.", + "removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", array_remove_all_udf ); #[user_doc( doc_section(label = "Array Functions"), - description = "Removes all elements from the array equal to the given value.", + description = "Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries.", syntax_example = "array_remove_all(array, element)", sql_example = r#"```sql > select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); @@ -240,6 +254,13 @@ make_udf_expr_and_func!( +--------------------------------------------------+ | [1, 3, 1, 4] | +--------------------------------------------------+ + +> select array_remove_all([1, 2, NULL, 2, 4], 2); ++-----------------------------------------------------+ +| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) | ++-----------------------------------------------------+ +| [1, NULL, 4] | ++-----------------------------------------------------+ ```"#, argument( name = "array", @@ -462,7 +483,8 @@ fn general_remove( mod tests { use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; use arrow::array::{ - Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait, + Array, ArrayRef, AsArray, GenericListArray, Int32Array, ListArray, + OffsetSizeTrait, }; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::ScalarValue; @@ -621,13 +643,59 @@ mod tests { assert_array_remove(input_list, expected_list, element_to_remove); } + #[test] + fn test_array_remove_null_element_returns_null() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + None, + Some(4), + ])]), + )); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + None::>>, + ]), + ); + + assert_array_remove(input_list, expected_list, ScalarValue::Int32(None)); + } + + #[test] + fn test_array_remove_row_wise_null_element_returns_null() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), None, Some(2), Some(4)]), + Some(vec![Some(5), None, Some(6)]), + ]), + )); + let element_to_remove = Arc::new(Int32Array::from(vec![Some(2), None])); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(2), Some(4)]), + None::>>, + ]), + ); + + assert_array_remove_array_arg(input_list, element_to_remove, expected_list); + } + fn assert_array_remove( input_list: ArrayRef, expected_list: GenericListArray, element_to_remove: ScalarValue, ) { assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!(expected_list.value_type(), element_to_remove.data_type()); + assert_eq!( + expected_list.value_type(), + element_to_remove.data_type().clone() + ); let input_list_len = input_list.len(); let input_list_data_type = input_list.data_type().clone(); @@ -672,6 +740,60 @@ mod tests { } } + fn assert_array_remove_array_arg( + input_list: ArrayRef, + element_to_remove: ArrayRef, + expected_list: GenericListArray, + ) { + assert_eq!(input_list.data_type(), expected_list.data_type()); + assert_eq!( + expected_list.value_type(), + element_to_remove.data_type().clone() + ); + let input_list_len = input_list.len(); + let input_list_data_type = input_list.data_type().clone(); + + let udf = ArrayRemove::new(); + let args_fields = vec![ + Arc::new(Field::new("num", input_list.data_type().clone(), false)), + Arc::new(Field::new( + "el", + element_to_remove.data_type().clone(), + true, + )), + ]; + let scalar_args = vec![None, None]; + + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input_list), + ColumnarValue::Array(element_to_remove), + ], + arg_fields: args_fields, + number_rows: input_list_len, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + + assert_eq!(result.data_type(), input_list_data_type); + match result { + ColumnarValue::Array(array) => { + let result_list = array.as_list::(); + assert_eq!(result_list, &expected_list); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } + #[test] fn test_array_remove_n_non_nullable() { let input_list = Arc::new(ensure_field_nullability( @@ -724,6 +846,54 @@ mod tests { assert_array_remove_n(input_list, expected_list, element_to_remove, 2); } + #[test] + fn test_array_remove_n_null_element_returns_null() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + None, + Some(4), + ])]), + )); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + None::>>, + ]), + ); + + assert_array_remove_n(input_list, expected_list, ScalarValue::Int32(None), 2); + } + + #[test] + fn test_array_remove_n_row_wise_null_element_returns_null() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), None, Some(2), Some(4)]), + Some(vec![Some(5), None, Some(6)]), + ]), + )); + let element_to_remove = Arc::new(Int32Array::from(vec![Some(2), None])); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(4)]), + None::>>, + ]), + ); + + assert_array_remove_n_array_arg( + input_list, + element_to_remove, + ScalarValue::Int64(Some(2)), + expected_list, + ); + } + fn assert_array_remove_n( input_list: ArrayRef, expected_list: GenericListArray, @@ -731,7 +901,10 @@ mod tests { n: i64, ) { assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!(expected_list.value_type(), element_to_remove.data_type()); + assert_eq!( + expected_list.value_type(), + element_to_remove.data_type().clone() + ); let input_list_len = input_list.len(); let input_list_data_type = input_list.data_type().clone(); @@ -780,6 +953,63 @@ mod tests { } } + fn assert_array_remove_n_array_arg( + input_list: ArrayRef, + element_to_remove: ArrayRef, + n: ScalarValue, + expected_list: GenericListArray, + ) { + assert_eq!(input_list.data_type(), expected_list.data_type()); + assert_eq!( + expected_list.value_type(), + element_to_remove.data_type().clone() + ); + let input_list_len = input_list.len(); + let input_list_data_type = input_list.data_type().clone(); + + let udf = ArrayRemoveN::new(); + let args_fields = vec![ + Arc::new(Field::new("num", input_list.data_type().clone(), false)), + Arc::new(Field::new( + "el", + element_to_remove.data_type().clone(), + true, + )), + Arc::new(Field::new("count", DataType::Int64, false)), + ]; + let scalar_args = vec![None, None, Some(&n)]; + + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input_list), + ColumnarValue::Array(element_to_remove), + ColumnarValue::Scalar(n), + ], + arg_fields: args_fields, + number_rows: input_list_len, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + + assert_eq!(result.data_type(), input_list_data_type); + match result { + ColumnarValue::Array(array) => { + let result_list = array.as_list::(); + assert_eq!(result_list, &expected_list); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } + #[test] fn test_array_remove_all_non_nullable() { let input_list = Arc::new(ensure_field_nullability( @@ -832,13 +1062,59 @@ mod tests { assert_array_remove_all(input_list, expected_list, element_to_remove); } + #[test] + fn test_array_remove_all_null_element_returns_null() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + None, + Some(2), + None, + Some(4), + ])]), + )); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + None::>>, + ]), + ); + + assert_array_remove_all(input_list, expected_list, ScalarValue::Int32(None)); + } + + #[test] + fn test_array_remove_all_row_wise_null_element_returns_null() { + let input_list = Arc::new(ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), None, Some(2), Some(4)]), + Some(vec![Some(5), None, Some(6)]), + ]), + )); + let element_to_remove = Arc::new(Int32Array::from(vec![Some(2), None])); + let expected_list = ensure_field_nullability( + true, + ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), None, Some(4)]), + None::>>, + ]), + ); + + assert_array_remove_all_array_arg(input_list, element_to_remove, expected_list); + } + fn assert_array_remove_all( input_list: ArrayRef, expected_list: GenericListArray, element_to_remove: ScalarValue, ) { assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!(expected_list.value_type(), element_to_remove.data_type()); + assert_eq!( + expected_list.value_type(), + element_to_remove.data_type().clone() + ); let input_list_len = input_list.len(); let input_list_data_type = input_list.data_type().clone(); @@ -882,4 +1158,58 @@ mod tests { _ => panic!("Expected ColumnarValue::Array"), } } + + fn assert_array_remove_all_array_arg( + input_list: ArrayRef, + element_to_remove: ArrayRef, + expected_list: GenericListArray, + ) { + assert_eq!(input_list.data_type(), expected_list.data_type()); + assert_eq!( + expected_list.value_type(), + element_to_remove.data_type().clone() + ); + let input_list_len = input_list.len(); + let input_list_data_type = input_list.data_type().clone(); + + let udf = ArrayRemoveAll::new(); + let args_fields = vec![ + Arc::new(Field::new("num", input_list.data_type().clone(), false)), + Arc::new(Field::new( + "el", + element_to_remove.data_type().clone(), + true, + )), + ]; + let scalar_args = vec![None, None]; + + let return_field = udf + .return_field_from_args(ReturnFieldArgs { + arg_fields: &args_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + + let result = udf + .invoke_with_args(ScalarFunctionArgs { + args: vec![ + ColumnarValue::Array(input_list), + ColumnarValue::Array(element_to_remove), + ], + arg_fields: args_fields, + number_rows: input_list_len, + return_field, + config_options: Arc::new(Default::default()), + }) + .unwrap(); + + assert_eq!(result.data_type(), input_list_data_type); + match result { + ColumnarValue::Array(array) => { + let result_list = array.as_list::(); + assert_eq!(result_list, &expected_list); + } + _ => panic!("Expected ColumnarValue::Array"), + } + } } diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 56d78ac473f14..af9053c9919e9 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -230,9 +230,9 @@ select log(-1), log(0), sqrt(-1); | array_positions(array, element) | Searches for an element in the array, returns all occurrences. `array_positions([1, 2, 2, 3, 4], 2) -> [2, 3]` | | array_prepend(element, array) | Prepends an element to the beginning of an array. `array_prepend(1, [2, 3, 4]) -> [1, 2, 3, 4]` | | array_repeat(element, count) | Returns an array containing element `count` times. `array_repeat(1, 3) -> [1, 1, 1]` | -| array_remove(array, element) | Removes the first element from the array equal to the given value. `array_remove([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 2, 3, 2, 1, 4]` | -| array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2) -> [1, 3, 2, 1, 4]` | -| array_remove_all(array, element) | Removes all elements from the array equal to the given value. `array_remove_all([1, 2, 2, 3, 2, 1, 4], 2) -> [1, 3, 1, 4]` | +| array_remove(array, element) | Removes the first element from the array equal to the given value. `NULL` elements already in the array are preserved when removing a non-`NULL` value, and `array_remove(array, NULL)` returns `NULL`. `array_remove([1, 2, NULL, 2, 4], 2) -> [1, NULL, 2, 4]` | +| array_remove_n(array, element, max) | Removes the first `max` elements from the array equal to the given value. `NULL` elements already in the array are preserved when removing a non-`NULL` value, and `array_remove_n(array, NULL, max)` returns `NULL`. `array_remove_n([1, 2, NULL, 2, 4], 2, 2) -> [1, NULL, 4]` | +| array_remove_all(array, element) | Removes all elements from the array equal to the given value. `NULL` elements already in the array are preserved when removing a non-`NULL` value, and `array_remove_all(array, NULL)` returns `NULL`. `array_remove_all([1, 2, NULL, 2, 4], 2) -> [1, NULL, 4]` | | array_replace(array, from, to) | Replaces the first occurrence of the specified element with another specified element. `array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 2, 3, 2, 1, 4]` | | array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` | | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 254151c2c20eb..6b39ea263fea4 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3880,7 +3880,7 @@ _Alias of [array_prepend](#array_prepend)._ ### `array_remove` -Removes the first element from the array equal to the given value. +Removes the first element from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries. ```sql array_remove(array, element) @@ -3900,6 +3900,13 @@ array_remove(array, element) +----------------------------------------------+ | [1, 2, 3, 2, 1, 4] | +----------------------------------------------+ + +> select array_remove([1, 2, NULL, 2, 4], 2); ++---------------------------------------------------+ +| array_remove(List([1,2,NULL,2,4]),Int64(2)) | ++---------------------------------------------------+ +| [1, NULL, 2, 4] | ++---------------------------------------------------+ ``` #### Aliases @@ -3908,7 +3915,7 @@ array_remove(array, element) ### `array_remove_all` -Removes all elements from the array equal to the given value. +Removes all elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries. ```sql array_remove_all(array, element) @@ -3928,6 +3935,13 @@ array_remove_all(array, element) +--------------------------------------------------+ | [1, 3, 1, 4] | +--------------------------------------------------+ + +> select array_remove_all([1, 2, NULL, 2, 4], 2); ++-----------------------------------------------------+ +| array_remove_all(List([1,2,NULL,2,4]),Int64(2)) | ++-----------------------------------------------------+ +| [1, NULL, 4] | ++-----------------------------------------------------+ ``` #### Aliases @@ -3936,10 +3950,10 @@ array_remove_all(array, element) ### `array_remove_n` -Removes the first `max` elements from the array equal to the given value. +Removes the first `max` elements from the array equal to the given value. NULL elements already in the array are preserved when removing a non-NULL value. If `element` evaluates to NULL, the result is NULL rather than removing NULL entries. ```sql -array_remove_n(array, element, max)) +array_remove_n(array, element, max) ``` #### Arguments @@ -3957,6 +3971,13 @@ array_remove_n(array, element, max)) +---------------------------------------------------------+ | [1, 3, 2, 1, 4] | +---------------------------------------------------------+ + +> select array_remove_n([1, 2, NULL, 2, 4], 2, 2); ++----------------------------------------------------------+ +| array_remove_n(List([1,2,NULL,2,4]),Int64(2),Int64(2)) | ++----------------------------------------------------------+ +| [1, NULL, 4] | ++----------------------------------------------------------+ ``` #### Aliases From 38fa4686005263d89926a720034b4e5a2bf06080 Mon Sep 17 00:00:00 2001 From: Rizky Mirzaviandy Priambodo <142987522+Xavrir@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:34:31 +0700 Subject: [PATCH 2/2] docs: trim duplicated tests from array_remove docs PR --- datafusion/functions-nested/src/remove.rs | 317 +--------------------- 1 file changed, 4 insertions(+), 313 deletions(-) diff --git a/datafusion/functions-nested/src/remove.rs b/datafusion/functions-nested/src/remove.rs index 4e28fd0fcebe8..54dec8ca18f4f 100644 --- a/datafusion/functions-nested/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -483,8 +483,7 @@ fn general_remove( mod tests { use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN}; use arrow::array::{ - Array, ArrayRef, AsArray, GenericListArray, Int32Array, ListArray, - OffsetSizeTrait, + Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait, }; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::ScalarValue; @@ -643,59 +642,13 @@ mod tests { assert_array_remove(input_list, expected_list, element_to_remove); } - #[test] - fn test_array_remove_null_element_returns_null() { - let input_list = Arc::new(ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - None, - Some(2), - None, - Some(4), - ])]), - )); - let expected_list = ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - None::>>, - ]), - ); - - assert_array_remove(input_list, expected_list, ScalarValue::Int32(None)); - } - - #[test] - fn test_array_remove_row_wise_null_element_returns_null() { - let input_list = Arc::new(ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(2), None, Some(2), Some(4)]), - Some(vec![Some(5), None, Some(6)]), - ]), - )); - let element_to_remove = Arc::new(Int32Array::from(vec![Some(2), None])); - let expected_list = ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), None, Some(2), Some(4)]), - None::>>, - ]), - ); - - assert_array_remove_array_arg(input_list, element_to_remove, expected_list); - } - fn assert_array_remove( input_list: ArrayRef, expected_list: GenericListArray, element_to_remove: ScalarValue, ) { assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!( - expected_list.value_type(), - element_to_remove.data_type().clone() - ); + assert_eq!(expected_list.value_type(), element_to_remove.data_type()); let input_list_len = input_list.len(); let input_list_data_type = input_list.data_type().clone(); @@ -740,60 +693,6 @@ mod tests { } } - fn assert_array_remove_array_arg( - input_list: ArrayRef, - element_to_remove: ArrayRef, - expected_list: GenericListArray, - ) { - assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!( - expected_list.value_type(), - element_to_remove.data_type().clone() - ); - let input_list_len = input_list.len(); - let input_list_data_type = input_list.data_type().clone(); - - let udf = ArrayRemove::new(); - let args_fields = vec![ - Arc::new(Field::new("num", input_list.data_type().clone(), false)), - Arc::new(Field::new( - "el", - element_to_remove.data_type().clone(), - true, - )), - ]; - let scalar_args = vec![None, None]; - - let return_field = udf - .return_field_from_args(ReturnFieldArgs { - arg_fields: &args_fields, - scalar_arguments: &scalar_args, - }) - .unwrap(); - - let result = udf - .invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(input_list), - ColumnarValue::Array(element_to_remove), - ], - arg_fields: args_fields, - number_rows: input_list_len, - return_field, - config_options: Arc::new(Default::default()), - }) - .unwrap(); - - assert_eq!(result.data_type(), input_list_data_type); - match result { - ColumnarValue::Array(array) => { - let result_list = array.as_list::(); - assert_eq!(result_list, &expected_list); - } - _ => panic!("Expected ColumnarValue::Array"), - } - } - #[test] fn test_array_remove_n_non_nullable() { let input_list = Arc::new(ensure_field_nullability( @@ -846,54 +745,6 @@ mod tests { assert_array_remove_n(input_list, expected_list, element_to_remove, 2); } - #[test] - fn test_array_remove_n_null_element_returns_null() { - let input_list = Arc::new(ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - None, - Some(2), - None, - Some(4), - ])]), - )); - let expected_list = ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - None::>>, - ]), - ); - - assert_array_remove_n(input_list, expected_list, ScalarValue::Int32(None), 2); - } - - #[test] - fn test_array_remove_n_row_wise_null_element_returns_null() { - let input_list = Arc::new(ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(2), None, Some(2), Some(4)]), - Some(vec![Some(5), None, Some(6)]), - ]), - )); - let element_to_remove = Arc::new(Int32Array::from(vec![Some(2), None])); - let expected_list = ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), None, Some(4)]), - None::>>, - ]), - ); - - assert_array_remove_n_array_arg( - input_list, - element_to_remove, - ScalarValue::Int64(Some(2)), - expected_list, - ); - } - fn assert_array_remove_n( input_list: ArrayRef, expected_list: GenericListArray, @@ -901,10 +752,7 @@ mod tests { n: i64, ) { assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!( - expected_list.value_type(), - element_to_remove.data_type().clone() - ); + assert_eq!(expected_list.value_type(), element_to_remove.data_type()); let input_list_len = input_list.len(); let input_list_data_type = input_list.data_type().clone(); @@ -953,63 +801,6 @@ mod tests { } } - fn assert_array_remove_n_array_arg( - input_list: ArrayRef, - element_to_remove: ArrayRef, - n: ScalarValue, - expected_list: GenericListArray, - ) { - assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!( - expected_list.value_type(), - element_to_remove.data_type().clone() - ); - let input_list_len = input_list.len(); - let input_list_data_type = input_list.data_type().clone(); - - let udf = ArrayRemoveN::new(); - let args_fields = vec![ - Arc::new(Field::new("num", input_list.data_type().clone(), false)), - Arc::new(Field::new( - "el", - element_to_remove.data_type().clone(), - true, - )), - Arc::new(Field::new("count", DataType::Int64, false)), - ]; - let scalar_args = vec![None, None, Some(&n)]; - - let return_field = udf - .return_field_from_args(ReturnFieldArgs { - arg_fields: &args_fields, - scalar_arguments: &scalar_args, - }) - .unwrap(); - - let result = udf - .invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(input_list), - ColumnarValue::Array(element_to_remove), - ColumnarValue::Scalar(n), - ], - arg_fields: args_fields, - number_rows: input_list_len, - return_field, - config_options: Arc::new(Default::default()), - }) - .unwrap(); - - assert_eq!(result.data_type(), input_list_data_type); - match result { - ColumnarValue::Array(array) => { - let result_list = array.as_list::(); - assert_eq!(result_list, &expected_list); - } - _ => panic!("Expected ColumnarValue::Array"), - } - } - #[test] fn test_array_remove_all_non_nullable() { let input_list = Arc::new(ensure_field_nullability( @@ -1062,59 +853,13 @@ mod tests { assert_array_remove_all(input_list, expected_list, element_to_remove); } - #[test] - fn test_array_remove_all_null_element_returns_null() { - let input_list = Arc::new(ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - None, - Some(2), - None, - Some(4), - ])]), - )); - let expected_list = ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - None::>>, - ]), - ); - - assert_array_remove_all(input_list, expected_list, ScalarValue::Int32(None)); - } - - #[test] - fn test_array_remove_all_row_wise_null_element_returns_null() { - let input_list = Arc::new(ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), Some(2), None, Some(2), Some(4)]), - Some(vec![Some(5), None, Some(6)]), - ]), - )); - let element_to_remove = Arc::new(Int32Array::from(vec![Some(2), None])); - let expected_list = ensure_field_nullability( - true, - ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1), None, Some(4)]), - None::>>, - ]), - ); - - assert_array_remove_all_array_arg(input_list, element_to_remove, expected_list); - } - fn assert_array_remove_all( input_list: ArrayRef, expected_list: GenericListArray, element_to_remove: ScalarValue, ) { assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!( - expected_list.value_type(), - element_to_remove.data_type().clone() - ); + assert_eq!(expected_list.value_type(), element_to_remove.data_type()); let input_list_len = input_list.len(); let input_list_data_type = input_list.data_type().clone(); @@ -1158,58 +903,4 @@ mod tests { _ => panic!("Expected ColumnarValue::Array"), } } - - fn assert_array_remove_all_array_arg( - input_list: ArrayRef, - element_to_remove: ArrayRef, - expected_list: GenericListArray, - ) { - assert_eq!(input_list.data_type(), expected_list.data_type()); - assert_eq!( - expected_list.value_type(), - element_to_remove.data_type().clone() - ); - let input_list_len = input_list.len(); - let input_list_data_type = input_list.data_type().clone(); - - let udf = ArrayRemoveAll::new(); - let args_fields = vec![ - Arc::new(Field::new("num", input_list.data_type().clone(), false)), - Arc::new(Field::new( - "el", - element_to_remove.data_type().clone(), - true, - )), - ]; - let scalar_args = vec![None, None]; - - let return_field = udf - .return_field_from_args(ReturnFieldArgs { - arg_fields: &args_fields, - scalar_arguments: &scalar_args, - }) - .unwrap(); - - let result = udf - .invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(input_list), - ColumnarValue::Array(element_to_remove), - ], - arg_fields: args_fields, - number_rows: input_list_len, - return_field, - config_options: Arc::new(Default::default()), - }) - .unwrap(); - - assert_eq!(result.data_type(), input_list_data_type); - match result { - ColumnarValue::Array(array) => { - let result_list = array.as_list::(); - assert_eq!(result_list, &expected_list); - } - _ => panic!("Expected ColumnarValue::Array"), - } - } }