From cb86827b7d48b863acbbf457dfe0d72cb8af82b7 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Sun, 7 Jan 2024 20:56:22 +0400 Subject: [PATCH] feat(rust,python,cli): add SQL engine support for `MOD` function (#13502) --- crates/polars-sql/src/functions.rs | 8 +++ py-polars/tests/unit/sql/test_numeric.py | 78 +++++++++++++++++------- 2 files changed, 64 insertions(+), 22 deletions(-) diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index a1ad0562a1a2..f14a904ccf92 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -89,6 +89,12 @@ pub(crate) enum PolarsSQLFunctions { /// SELECT POW(column_1, 2) from df; /// ``` Pow, + /// SQL 'mod' function + /// Returns the remainder of a numeric expression divided by another numeric expression. + /// ```sql + /// SELECT MOD(column_1, 2) from df; + /// ``` + Mod, /// SQL 'sqrt' function /// Returns the square root (√) of a number. /// ```sql @@ -601,6 +607,7 @@ impl PolarsSQLFunctions { "log10" => Self::Log10, "log1p" => Self::Log1p, "log2" => Self::Log2, + "mod" => Self::Mod, "pi" => Self::Pi, "pow" | "power" => Self::Pow, "round" => Self::Round, @@ -742,6 +749,7 @@ impl SQLFunctionVisitor<'_> { Log10 => self.visit_unary(|e| e.log(10.0)), Log1p => self.visit_unary(Expr::log1p), Log2 => self.visit_unary(|e| e.log(2.0)), + Mod => self.visit_binary(|e1, e2| e1 % e2), Pow => self.visit_binary::(Expr::pow), Sqrt => self.visit_unary(Expr::sqrt), Cbrt => self.visit_unary(Expr::cbrt), diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index 8d704e699011..6dec74d05ca5 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -7,28 +7,23 @@ from polars.testing import assert_frame_equal, assert_series_equal -def test_stddev_variance() -> None: +def test_modulo() -> None: df = pl.DataFrame( { - "v1": [-1.0, 0.0, 1.0], - "v2": [5.5, 0.0, 3.0], - "v3": [-10, None, 10], - "v4": [-100, 0.0, -50.0], + "a": [1.5, None, 3.0, 13 / 3, 5.0], + "b": [6, 7, 8, 9, 10], + "c": [11, 12, 13, 14, 15], + "d": [16.5, 17.0, 18.5, None, 20.0], } ) with pl.SQLContext(df=df) as ctx: - # note: we support all common aliases for std/var out = ctx.execute( """ SELECT - STDEV(v1) AS "v1_std", - STDDEV(v2) AS "v2_std", - STDEV_SAMP(v3) AS "v3_std", - STDDEV_SAMP(v4) AS "v4_std", - VAR(v1) AS "v1_var", - VARIANCE(v2) AS "v2_var", - VARIANCE(v3) AS "v3_var", - VAR_SAMP(v4) AS "v4_var" + a % 2 AS a2, + b % 3 AS b3, + MOD(c, 4) AS c4, + MOD(d, 5.5) AS d55 FROM df """ ).collect() @@ -37,14 +32,10 @@ def test_stddev_variance() -> None: out, pl.DataFrame( { - "v1_std": [1.0], - "v2_std": [2.7537852736431], - "v3_std": [14.142135623731], - "v4_std": [50.0], - "v1_var": [1.0], - "v2_var": [7.5833333333333], - "v3_var": [200.0], - "v4_var": [2500.0], + "a2": [1.5, None, 1.0, 1 / 3, 1.0], + "b3": [0, 1, 2, 0, 1], + "c4": [3, 0, 1, 2, 3], + "d55": [0.0, 0.5, 2.0, None, 3.5], } ), ) @@ -79,3 +70,46 @@ def test_round_ndigits_errors() -> None: InvalidOperationError, match="Invalid 'decimals' for Round: -1" ): ctx.execute("SELECT ROUND(n,-1) AS n FROM df") + + +def test_stddev_variance() -> None: + df = pl.DataFrame( + { + "v1": [-1.0, 0.0, 1.0], + "v2": [5.5, 0.0, 3.0], + "v3": [-10, None, 10], + "v4": [-100, 0.0, -50.0], + } + ) + with pl.SQLContext(df=df) as ctx: + # note: we support all common aliases for std/var + out = ctx.execute( + """ + SELECT + STDEV(v1) AS "v1_std", + STDDEV(v2) AS "v2_std", + STDEV_SAMP(v3) AS "v3_std", + STDDEV_SAMP(v4) AS "v4_std", + VAR(v1) AS "v1_var", + VARIANCE(v2) AS "v2_var", + VARIANCE(v3) AS "v3_var", + VAR_SAMP(v4) AS "v4_var" + FROM df + """ + ).collect() + + assert_frame_equal( + out, + pl.DataFrame( + { + "v1_std": [1.0], + "v2_std": [2.7537852736431], + "v3_std": [14.142135623731], + "v4_std": [50.0], + "v1_var": [1.0], + "v2_var": [7.5833333333333], + "v3_var": [200.0], + "v4_var": [2500.0], + } + ), + )