From 4fb12c91880bbf868f4a63c407a503dc29191e3b Mon Sep 17 00:00:00 2001 From: sorhawell Date: Fri, 17 Nov 2023 22:59:18 +0100 Subject: [PATCH] minor conversion --- R/expr__expr.R | 18 +++++--- R/extendr-wrappers.R | 8 ++-- src/rust/src/lazy/dsl.rs | 77 ++++++++++++++++------------------ src/rust/src/rdataframe/mod.rs | 4 +- 4 files changed, 53 insertions(+), 54 deletions(-) diff --git a/R/expr__expr.R b/R/expr__expr.R index 0f9357b5d..4ce493ea7 100644 --- a/R/expr__expr.R +++ b/R/expr__expr.R @@ -1300,7 +1300,8 @@ Expr_rechunk = "use_extendr_wrapper" #' pl$col("a")$cum_sum(reverse = TRUE)$alias("cum_sum_reversed") #' ) Expr_cum_sum = function(reverse = FALSE) { - .pr$Expr$cum_sum(self, reverse) + .pr$Expr$cum_sum(self, reverse) |> + unwrap("in cum_sum():") } @@ -1322,7 +1323,8 @@ Expr_cum_sum = function(reverse = FALSE) { #' pl$col("a")$cum_prod(reverse = TRUE)$alias("cum_prod_reversed") #' ) Expr_cum_prod = function(reverse = FALSE) { - .pr$Expr$cum_prod(self, reverse) + .pr$Expr$cum_prod(self, reverse) |> + unwrap("in cum_prod():") } #' Cumulative minimum @@ -1344,7 +1346,8 @@ Expr_cum_prod = function(reverse = FALSE) { #' pl$col("a")$cum_min(reverse = TRUE)$alias("cum_min_reversed") #' ) Expr_cum_min = function(reverse = FALSE) { - .pr$Expr$cum_min(self, reverse) + .pr$Expr$cum_min(self, reverse) |> + unwrap("in cum_min():") } #' Cumulative maximum @@ -1366,7 +1369,8 @@ Expr_cum_min = function(reverse = FALSE) { #' pl$col("a")$cum_max(reverse = TRUE)$alias("cum_max_reversed") #' ) Expr_cum_max = function(reverse = FALSE) { - .pr$Expr$cum_max(self, reverse) + .pr$Expr$cum_max(self, reverse) |> + unwrap("in cum_max():") } #' Cumulative count @@ -1390,7 +1394,8 @@ Expr_cum_max = function(reverse = FALSE) { #' pl$col("a")$cum_count(reverse = TRUE)$alias("cum_count_reversed") #' ) Expr_cum_count = function(reverse = FALSE) { - .pr$Expr$cum_count(self, reverse) + .pr$Expr$cum_count(self, reverse) |> + unwrap("in cum_count():") } @@ -1730,7 +1735,8 @@ Expr_sort_by = function(by, descending = FALSE) { #' @examples #' pl$select(pl$lit(0:10)$gather(c(1, 8, 0, 7))) Expr_gather = function(indices) { - .pr$Expr$gather(self, pl$lit(indices)) + .pr$Expr$gather(self, pl$lit(indices)) |> + unwrap("in $gather():") } diff --git a/R/extendr-wrappers.R b/R/extendr-wrappers.R index d4906c25e..0d09cdecf 100644 --- a/R/extendr-wrappers.R +++ b/R/extendr-wrappers.R @@ -11,14 +11,14 @@ #' @useDynLib polars, .registration = TRUE NULL -min_horizontal <- function(dotdotdot) .Call(wrap__min_horizontal, dotdotdot) - -max_horizontal <- function(dotdotdot) .Call(wrap__max_horizontal, dotdotdot) - all_horizontal <- function(dotdotdot) .Call(wrap__all_horizontal, dotdotdot) any_horizontal <- function(dotdotdot) .Call(wrap__any_horizontal, dotdotdot) +min_horizontal <- function(dotdotdot) .Call(wrap__min_horizontal, dotdotdot) + +max_horizontal <- function(dotdotdot) .Call(wrap__max_horizontal, dotdotdot) + sum_horizontal <- function(dotdotdot) .Call(wrap__sum_horizontal, dotdotdot) coalesce_exprs <- function(exprs) .Call(wrap__coalesce_exprs, exprs) diff --git a/src/rust/src/lazy/dsl.rs b/src/rust/src/lazy/dsl.rs index 4e8827a1b..10d3b2339 100644 --- a/src/rust/src/lazy/dsl.rs +++ b/src/rust/src/lazy/dsl.rs @@ -14,9 +14,7 @@ use crate::utils::extendr_helpers::robj_inherits; use crate::utils::parse_fill_null_strategy; use crate::utils::wrappers::null_to_opt; use crate::utils::{r_error_list, r_ok_list, r_result_list, robj_to_binary_vec}; -use crate::utils::{ - try_f64_into_i64, try_f64_into_u32, try_f64_into_usize, try_f64_into_usize_no_zero, -}; +use crate::utils::{try_f64_into_i64, try_f64_into_u32, try_f64_into_usize}; use crate::CONFIG; use extendr_api::{extendr, prelude::*, rprintln, Deref, DerefMut, Rinternals}; use pl::PolarsError as pl_error; @@ -287,8 +285,8 @@ impl Expr { .into() } - pub fn gather(&self, idx: &Expr) -> Self { - self.clone().0.gather(idx.0.clone()).into() + pub fn gather(&self, idx: Robj) -> RResult { + Ok(self.clone().0.gather(robj_to!(PLExpr, idx)?).into()) } pub fn sort_by(&self, by: Robj, descending: Robj) -> RResult { @@ -433,24 +431,20 @@ impl Expr { self.clone().0.explode().into() } - pub fn gather_every(&self, n: f64) -> List { - use pl::*; //dunno what set of traits needed just take all - - let result = try_f64_into_usize_no_zero(n) - .map_err(|err| format!("Invalid n argument in gather_every: {}", err)) - .map(|n| { - Expr( - self.clone() - .0 - .map( - move |s: Series| Ok(Some(s.gather_every(n))), - pl::GetOutput::same_type(), - ) - .with_fmt("gather_every"), - ) - }); - - r_result_list(result) + pub fn gather_every(&self, n: Robj) -> RResult { + let n = robj_to!(usize, n).and_then(|n| match n { + 0 => rerr().bad_arg("n").bad_val("n can't be zero"), + _ => Ok(n), + })?; + Ok(self + .0 + .clone() + .map( + move |s: pl::Series| Ok(Some(s.gather_every(n))), + pl::GetOutput::same_type(), + ) + .with_fmt("gather_every") + .into()) } pub fn hash( @@ -1074,7 +1068,7 @@ impl Expr { .take(robj_to!(PLExprCol, index)?, robj_to!(bool, null_on_oob)?) .into()) } - + fn list_get(&self, index: &Expr) -> Self { self.0.clone().list().get(index.clone().0).into() } @@ -1533,24 +1527,24 @@ impl Expr { self.0.clone().drop_nans().into() } - pub fn cum_sum(&self, reverse: bool) -> Self { - self.0.clone().cum_sum(reverse).into() + pub fn cum_sum(&self, reverse: Robj) -> RResult { + Ok(self.0.clone().cum_sum(robj_to!(bool, reverse)?).into()) } - pub fn cum_prod(&self, reverse: bool) -> Self { - self.0.clone().cum_prod(reverse).into() + pub fn cum_prod(&self, reverse: Robj) -> RResult { + Ok(self.0.clone().cum_prod(robj_to!(bool, reverse)?).into()) } - pub fn cum_min(&self, reverse: bool) -> Self { - self.0.clone().cum_min(reverse).into() + pub fn cum_min(&self, reverse: Robj) -> RResult { + Ok(self.0.clone().cum_min(robj_to!(bool, reverse)?).into()) } - pub fn cum_max(&self, reverse: bool) -> Self { - self.0.clone().cum_max(reverse).into() + pub fn cum_max(&self, reverse: Robj) -> RResult { + Ok(self.0.clone().cum_max(robj_to!(bool, reverse)?).into()) } - pub fn cum_count(&self, reverse: bool) -> Self { - self.0.clone().cum_count(reverse).into() + pub fn cum_count(&self, reverse: Robj) -> RResult { + Ok(self.0.clone().cum_count(robj_to!(bool, reverse)?).into()) } pub fn floor(&self) -> Self { @@ -1863,14 +1857,13 @@ impl Expr { .into() } - pub fn str_concat(&self, delimiter: &str, ignore_nulls: Robj) -> RResult { - Ok( - self.0 + pub fn str_concat(&self, delimiter: Robj, ignore_nulls: Robj) -> RResult { + Ok(self + .0 .clone() .str() - .concat(delimiter, robj_to!(bool, ignore_nulls)?) - .into() - ) + .concat(robj_to!(str, delimiter)?, robj_to!(bool, ignore_nulls)?) + .into()) } pub fn str_to_uppercase(&self) -> Self { @@ -2136,7 +2129,7 @@ impl Expr { Ok(self.0.clone().str().explode().into()) } - pub fn str_parse_int(&self, radix: Robj, strict: Robj) -> Result { + pub fn str_parse_int(&self, radix: Robj, strict: Robj) -> RResult { Ok(self .0 .clone() @@ -2146,7 +2139,7 @@ impl Expr { .into()) } - pub fn bin_contains(&self, lit: Robj) -> Result { + pub fn bin_contains(&self, lit: Robj) -> RResult { Ok(self .0 .clone() diff --git a/src/rust/src/rdataframe/mod.rs b/src/rust/src/rdataframe/mod.rs index e6d883b7a..5fa5b1dc1 100644 --- a/src/rust/src/rdataframe/mod.rs +++ b/src/rust/src/rdataframe/mod.rs @@ -463,7 +463,7 @@ impl DataFrame { pub fn write_csv( &self, path: Robj, - include_bom: bool, + include_bom: Robj, include_header: Robj, separator: Robj, line_terminator: Robj, @@ -479,7 +479,7 @@ impl DataFrame { let path = robj_to!(str, path)?; let f = std::fs::File::create(path)?; pl::CsvWriter::new(f) - .include_bom(include_bom) + .include_bom(robj_to!(bool, include_bom)?) .include_header(robj_to!(bool, include_header)?) .with_separator(robj_to!(Utf8Byte, separator)?) .with_line_terminator(robj_to!(String, line_terminator)?)