Skip to content

Commit

Permalink
refactor(rust): Expose many more function expressions to python IR (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored and ritchie46 committed Aug 23, 2024
1 parent e8cbe81 commit bbab1a7
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 80 deletions.
2 changes: 1 addition & 1 deletion crates/polars-python/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl NodeTraverser {
// Increment major on breaking changes to the IR (e.g. renaming
// fields, reordering tuples), minor on backwards compatible
// changes (e.g. exposing a new expression node).
const VERSION: Version = (1, 0);
const VERSION: Version = (1, 1);

pub(crate) fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
Self {
Expand Down
152 changes: 73 additions & 79 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use polars::datatypes::TimeUnit;
use polars::series::ops::NullBehavior;
use polars_core::prelude::{NonExistent, QuantileInterpolOptions};
use polars_core::series::IsSorted;
use polars_ops::prelude::ClosedInterval;
use polars_ops::series::InterpolationMethod;
#[cfg(feature = "search_sorted")]
use polars_ops::series::SearchSortedSide;
use polars_plan::dsl::function_expr::rolling::RollingFunction;
use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy;
use polars_plan::dsl::{BooleanFunction, StringFunction, TemporalFunction};
Expand Down Expand Up @@ -1054,21 +1058,31 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
},
FunctionExpr::Abs => ("abs",).to_object(py),
#[cfg(feature = "hist")]
FunctionExpr::Hist { .. } => return Err(PyNotImplementedError::new_err("hist")),
FunctionExpr::Hist {
bin_count,
include_category,
include_breakpoint,
} => ("hist", bin_count, include_category, include_breakpoint).to_object(py),
FunctionExpr::NullCount => ("null_count",).to_object(py),
FunctionExpr::Pow(f) => match f {
PowFunction::Generic => ("pow",).to_object(py),
PowFunction::Sqrt => ("sqrt",).to_object(py),
PowFunction::Cbrt => ("cbrt",).to_object(py),
},
FunctionExpr::Hash(_, _, _, _) => {
return Err(PyNotImplementedError::new_err("hash"))
FunctionExpr::Hash(seed, seed_1, seed_2, seed_3) => {
("hash", seed, seed_1, seed_2, seed_3).to_object(py)
},
FunctionExpr::ArgWhere => ("argwhere",).to_object(py),
#[cfg(feature = "search_sorted")]
FunctionExpr::SearchSorted(_) => {
return Err(PyNotImplementedError::new_err("search sorted"))
},
FunctionExpr::SearchSorted(side) => (
"search_sorted",
match side {
SearchSortedSide::Any => "any",
SearchSortedSide::Left => "left",
SearchSortedSide::Right => "right",
},
)
.to_object(py),
FunctionExpr::Range(_) => return Err(PyNotImplementedError::new_err("range")),
#[cfg(feature = "trigonometry")]
FunctionExpr::Trigonometry(trigfun) => {
Expand Down Expand Up @@ -1147,17 +1161,13 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
return Err(PyNotImplementedError::new_err("rolling std by"))
},
},
FunctionExpr::ShiftAndFill => {
return Err(PyNotImplementedError::new_err("shift and fill"))
},
FunctionExpr::ShiftAndFill => ("shift_and_fill",).to_object(py),
FunctionExpr::Shift => ("shift",).to_object(py),
FunctionExpr::DropNans => ("drop_nans",).to_object(py),
FunctionExpr::DropNulls => ("drop_nulls",).to_object(py),
FunctionExpr::Mode => ("mode",).to_object(py),
FunctionExpr::Skew(_) => return Err(PyNotImplementedError::new_err("skew")),
FunctionExpr::Kurtosis(_, _) => {
return Err(PyNotImplementedError::new_err("kurtosis"))
},
FunctionExpr::Skew(bias) => ("skew", bias).to_object(py),
FunctionExpr::Kurtosis(fisher, bias) => ("kurtosis", fisher, bias).to_object(py),
FunctionExpr::Reshape(_, _) => {
return Err(PyNotImplementedError::new_err("reshape"))
},
Expand All @@ -1168,11 +1178,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
options: _,
seed: _,
} => return Err(PyNotImplementedError::new_err("rank")),
FunctionExpr::Clip {
has_min: _,
has_max: _,
} => return Err(PyNotImplementedError::new_err("clip")),
FunctionExpr::AsStruct => return Err(PyNotImplementedError::new_err("as struct")),
FunctionExpr::Clip { has_min, has_max } => ("clip", has_min, has_max).to_object(py),
FunctionExpr::AsStruct => ("as_struct",).to_object(py),
#[cfg(feature = "top_k")]
FunctionExpr::TopK { descending } => ("top_k", descending).to_object(py),
FunctionExpr::CumCount { reverse } => ("cum_count", reverse).to_object(py),
Expand All @@ -1182,37 +1189,41 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::CumMax { reverse } => ("cum_max", reverse).to_object(py),
FunctionExpr::Reverse => ("reverse",).to_object(py),
FunctionExpr::ValueCounts {
sort: _,
parallel: _,
name: _,
normalize: _,
} => return Err(PyNotImplementedError::new_err("value counts")),
sort,
parallel,
name,
normalize,
} => ("value_counts", sort, parallel, name, normalize).to_object(py),
FunctionExpr::UniqueCounts => ("unique_counts",).to_object(py),
FunctionExpr::ApproxNUnique => {
return Err(PyNotImplementedError::new_err("approx nunique"))
},
FunctionExpr::ApproxNUnique => ("approx_n_unique",).to_object(py),
FunctionExpr::Coalesce => ("coalesce",).to_object(py),
FunctionExpr::ShrinkType => {
return Err(PyNotImplementedError::new_err("shrink type"))
},
FunctionExpr::Diff(_, _) => return Err(PyNotImplementedError::new_err("diff")),
FunctionExpr::ShrinkType => ("shrink_dtype",).to_object(py),
FunctionExpr::Diff(n, null_behaviour) => (
"diff",
n,
match null_behaviour {
NullBehavior::Drop => "drop",
NullBehavior::Ignore => "ignore",
},
)
.to_object(py),
#[cfg(feature = "pct_change")]
FunctionExpr::PctChange => {
return Err(PyNotImplementedError::new_err("pct change"))
},
FunctionExpr::Interpolate(_) => {
return Err(PyNotImplementedError::new_err("interpolate"))
},
FunctionExpr::InterpolateBy => {
return Err(PyNotImplementedError::new_err("interpolate_by"))
FunctionExpr::PctChange => ("pct_change",).to_object(py),
FunctionExpr::Interpolate(method) => (
"interpolate",
match method {
InterpolationMethod::Linear => "linear",
InterpolationMethod::Nearest => "nearest",
},
)
.to_object(py),
FunctionExpr::InterpolateBy => ("interpolate_by",).to_object(py),
FunctionExpr::Entropy { base, normalize } => {
("entropy", base, normalize).to_object(py)
},
FunctionExpr::Entropy {
base: _,
normalize: _,
} => return Err(PyNotImplementedError::new_err("entropy")),
FunctionExpr::Log { base: _ } => return Err(PyNotImplementedError::new_err("log")),
FunctionExpr::Log1p => return Err(PyNotImplementedError::new_err("log1p")),
FunctionExpr::Exp => return Err(PyNotImplementedError::new_err("exp")),
FunctionExpr::Log { base } => ("log", base).to_object(py),
FunctionExpr::Log1p => ("log1p",).to_object(py),
FunctionExpr::Exp => ("exp",).to_object(py),
FunctionExpr::Unique(maintain_order) => ("unique", maintain_order).to_object(py),
FunctionExpr::Round { decimals } => ("round", decimals).to_object(py),
FunctionExpr::RoundSF { digits } => ("round_sig_figs", digits).to_object(py),
Expand All @@ -1228,20 +1239,18 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
return Err(PyNotImplementedError::new_err("corr"))
},
#[cfg(feature = "peaks")]
FunctionExpr::PeakMin => return Err(PyNotImplementedError::new_err("peak min")),
FunctionExpr::PeakMin => ("peak_max",).to_object(py),
#[cfg(feature = "peaks")]
FunctionExpr::PeakMax => return Err(PyNotImplementedError::new_err("peak max")),
FunctionExpr::PeakMax => ("peak_min",).to_object(py),
#[cfg(feature = "cutqcut")]
FunctionExpr::Cut { .. } => return Err(PyNotImplementedError::new_err("cut")),
#[cfg(feature = "cutqcut")]
FunctionExpr::QCut { .. } => return Err(PyNotImplementedError::new_err("qcut")),
#[cfg(feature = "rle")]
FunctionExpr::RLE => return Err(PyNotImplementedError::new_err("rle")),
FunctionExpr::RLE => ("rle",).to_object(py),
#[cfg(feature = "rle")]
FunctionExpr::RLEID => return Err(PyNotImplementedError::new_err("rleid")),
FunctionExpr::ToPhysical => {
return Err(PyNotImplementedError::new_err("to physical"))
},
FunctionExpr::RLEID => ("rle_id",).to_object(py),
FunctionExpr::ToPhysical => ("to_physical",).to_object(py),
FunctionExpr::Random { .. } => {
return Err(PyNotImplementedError::new_err("random"))
},
Expand All @@ -1258,24 +1267,12 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::FfiPlugin { .. } => {
return Err(PyNotImplementedError::new_err("ffi plugin"))
},
FunctionExpr::BackwardFill { limit: _ } => {
return Err(PyNotImplementedError::new_err("backward fill"))
},
FunctionExpr::ForwardFill { limit: _ } => {
return Err(PyNotImplementedError::new_err("forward fill"))
},
FunctionExpr::SumHorizontal => {
return Err(PyNotImplementedError::new_err("sum horizontal"))
},
FunctionExpr::MaxHorizontal => {
return Err(PyNotImplementedError::new_err("max horizontal"))
},
FunctionExpr::MeanHorizontal => {
return Err(PyNotImplementedError::new_err("mean horizontal"))
},
FunctionExpr::MinHorizontal => {
return Err(PyNotImplementedError::new_err("min horizontal"))
},
FunctionExpr::BackwardFill { limit } => ("backward_fill", limit).to_object(py),
FunctionExpr::ForwardFill { limit } => ("forward_fill", limit).to_object(py),
FunctionExpr::SumHorizontal => ("sum_horizontal",).to_object(py),
FunctionExpr::MaxHorizontal => ("max_horizontal",).to_object(py),
FunctionExpr::MeanHorizontal => ("mean_horizontal",).to_object(py),
FunctionExpr::MinHorizontal => ("min_horizontal",).to_object(py),
FunctionExpr::EwmMean { options: _ } => {
return Err(PyNotImplementedError::new_err("ewm mean"))
},
Expand All @@ -1285,23 +1282,20 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::EwmVar { options: _ } => {
return Err(PyNotImplementedError::new_err("ewm var"))
},
FunctionExpr::Replace => return Err(PyNotImplementedError::new_err("replace")),
FunctionExpr::Replace => ("replace",).to_object(py),
FunctionExpr::ReplaceStrict { return_dtype: _ } => {
return Err(PyNotImplementedError::new_err("replace_strict"))
// Can ignore the return dtype because it is encoded in the schema.
("replace_strict",).to_object(py)
},
FunctionExpr::Negate => return Err(PyNotImplementedError::new_err("negate")),
FunctionExpr::Negate => ("negate",).to_object(py),
FunctionExpr::FillNullWithStrategy(_) => {
return Err(PyNotImplementedError::new_err("fill null with strategy"))
},
FunctionExpr::GatherEvery { n, offset } => {
("gather_every", offset, n).to_object(py)
},
FunctionExpr::Reinterpret(_) => {
return Err(PyNotImplementedError::new_err("reinterpret"))
},
FunctionExpr::ExtendConstant => {
return Err(PyNotImplementedError::new_err("extend constant"))
},
FunctionExpr::Reinterpret(signed) => ("reinterpret", signed).to_object(py),
FunctionExpr::ExtendConstant => ("extend_constant",).to_object(py),
FunctionExpr::Business(_) => {
return Err(PyNotImplementedError::new_err("business"))
},
Expand Down

0 comments on commit bbab1a7

Please sign in to comment.