From bbab1a70176488213877bab3b919fca573cea1c0 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 23 Aug 2024 10:47:24 +0100 Subject: [PATCH] refactor(rust): Expose many more function expressions to python IR (#18317) --- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 152 +++++++++--------- 2 files changed, 74 insertions(+), 80 deletions(-) diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 32585c4cc887e..36d8e6e4b7935 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (1, 0); + const VERSION: Version = (1, 1); pub(crate) fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index fe85d23b6fb7f..d282e6d528e32 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1,7 +1,11 @@ use polars::datatypes::TimeUnit; +use polars::series::ops::NullBehavior; use polars_core::prelude::{NonExistent, QuantileInterpolOptions}; use polars_core::series::IsSorted; use polars_ops::prelude::ClosedInterval; +use polars_ops::series::InterpolationMethod; +#[cfg(feature = "search_sorted")] +use polars_ops::series::SearchSortedSide; use polars_plan::dsl::function_expr::rolling::RollingFunction; use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy; use polars_plan::dsl::{BooleanFunction, StringFunction, TemporalFunction}; @@ -1054,21 +1058,31 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { }, FunctionExpr::Abs => ("abs",).to_object(py), #[cfg(feature = "hist")] - FunctionExpr::Hist { .. } => return Err(PyNotImplementedError::new_err("hist")), + FunctionExpr::Hist { + bin_count, + include_category, + include_breakpoint, + } => ("hist", bin_count, include_category, include_breakpoint).to_object(py), FunctionExpr::NullCount => ("null_count",).to_object(py), FunctionExpr::Pow(f) => match f { PowFunction::Generic => ("pow",).to_object(py), PowFunction::Sqrt => ("sqrt",).to_object(py), PowFunction::Cbrt => ("cbrt",).to_object(py), }, - FunctionExpr::Hash(_, _, _, _) => { - return Err(PyNotImplementedError::new_err("hash")) + FunctionExpr::Hash(seed, seed_1, seed_2, seed_3) => { + ("hash", seed, seed_1, seed_2, seed_3).to_object(py) }, FunctionExpr::ArgWhere => ("argwhere",).to_object(py), #[cfg(feature = "search_sorted")] - FunctionExpr::SearchSorted(_) => { - return Err(PyNotImplementedError::new_err("search sorted")) - }, + FunctionExpr::SearchSorted(side) => ( + "search_sorted", + match side { + SearchSortedSide::Any => "any", + SearchSortedSide::Left => "left", + SearchSortedSide::Right => "right", + }, + ) + .to_object(py), FunctionExpr::Range(_) => return Err(PyNotImplementedError::new_err("range")), #[cfg(feature = "trigonometry")] FunctionExpr::Trigonometry(trigfun) => { @@ -1147,17 +1161,13 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { return Err(PyNotImplementedError::new_err("rolling std by")) }, }, - FunctionExpr::ShiftAndFill => { - return Err(PyNotImplementedError::new_err("shift and fill")) - }, + FunctionExpr::ShiftAndFill => ("shift_and_fill",).to_object(py), FunctionExpr::Shift => ("shift",).to_object(py), FunctionExpr::DropNans => ("drop_nans",).to_object(py), FunctionExpr::DropNulls => ("drop_nulls",).to_object(py), FunctionExpr::Mode => ("mode",).to_object(py), - FunctionExpr::Skew(_) => return Err(PyNotImplementedError::new_err("skew")), - FunctionExpr::Kurtosis(_, _) => { - return Err(PyNotImplementedError::new_err("kurtosis")) - }, + FunctionExpr::Skew(bias) => ("skew", bias).to_object(py), + FunctionExpr::Kurtosis(fisher, bias) => ("kurtosis", fisher, bias).to_object(py), FunctionExpr::Reshape(_, _) => { return Err(PyNotImplementedError::new_err("reshape")) }, @@ -1168,11 +1178,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { options: _, seed: _, } => return Err(PyNotImplementedError::new_err("rank")), - FunctionExpr::Clip { - has_min: _, - has_max: _, - } => return Err(PyNotImplementedError::new_err("clip")), - FunctionExpr::AsStruct => return Err(PyNotImplementedError::new_err("as struct")), + FunctionExpr::Clip { has_min, has_max } => ("clip", has_min, has_max).to_object(py), + FunctionExpr::AsStruct => ("as_struct",).to_object(py), #[cfg(feature = "top_k")] FunctionExpr::TopK { descending } => ("top_k", descending).to_object(py), FunctionExpr::CumCount { reverse } => ("cum_count", reverse).to_object(py), @@ -1182,37 +1189,41 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::CumMax { reverse } => ("cum_max", reverse).to_object(py), FunctionExpr::Reverse => ("reverse",).to_object(py), FunctionExpr::ValueCounts { - sort: _, - parallel: _, - name: _, - normalize: _, - } => return Err(PyNotImplementedError::new_err("value counts")), + sort, + parallel, + name, + normalize, + } => ("value_counts", sort, parallel, name, normalize).to_object(py), FunctionExpr::UniqueCounts => ("unique_counts",).to_object(py), - FunctionExpr::ApproxNUnique => { - return Err(PyNotImplementedError::new_err("approx nunique")) - }, + FunctionExpr::ApproxNUnique => ("approx_n_unique",).to_object(py), FunctionExpr::Coalesce => ("coalesce",).to_object(py), - FunctionExpr::ShrinkType => { - return Err(PyNotImplementedError::new_err("shrink type")) - }, - FunctionExpr::Diff(_, _) => return Err(PyNotImplementedError::new_err("diff")), + FunctionExpr::ShrinkType => ("shrink_dtype",).to_object(py), + FunctionExpr::Diff(n, null_behaviour) => ( + "diff", + n, + match null_behaviour { + NullBehavior::Drop => "drop", + NullBehavior::Ignore => "ignore", + }, + ) + .to_object(py), #[cfg(feature = "pct_change")] - FunctionExpr::PctChange => { - return Err(PyNotImplementedError::new_err("pct change")) - }, - FunctionExpr::Interpolate(_) => { - return Err(PyNotImplementedError::new_err("interpolate")) - }, - FunctionExpr::InterpolateBy => { - return Err(PyNotImplementedError::new_err("interpolate_by")) + FunctionExpr::PctChange => ("pct_change",).to_object(py), + FunctionExpr::Interpolate(method) => ( + "interpolate", + match method { + InterpolationMethod::Linear => "linear", + InterpolationMethod::Nearest => "nearest", + }, + ) + .to_object(py), + FunctionExpr::InterpolateBy => ("interpolate_by",).to_object(py), + FunctionExpr::Entropy { base, normalize } => { + ("entropy", base, normalize).to_object(py) }, - FunctionExpr::Entropy { - base: _, - normalize: _, - } => return Err(PyNotImplementedError::new_err("entropy")), - FunctionExpr::Log { base: _ } => return Err(PyNotImplementedError::new_err("log")), - FunctionExpr::Log1p => return Err(PyNotImplementedError::new_err("log1p")), - FunctionExpr::Exp => return Err(PyNotImplementedError::new_err("exp")), + FunctionExpr::Log { base } => ("log", base).to_object(py), + FunctionExpr::Log1p => ("log1p",).to_object(py), + FunctionExpr::Exp => ("exp",).to_object(py), FunctionExpr::Unique(maintain_order) => ("unique", maintain_order).to_object(py), FunctionExpr::Round { decimals } => ("round", decimals).to_object(py), FunctionExpr::RoundSF { digits } => ("round_sig_figs", digits).to_object(py), @@ -1228,20 +1239,18 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { return Err(PyNotImplementedError::new_err("corr")) }, #[cfg(feature = "peaks")] - FunctionExpr::PeakMin => return Err(PyNotImplementedError::new_err("peak min")), + FunctionExpr::PeakMin => ("peak_max",).to_object(py), #[cfg(feature = "peaks")] - FunctionExpr::PeakMax => return Err(PyNotImplementedError::new_err("peak max")), + FunctionExpr::PeakMax => ("peak_min",).to_object(py), #[cfg(feature = "cutqcut")] FunctionExpr::Cut { .. } => return Err(PyNotImplementedError::new_err("cut")), #[cfg(feature = "cutqcut")] FunctionExpr::QCut { .. } => return Err(PyNotImplementedError::new_err("qcut")), #[cfg(feature = "rle")] - FunctionExpr::RLE => return Err(PyNotImplementedError::new_err("rle")), + FunctionExpr::RLE => ("rle",).to_object(py), #[cfg(feature = "rle")] - FunctionExpr::RLEID => return Err(PyNotImplementedError::new_err("rleid")), - FunctionExpr::ToPhysical => { - return Err(PyNotImplementedError::new_err("to physical")) - }, + FunctionExpr::RLEID => ("rle_id",).to_object(py), + FunctionExpr::ToPhysical => ("to_physical",).to_object(py), FunctionExpr::Random { .. } => { return Err(PyNotImplementedError::new_err("random")) }, @@ -1258,24 +1267,12 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::FfiPlugin { .. } => { return Err(PyNotImplementedError::new_err("ffi plugin")) }, - FunctionExpr::BackwardFill { limit: _ } => { - return Err(PyNotImplementedError::new_err("backward fill")) - }, - FunctionExpr::ForwardFill { limit: _ } => { - return Err(PyNotImplementedError::new_err("forward fill")) - }, - FunctionExpr::SumHorizontal => { - return Err(PyNotImplementedError::new_err("sum horizontal")) - }, - FunctionExpr::MaxHorizontal => { - return Err(PyNotImplementedError::new_err("max horizontal")) - }, - FunctionExpr::MeanHorizontal => { - return Err(PyNotImplementedError::new_err("mean horizontal")) - }, - FunctionExpr::MinHorizontal => { - return Err(PyNotImplementedError::new_err("min horizontal")) - }, + FunctionExpr::BackwardFill { limit } => ("backward_fill", limit).to_object(py), + FunctionExpr::ForwardFill { limit } => ("forward_fill", limit).to_object(py), + FunctionExpr::SumHorizontal => ("sum_horizontal",).to_object(py), + FunctionExpr::MaxHorizontal => ("max_horizontal",).to_object(py), + FunctionExpr::MeanHorizontal => ("mean_horizontal",).to_object(py), + FunctionExpr::MinHorizontal => ("min_horizontal",).to_object(py), FunctionExpr::EwmMean { options: _ } => { return Err(PyNotImplementedError::new_err("ewm mean")) }, @@ -1285,23 +1282,20 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::EwmVar { options: _ } => { return Err(PyNotImplementedError::new_err("ewm var")) }, - FunctionExpr::Replace => return Err(PyNotImplementedError::new_err("replace")), + FunctionExpr::Replace => ("replace",).to_object(py), FunctionExpr::ReplaceStrict { return_dtype: _ } => { - return Err(PyNotImplementedError::new_err("replace_strict")) + // Can ignore the return dtype because it is encoded in the schema. + ("replace_strict",).to_object(py) }, - FunctionExpr::Negate => return Err(PyNotImplementedError::new_err("negate")), + FunctionExpr::Negate => ("negate",).to_object(py), FunctionExpr::FillNullWithStrategy(_) => { return Err(PyNotImplementedError::new_err("fill null with strategy")) }, FunctionExpr::GatherEvery { n, offset } => { ("gather_every", offset, n).to_object(py) }, - FunctionExpr::Reinterpret(_) => { - return Err(PyNotImplementedError::new_err("reinterpret")) - }, - FunctionExpr::ExtendConstant => { - return Err(PyNotImplementedError::new_err("extend constant")) - }, + FunctionExpr::Reinterpret(signed) => ("reinterpret", signed).to_object(py), + FunctionExpr::ExtendConstant => ("extend_constant",).to_object(py), FunctionExpr::Business(_) => { return Err(PyNotImplementedError::new_err("business")) },