Skip to content

Commit

Permalink
feat(python,rust,cli): add SQL round support (#9330)
Browse files Browse the repository at this point in the history
Co-authored-by: ritchie <ritchie46@gmail.com>
  • Loading branch information
alexander-beedie and ritchie46 authored Jun 18, 2023
1 parent 0d399ec commit 05d1195
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 5 deletions.
35 changes: 32 additions & 3 deletions polars/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsError, PolarsResult};
use polars_lazy::dsl::Expr;
use polars_plan::dsl::count;
use polars_plan::logical_plan::LiteralValue;
use sqlparser::ast::{
Expr as SqlExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, Value as SqlValue,
WindowSpec, WindowType,
Expand Down Expand Up @@ -84,6 +85,11 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT POW(column_1, 2) from df;
/// ```
Pow,
/// SQL 'round' function
/// ```sql
/// SELECT ROUND(column_1, 3) from df;
/// ```
Round,
// ----
// String functions
// ----
Expand Down Expand Up @@ -271,6 +277,7 @@ impl PolarsSqlFunctions {
"max",
"min",
"pow",
"round",
"rtrim",
"starts_with",
"stddev",
Expand Down Expand Up @@ -303,6 +310,7 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions {
"log1p" => Self::Log1p,
"log2" => Self::Log2,
"pow" => Self::Pow,
"round" => Self::Round,
// ----
// String functions
// ----
Expand Down Expand Up @@ -366,6 +374,20 @@ impl SqlFunctionVisitor<'_> {
Log1p => self.visit_unary(Expr::log1p),
Log2 => self.visit_unary(|e| e.log(2.0)),
Pow => self.visit_binary::<Expr>(Expr::pow),
Round => match function.args.len() {
1 => self.visit_unary(|e| e.round(0)),
2 => self.try_visit_binary(|e, decimals| {
Ok(e.round(match decimals {
Expr::Literal(LiteralValue::Int64(n)) => n as u32,
_ => {
polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]);
}
}))
}),
_ => {
polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}",function.args.len());
},
},
// ----
// String functions
// ----
Expand All @@ -374,15 +396,15 @@ impl SqlFunctionVisitor<'_> {
LTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().lstrip(None)),
2 => self.visit_binary(|e, s| e.str().lstrip(Some(s))),
_ => panic!(
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for LTrim: {}",
function.args.len()
),
},
RTrim => match function.args.len() {
1 => self.visit_unary(|e| e.str().rstrip(None)),
2 => self.visit_binary(|e, s| e.str().rstrip(Some(s))),
_ => panic!(
_ => polars_bail!(InvalidOperation:
"Invalid number of arguments for RTrim: {}",
function.args.len()
),
Expand Down Expand Up @@ -488,14 +510,21 @@ impl SqlFunctionVisitor<'_> {
}

fn visit_binary<Arg: FromSqlExpr>(&self, f: impl Fn(Expr, Arg) -> Expr) -> PolarsResult<Expr> {
self.try_visit_binary(|e, a| Ok(f(e, a)))
}

fn try_visit_binary<Arg: FromSqlExpr>(
&self,
f: impl Fn(Expr, Arg) -> PolarsResult<Expr>,
) -> PolarsResult<Expr> {
let function = self.func;
let args = extract_args(function);
if let FunctionArgExpr::Expr(sql_expr) = args[0] {
let expr =
self.apply_window_spec(parse_sql_expr(sql_expr, self.ctx)?, &function.over)?;
if let FunctionArgExpr::Expr(sql_expr) = args[1] {
let expr2 = Arg::from_sql_expr(sql_expr, self.ctx)?;
Ok(f(expr, expr2))
f(expr, expr2)
} else {
not_supported_error(function.name.0[0].value.as_str(), &args)
}
Expand Down
1 change: 0 additions & 1 deletion polars/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ impl SqlExprVisitor<'_> {
UnaryOperator::Plus => lit(0) + expr,
UnaryOperator::Minus => lit(0) - expr,
UnaryOperator::Not => expr.not(),
UnaryOperator::PGSquareRoot => expr.pow(0.5),
other => polars_bail!(InvalidOperation: "Unary operator {:?} is not supported", other),
})
}
Expand Down
33 changes: 32 additions & 1 deletion py-polars/tests/unit/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import polars as pl
import polars.selectors as cs
from polars.testing import assert_frame_equal
from polars.testing import assert_frame_equal, assert_series_equal


# TODO: Do not rely on I/O for these tests
Expand Down Expand Up @@ -306,6 +306,37 @@ def test_sql_regex_error() -> None:
ctx.execute("SELECT * FROM df WHERE sval !~* abcde")


@pytest.mark.parametrize(
("decimals", "expected"),
[
(0, [-8192.0, -4.0, -2.0, 2.0, 4.0, 8193.0]),
(1, [-8192.5, -4.0, -1.5, 2.5, 3.6, 8192.5]),
(2, [-8192.5, -3.96, -1.54, 2.46, 3.6, 8192.5]),
(3, [-8192.499, -3.955, -1.543, 2.457, 3.599, 8192.5]),
(4, [-8192.499, -3.955, -1.5432, 2.4568, 3.599, 8192.5001]),
],
)
def test_sql_round_ndigits(decimals: int, expected: list[float]) -> None:
df = pl.DataFrame(
{"n": [-8192.499, -3.9550, -1.54321, 2.45678, 3.59901, 8192.5001]},
)
with pl.SQLContext(df=df, eager_execution=True) as ctx:
if decimals == 0:
out = ctx.execute("SELECT ROUND(n) AS n FROM df")
assert_series_equal(out["n"], pl.Series("n", values=expected))

out = ctx.execute(f"""SELECT ROUND("n",{decimals}) AS n FROM df""")
assert_series_equal(out["n"], pl.Series("n", values=expected))


def test_sql_round_ndigits_errors() -> None:
df = pl.DataFrame({"n": [99.999]})
with pl.SQLContext(df=df, eager_execution=True) as ctx, pytest.raises(
pl.InvalidOperationError, match="Invalid 'decimals' for Round: -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_sql_trim(foods_ipc_path: Path) -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
Expand Down

0 comments on commit 05d1195

Please sign in to comment.