Skip to content

Commit

Permalink
feat(rust,python,cli): add SQL engine support for MOD function (#13502
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexander-beedie committed Jan 7, 2024
1 parent f0ba057 commit cb86827
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 22 deletions.
8 changes: 8 additions & 0 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT POW(column_1, 2) from df;
/// ```
Pow,
/// SQL 'mod' function
/// Returns the remainder of a numeric expression divided by another numeric expression.
/// ```sql
/// SELECT MOD(column_1, 2) from df;
/// ```
Mod,
/// SQL 'sqrt' function
/// Returns the square root (√) of a number.
/// ```sql
Expand Down Expand Up @@ -601,6 +607,7 @@ impl PolarsSQLFunctions {
"log10" => Self::Log10,
"log1p" => Self::Log1p,
"log2" => Self::Log2,
"mod" => Self::Mod,
"pi" => Self::Pi,
"pow" | "power" => Self::Pow,
"round" => Self::Round,
Expand Down Expand Up @@ -742,6 +749,7 @@ impl SQLFunctionVisitor<'_> {
Log10 => self.visit_unary(|e| e.log(10.0)),
Log1p => self.visit_unary(Expr::log1p),
Log2 => self.visit_unary(|e| e.log(2.0)),
Mod => self.visit_binary(|e1, e2| e1 % e2),
Pow => self.visit_binary::<Expr>(Expr::pow),
Sqrt => self.visit_unary(Expr::sqrt),
Cbrt => self.visit_unary(Expr::cbrt),
Expand Down
78 changes: 56 additions & 22 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,23 @@
from polars.testing import assert_frame_equal, assert_series_equal


def test_stddev_variance() -> None:
def test_modulo() -> None:
df = pl.DataFrame(
{
"v1": [-1.0, 0.0, 1.0],
"v2": [5.5, 0.0, 3.0],
"v3": [-10, None, 10],
"v4": [-100, 0.0, -50.0],
"a": [1.5, None, 3.0, 13 / 3, 5.0],
"b": [6, 7, 8, 9, 10],
"c": [11, 12, 13, 14, 15],
"d": [16.5, 17.0, 18.5, None, 20.0],
}
)
with pl.SQLContext(df=df) as ctx:
# note: we support all common aliases for std/var
out = ctx.execute(
"""
SELECT
STDEV(v1) AS "v1_std",
STDDEV(v2) AS "v2_std",
STDEV_SAMP(v3) AS "v3_std",
STDDEV_SAMP(v4) AS "v4_std",
VAR(v1) AS "v1_var",
VARIANCE(v2) AS "v2_var",
VARIANCE(v3) AS "v3_var",
VAR_SAMP(v4) AS "v4_var"
a % 2 AS a2,
b % 3 AS b3,
MOD(c, 4) AS c4,
MOD(d, 5.5) AS d55
FROM df
"""
).collect()
Expand All @@ -37,14 +32,10 @@ def test_stddev_variance() -> None:
out,
pl.DataFrame(
{
"v1_std": [1.0],
"v2_std": [2.7537852736431],
"v3_std": [14.142135623731],
"v4_std": [50.0],
"v1_var": [1.0],
"v2_var": [7.5833333333333],
"v3_var": [200.0],
"v4_var": [2500.0],
"a2": [1.5, None, 1.0, 1 / 3, 1.0],
"b3": [0, 1, 2, 0, 1],
"c4": [3, 0, 1, 2, 3],
"d55": [0.0, 0.5, 2.0, None, 3.5],
}
),
)
Expand Down Expand Up @@ -79,3 +70,46 @@ def test_round_ndigits_errors() -> None:
InvalidOperationError, match="Invalid 'decimals' for Round: -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_stddev_variance() -> None:
df = pl.DataFrame(
{
"v1": [-1.0, 0.0, 1.0],
"v2": [5.5, 0.0, 3.0],
"v3": [-10, None, 10],
"v4": [-100, 0.0, -50.0],
}
)
with pl.SQLContext(df=df) as ctx:
# note: we support all common aliases for std/var
out = ctx.execute(
"""
SELECT
STDEV(v1) AS "v1_std",
STDDEV(v2) AS "v2_std",
STDEV_SAMP(v3) AS "v3_std",
STDDEV_SAMP(v4) AS "v4_std",
VAR(v1) AS "v1_var",
VARIANCE(v2) AS "v2_var",
VARIANCE(v3) AS "v3_var",
VAR_SAMP(v4) AS "v4_var"
FROM df
"""
).collect()

assert_frame_equal(
out,
pl.DataFrame(
{
"v1_std": [1.0],
"v2_std": [2.7537852736431],
"v3_std": [14.142135623731],
"v4_std": [50.0],
"v1_var": [1.0],
"v2_var": [7.5833333333333],
"v3_var": [200.0],
"v4_var": [2500.0],
}
),
)

0 comments on commit cb86827

Please sign in to comment.