Skip to content

Commit

Permalink
refactor: Make functions in expr/general non-anonymous (pola-rs#13832)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored and r-brink committed Jan 24, 2024
1 parent 1122938 commit 9cd1140
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 53 deletions.
3 changes: 2 additions & 1 deletion crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,8 @@ pub trait ChunkSort<T: PolarsDataType> {

pub type FillNullLimit = Option<IdxSize>;

#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Hash)]
#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))]
pub enum FillNullStrategy {
/// previous value in array
Backward(FillNullLimit),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ chunked_ids = ["polars-plan/chunked_ids", "polars-core/chunked_ids", "polars-ops
list_to_struct = ["polars-plan/list_to_struct"]
python = ["pyo3", "polars-plan/python", "polars-core/python", "polars-io/python"]
row_hash = ["polars-plan/row_hash"]
reinterpret = ["polars-plan/reinterpret", "polars-ops/reinterpret"]
string_pad = ["polars-plan/string_pad"]
string_reverse = ["polars-plan/string_reverse"]
string_to_integer = ["polars-plan/string_to_integer"]
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ string_to_integer = ["polars-core/strings"]
extract_jsonpath = ["serde_json", "jsonpath_lib", "polars-json"]
log = []
hash = []
reinterpret = ["polars-core/reinterpret"]
group_by_list = ["polars-core/group_by_list"]
rolling_window = ["polars-core/rolling_window"]
moment = []
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ mod moment;
mod pct_change;
#[cfg(feature = "rank")]
mod rank;
#[cfg(feature = "reinterpret")]
mod reinterpret;
#[cfg(feature = "replace")]
mod replace;
#[cfg(feature = "rle")]
Expand Down Expand Up @@ -97,6 +99,8 @@ pub use pct_change::*;
use polars_core::prelude::*;
#[cfg(feature = "rank")]
pub use rank::*;
#[cfg(feature = "reinterpret")]
pub use reinterpret::*;
#[cfg(feature = "replace")]
pub use replace::*;
#[cfg(feature = "rle")]
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-ops/src/series/ops/reinterpret.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use polars_core::prelude::*;

pub fn reinterpret(s: &Series, signed: bool) -> PolarsResult<Series> {
Ok(match (s.dtype(), signed) {
(DataType::UInt64, true) => s.u64().unwrap().reinterpret_signed().into_series(),
(DataType::UInt64, false) => s.clone(),
(DataType::Int64, false) => s.i64().unwrap().reinterpret_unsigned().into_series(),
(DataType::Int64, true) => s.clone(),
(DataType::UInt32, true) => s.u32().unwrap().reinterpret_signed().into_series(),
(DataType::UInt32, false) => s.clone(),
(DataType::Int32, false) => s.i32().unwrap().reinterpret_unsigned().into_series(),
(DataType::Int32, true) => s.clone(),
_ => polars_bail!(
ComputeError:
"reinterpret is only allowed for 64-bit/32-bit integers types, use cast otherwise"
),
})
}
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ log = ["polars-ops/log"]
chunked_ids = ["polars-core/chunked_ids"]
list_to_struct = ["polars-ops/list_to_struct"]
row_hash = ["polars-core/row_hash", "polars-ops/hash"]
reinterpret = ["polars-core/reinterpret", "polars-ops/reinterpret"]
string_pad = ["polars-ops/string_pad"]
string_reverse = ["polars-ops/string_reverse"]
string_to_integer = ["polars-ops/string_to_integer"]
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,20 @@ pub(super) fn replace(s: &[Series], return_dtype: Option<DataType>) -> PolarsRes
let default = if let Some(s) = s.get(3) { s } else { &s[0] };
polars_ops::series::replace(&s[0], &s[1], &s[2], default, return_dtype)
}

pub(super) fn fill_null_with_strategy(
s: &Series,
strategy: FillNullStrategy,
) -> PolarsResult<Series> {
s.fill_null(strategy)
}

pub(super) fn gather_every(s: &Series, n: usize, offset: usize) -> PolarsResult<Series> {
polars_ensure!(n > 0, InvalidOperation: "gather_every(n): n should be positive");
Ok(s.gather_every(n, offset))
}

#[cfg(feature = "reinterpret")]
pub(super) fn reinterpret(s: &Series, signed: bool) -> PolarsResult<Series> {
polars_ops::series::reinterpret(s, signed)
}
19 changes: 19 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ pub enum FunctionExpr {
FillNull {
super_type: DataType,
},
FillNullWithStrategy(FillNullStrategy),
#[cfg(feature = "rolling_window")]
RollingExpr(RollingFunction),
ShiftAndFill,
Expand Down Expand Up @@ -320,6 +321,12 @@ pub enum FunctionExpr {
Replace {
return_dtype: Option<DataType>,
},
GatherEvery {
n: usize,
offset: usize,
},
#[cfg(feature = "reinterpret")]
Reinterpret(bool),
}

impl Hash for FunctionExpr {
Expand Down Expand Up @@ -521,6 +528,10 @@ impl Hash for FunctionExpr {
},
#[cfg(feature = "replace")]
Replace { return_dtype } => return_dtype.hash(state),
FillNullWithStrategy(strategy) => strategy.hash(state),
GatherEvery { n, offset } => (n, offset).hash(state),
#[cfg(feature = "reinterpret")]
Reinterpret(signed) => signed.hash(state),
}
}
}
Expand Down Expand Up @@ -689,6 +700,10 @@ impl Display for FunctionExpr {
Hist { .. } => "hist",
#[cfg(feature = "replace")]
Replace { .. } => "replace",
FillNullWithStrategy(_) => "fill_null_with_strategy",
GatherEvery { .. } => "gather_every",
#[cfg(feature = "reinterpret")]
Reinterpret(_) => "reinterpret",
};
write!(f, "{s}")
}
Expand Down Expand Up @@ -1040,6 +1055,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Replace { return_dtype } => {
map_as_slice!(dispatch::replace, return_dtype.clone())
},
FillNullWithStrategy(strategy) => map!(dispatch::fill_null_with_strategy, strategy),
GatherEvery { n, offset } => map!(dispatch::gather_every, n, offset),
#[cfg(feature = "reinterpret")]
Reinterpret(signed) => map!(dispatch::reinterpret, signed),
}
}
}
11 changes: 11 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,17 @@ impl FunctionExpr {
EwmVar { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "replace")]
Replace { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
FillNullWithStrategy(_) => mapper.with_same_dtype(),
GatherEvery { .. } => mapper.with_same_dtype(),
#[cfg(feature = "reinterpret")]
Reinterpret(signed) => {
let dt = if *signed {
DataType::Int64
} else {
DataType::UInt64
};
mapper.with_dtype(dt)
},
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,10 @@ impl Expr {
self.fill_null_impl(fill_value.into())
}

pub fn fill_null_with_strategy(self, strategy: FillNullStrategy) -> Self {
self.apply_private(FunctionExpr::FillNullWithStrategy(strategy))
}

/// Replace the floating point `NaN` values by a value.
pub fn fill_nan<E: Into<Expr>>(self, fill_value: E) -> Self {
// we take the not branch so that self is truthy value of `when -> then -> otherwise`
Expand Down Expand Up @@ -1617,6 +1621,15 @@ impl Expr {
self.map_private(FunctionExpr::ToPhysical)
}

pub fn gather_every(self, n: usize, offset: usize) -> Expr {
self.apply_private(FunctionExpr::GatherEvery { n, offset })
}

#[cfg(feature = "reinterpret")]
pub fn reinterpret(self, signed: bool) -> Expr {
self.map_private(FunctionExpr::Reinterpret(signed))
}

#[cfg(feature = "strings")]
/// Get the [`string::StringNameSpace`]
pub fn str(self) -> string::StringNameSpace {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ product = ["polars-core/product"]
propagate_nans = ["polars-lazy?/propagate_nans"]
range = ["polars-lazy?/range"]
rank = ["polars-lazy?/rank", "polars-ops/rank"]
reinterpret = ["polars-core/reinterpret"]
reinterpret = ["polars-core/reinterpret", "polars-lazy?/reinterpret", "polars-ops/reinterpret"]
repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"]
replace = ["polars-ops/replace", "polars-lazy?/replace"]
rle = ["polars-lazy?/rle"]
Expand Down
36 changes: 4 additions & 32 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use pyo3::types::PyBytes;
use crate::conversion::{parse_fill_null_strategy, Wrap};
use crate::error::PyPolarsErr;
use crate::map::lazy::map_single;
use crate::utils::reinterpret;
use crate::PyExpr;

#[pymethods]
Expand Down Expand Up @@ -356,16 +355,8 @@ impl PyExpr {
}

fn fill_null_with_strategy(&self, strategy: &str, limit: FillNullLimit) -> PyResult<Self> {
let strat = parse_fill_null_strategy(strategy, limit)?;
Ok(self
.inner
.clone()
.apply(
move |s| s.fill_null(strat).map(Some),
GetOutput::same_type(),
)
.with_fmt("fill_null_with_strategy")
.into())
let strategy = parse_fill_null_strategy(strategy, limit)?;
Ok(self.inner.clone().fill_null_with_strategy(strategy).into())
}

fn fill_nan(&self, expr: Self) -> Self {
Expand Down Expand Up @@ -424,17 +415,7 @@ impl PyExpr {
}

fn gather_every(&self, n: usize, offset: usize) -> Self {
self.inner
.clone()
.apply(
move |s: Series| {
polars_ensure!(n > 0, InvalidOperation: "gather_every(n): n can't be zero");
Ok(Some(s.gather_every(n, offset)))
},
GetOutput::same_type(),
)
.with_fmt("gather_every")
.into()
self.inner.clone().gather_every(n, offset).into()
}
fn tail(&self, n: usize) -> Self {
self.inner.clone().tail(Some(n)).into()
Expand Down Expand Up @@ -689,16 +670,7 @@ impl PyExpr {
}

fn reinterpret(&self, signed: bool) -> Self {
let function = move |s: Series| reinterpret(&s, signed).map(Some);
let dt = if signed {
DataType::Int64
} else {
DataType::UInt64
};
self.inner
.clone()
.map(function, GetOutput::from_type(dt))
.into()
self.inner.clone().reinterpret(signed).into()
}
fn mode(&self) -> Self {
self.inner.clone().mode().into()
Expand Down
19 changes: 0 additions & 19 deletions py-polars/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
use polars::prelude::*;

pub fn reinterpret(s: &Series, signed: bool) -> PolarsResult<Series> {
Ok(match (s.dtype(), signed) {
(DataType::UInt64, true) => s.u64().unwrap().reinterpret_signed().into_series(),
(DataType::UInt64, false) => s.clone(),
(DataType::Int64, false) => s.i64().unwrap().reinterpret_unsigned().into_series(),
(DataType::Int64, true) => s.clone(),
(DataType::UInt32, true) => s.u32().unwrap().reinterpret_signed().into_series(),
(DataType::UInt32, false) => s.clone(),
(DataType::Int32, false) => s.i32().unwrap().reinterpret_unsigned().into_series(),
(DataType::Int32, true) => s.clone(),
_ => polars_bail!(
ComputeError:
"reinterpret is only allowed for 64-bit/32-bit integers types, use cast otherwise"
),
})
}

// was redefined because I could not get feature flags activated?
#[macro_export]
macro_rules! apply_method_all_arrow_series2 {
Expand Down

0 comments on commit 9cd1140

Please sign in to comment.