From f6eacfb3ba9e4de3823d13c279a25ba0c29988ec Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Tue, 17 Mar 2026 11:40:50 +0000 Subject: [PATCH] add new functions and extract out sections as code --- r/DESCRIPTION | 2 +- r/R/dplyr-funcs-conditional.R | 142 ++++++++++++++++-- r/R/dplyr-funcs-doc.R | 5 +- r/man/acero.Rd | 5 +- r/man/read_json_arrow.Rd | 2 +- r/man/schema.Rd | 2 +- .../testthat/test-dplyr-funcs-conditional.R | 132 ++++++++++++++++ 7 files changed, 271 insertions(+), 19 deletions(-) diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 7513cc89715e..054af467dec9 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -44,7 +44,7 @@ Imports: utils, vctrs Roxygen: list(markdown = TRUE, r6 = FALSE, load = "source") -RoxygenNote: 7.3.3 +RoxygenNote: 7.3.3.9000 Config/testthat/edition: 3 Config/build/bootstrap: TRUE Suggests: diff --git a/r/R/dplyr-funcs-conditional.R b/r/R/dplyr-funcs-conditional.R index 25d7fbc668cf..b8f46615aefe 100644 --- a/r/R/dplyr-funcs-conditional.R +++ b/r/R/dplyr-funcs-conditional.R @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +# nolint start: cyclocomp_linter. register_bindings_conditional <- function() { register_binding("%in%", function(x, table) { # We use `is_in` here, unlike with Arrays, which use `is_in_meta_binary` @@ -134,21 +135,134 @@ register_bindings_conditional <- function() { validation_error(paste0("`.default` must have size 1, not size ", length(.default), ".")) } - query[n + 1] <- TRUE - value[n + 1] <- .default - } - Expression$create( - "case_when", - args = c( - Expression$create( - "make_struct", - args = query, - options = list(field_names = as.character(seq_along(query))) - ), - value - ) - ) + query[[n + 1]] <- TRUE + value[[n + 1]] <- .default + } + build_case_when_expr(query, value) }, notes = "`.ptype` and `.size` arguments not supported" ) + + register_binding("dplyr::replace_when", function(x, ...) { + formulas <- list2(...) + n <- length(formulas) + if (n == 0) { + return(x) + } + query <- vector("list", n + 1) + value <- vector("list", n + 1) + mask <- caller_env() + for (i in seq_len(n)) { + f <- formulas[[i]] + if (!inherits(f, "formula")) { + validation_error("Each argument to replace_when() must be a two-sided formula") + } + query[[i]] <- arrow_eval(f[[2]], mask) + value[[i]] <- arrow_eval(f[[3]], mask) + if (!call_binding("is.logical", query[[i]])) { + validation_error("Left side of each formula in replace_when() must be a logical expression") + } + } + query[[n + 1]] <- TRUE + value[[n + 1]] <- x + build_case_when_expr(query, value) + }) + + register_binding("dplyr::replace_values", function(x, ..., from = NULL, to = NULL) { + parsed <- parse_value_mapping(x, list2(...), from, to, caller_env(), "replace_values") + if (is.null(parsed)) { + return(x) + } + query <- parsed$query + value <- parsed$value + n <- length(query) + query[[n + 1]] <- TRUE + value[[n + 1]] <- x + build_case_when_expr(query, value) + }) + + register_binding( + "dplyr::recode_values", + function(x, ..., from = NULL, to = NULL, default = NULL, unmatched = "default", ptype = NULL) { + if (!is.null(ptype)) { + arrow_not_supported("`recode_values()` with `ptype` specified") + } + if (unmatched == "error") { + arrow_not_supported("`recode_values()` with `unmatched = \"error\"`") + } + + parsed <- parse_value_mapping(x, list2(...), from, to, caller_env(), "recode_values") + if (is.null(parsed)) { + query <- list() + value <- list() + } else { + query <- parsed$query + value <- parsed$value + } + + if (!is.null(default)) { + n <- length(query) + query[[n + 1]] <- TRUE + value[[n + 1]] <- Expression$scalar(default) + } + build_case_when_expr(query, value) + }, + notes = "`ptype` argument and `unmatched = \"error\"` not supported" + ) + + # Create case_when Expression from query/value lists + build_case_when_expr <- function(query, value) { + Expression$create( + "case_when", + args = c( + Expression$create( + "make_struct", + args = query, + options = list(field_names = as.character(seq_along(query))) + ), + value + ) + ) + } + + # Parse value ~ replacement formulas or from/to vectors into query/value lists + # Used by replace_values and recode_values + parse_value_mapping <- function(x, formulas, from, to, mask, fn) { + if (length(formulas) > 0 && !is.null(from)) { + validation_error(paste0("Can't use both `...` and `from`/`to` in ", fn, "()")) + } + + if (length(formulas) > 0) { + n <- length(formulas) + query <- vector("list", n) + value <- vector("list", n) + for (i in seq_len(n)) { + f <- formulas[[i]] + if (!inherits(f, "formula")) { + validation_error(paste0("Each argument to ", fn, "() must be a two-sided formula")) + } + lhs <- arrow_eval(f[[2]], mask) + rhs <- arrow_eval(f[[3]], mask) + query[[i]] <- x == lhs + value[[i]] <- rhs + } + list(query = query, value = value) + } else if (!is.null(from)) { + if (is.null(to)) { + validation_error("`to` must be provided when using `from`") + } + n <- length(from) + to <- vctrs::vec_recycle(to, n) + query <- vector("list", n) + value <- vector("list", n) + for (i in seq_len(n)) { + query[[i]] <- x == from[[i]] + value[[i]] <- Expression$scalar(to[[i]]) + } + list(query = query, value = value) + } else { + NULL + } + } } +# nolint end. diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index 9293d14c94c0..76f311e30e13 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -21,7 +21,7 @@ #' #' The `arrow` package contains methods for 38 `dplyr` table functions, many of #' which are "verbs" that do transformations to one or more tables. -#' The package also has mappings of 224 R functions to the corresponding +#' The package also has mappings of 227 R functions to the corresponding #' functions in the Arrow compute library. These allow you to write code inside #' of `dplyr` methods that call R functions, including many in packages like #' `stringr` and `lubridate`, and they will get translated to Arrow and run @@ -214,6 +214,9 @@ #' * [`if_else()`][dplyr::if_else()] #' * [`n()`][dplyr::n()] #' * [`n_distinct()`][dplyr::n_distinct()] +#' * [`recode_values()`][dplyr::recode_values()]: `ptype` argument and `unmatched = "error"` not supported +#' * [`replace_values()`][dplyr::replace_values()] +#' * [`replace_when()`][dplyr::replace_when()] #' #' ## hms #' diff --git a/r/man/acero.Rd b/r/man/acero.Rd index ee156cc9129b..e293751face7 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -9,7 +9,7 @@ \description{ The \code{arrow} package contains methods for 38 \code{dplyr} table functions, many of which are "verbs" that do transformations to one or more tables. -The package also has mappings of 224 R functions to the corresponding +The package also has mappings of 227 R functions to the corresponding functions in the Arrow compute library. These allow you to write code inside of \code{dplyr} methods that call R functions, including many in packages like \code{stringr} and \code{lubridate}, and they will get translated to Arrow and run @@ -207,6 +207,9 @@ Valid values are "s", "ms" (default), "us", "ns". \item \code{\link[dplyr:if_else]{if_else()}} \item \code{\link[dplyr:context]{n()}} \item \code{\link[dplyr:n_distinct]{n_distinct()}} +\item \code{\link[dplyr:recode-and-replace-values]{recode_values()}}: \code{ptype} argument and \code{unmatched = "error"} not supported +\item \code{\link[dplyr:recode-and-replace-values]{replace_values()}} +\item \code{\link[dplyr:case-and-replace-when]{replace_when()}} } } diff --git a/r/man/read_json_arrow.Rd b/r/man/read_json_arrow.Rd index b809a63bcc6f..abf6b8fc44a8 100644 --- a/r/man/read_json_arrow.Rd +++ b/r/man/read_json_arrow.Rd @@ -54,7 +54,7 @@ If \code{schema} is not provided, Arrow data types are inferred from the data: \item JSON numbers convert to \code{\link[=int64]{int64()}}, falling back to \code{\link[=float64]{float64()}} if a non-integer is encountered. \item JSON strings of the kind "YYYY-MM-DD" and "YYYY-MM-DD hh:mm:ss" convert to \code{\link[=timestamp]{timestamp(unit = "s")}}, falling back to \code{\link[=utf8]{utf8()}} if a conversion error occurs. -\item JSON arrays convert to a \code{\link[=list_of]{list_of()}} type, and inference proceeds recursively on the JSON arrays' values. +\item JSON arrays convert to a \code{\link[vctrs:list_of]{vctrs::list_of()}} type, and inference proceeds recursively on the JSON arrays' values. \item Nested JSON objects convert to a \code{\link[=struct]{struct()}} type, and inference proceeds recursively on the JSON objects' values. } diff --git a/r/man/schema.Rd b/r/man/schema.Rd index 65ab2eea0d27..ff77a05d84aa 100644 --- a/r/man/schema.Rd +++ b/r/man/schema.Rd @@ -7,7 +7,7 @@ schema(...) } \arguments{ -\item{...}{\link[=field]{fields}, field name/\link[=data-type]{data type} pairs (or a list of), or object from which to extract +\item{...}{\link[vctrs:fields]{fields}, field name/\link[=data-type]{data type} pairs (or a list of), or object from which to extract a schema} } \description{ diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 58373db253fd..ee12b9464aae 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -517,3 +517,135 @@ test_that("external objects are found when they're not in the global environment tibble(x = c("a", "b"), x2 = c("foo", NA)) ) }) + +test_that("replace_when()", { + # replaces matching values, keeps original otherwise + compare_dplyr_binding( + .input |> + mutate(result = replace_when(int, int > 5 ~ 100L)) |> + collect(), + tbl + ) + + # multiple conditions + compare_dplyr_binding( + .input |> + mutate(result = replace_when(int, int > 7 ~ 100L, int < 3 ~ 0L)) |> + collect(), + tbl + ) + + # no formulas returns x unchanged + compare_dplyr_binding( + .input |> + mutate(result = replace_when(int)) |> + collect(), + tbl + ) + + # validation errors + expect_arrow_eval_error( + replace_when(int, TRUE), + "Each argument to replace_when\\(\\) must be a two-sided formula", + class = "validation_error" + ) + expect_arrow_eval_error( + replace_when(int, 0L ~ 100L), + "Left side of each formula in replace_when\\(\\) must be a logical expression", + class = "validation_error" + ) +}) + +test_that("replace_values()", { + # formula interface + compare_dplyr_binding( + .input |> + mutate(result = replace_values(chr, "a" ~ "A", "b" ~ "B")) |> + collect(), + tbl + ) + + # from/to interface + compare_dplyr_binding( + .input |> + mutate(result = replace_values(chr, from = c("a", "b"), to = c("A", "B"))) |> + collect(), + tbl + ) + + # unmatched values kept + compare_dplyr_binding( + .input |> + mutate(result = replace_values(chr, "a" ~ "A")) |> + collect(), + tbl + ) + + # no replacements returns x unchanged + compare_dplyr_binding( + .input |> + mutate(result = replace_values(chr)) |> + collect(), + tbl + ) + + # validation errors + expect_arrow_eval_error( + replace_values(chr, "a" ~ "A", from = "b"), + "Can't use both `...` and `from`/`to` in replace_values\\(\\)", + class = "validation_error" + ) + expect_arrow_eval_error( + replace_values(chr, from = "a"), + "`to` must be provided when using `from`", + class = "validation_error" + ) +}) + +test_that("recode_values()", { + # formula interface with default NA + compare_dplyr_binding( + .input |> + mutate(result = recode_values(chr, "a" ~ "A", "b" ~ "B")) |> + collect(), + tbl + ) + + # from/to interface + compare_dplyr_binding( + .input |> + mutate(result = recode_values(chr, from = c("a", "b"), to = c("A", "B"))) |> + collect(), + tbl + ) + + # custom default + compare_dplyr_binding( + .input |> + mutate(result = recode_values(chr, "a" ~ "A", default = "other")) |> + collect(), + tbl + ) + + # validation errors + expect_arrow_eval_error( + recode_values(chr, "a" ~ "A", from = "b"), + "Can't use both `...` and `from`/`to` in recode_values\\(\\)", + class = "validation_error" + ) + expect_arrow_eval_error( + recode_values(chr, from = "a"), + "`to` must be provided when using `from`", + class = "validation_error" + ) + expect_arrow_eval_error( + recode_values(chr, "a" ~ "A", ptype = character()), + "`recode_values\\(\\)` with `ptype` specified not supported in Arrow", + class = "arrow_not_supported" + ) + expect_arrow_eval_error( + recode_values(chr, "a" ~ "A", unmatched = "error"), + "`recode_values\\(\\)` with `unmatched = \"error\"` not supported in Arrow", + class = "arrow_not_supported" + ) +})