diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 58fb6bf8ca92a..9c19169500676 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -768,15 +768,15 @@ impl SQLFunctionVisitor<'_> { 1 => self.visit_unary(|e| e.round(0)), 2 => self.try_visit_binary(|e, decimals| { Ok(e.round(match decimals { - Expr::Literal(LiteralValue::Int64(n)) => n as u32, - _ => { - polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]); - } + Expr::Literal(LiteralValue::Int64(n)) => { + if n >= 0 { n as u32 } else { + polars_bail!(InvalidOperation: "Round does not (yet) support negative 'decimals': {}", function.args[1]); + } + }, + _ => polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]); })) }), - _ => { - polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len()); - }, + _ => polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len()); }, Sign => self.visit_unary(Expr::sign), Sqrt => self.visit_unary(Expr::sqrt), diff --git a/py-polars/tests/unit/sql/test_numeric.py b/py-polars/tests/unit/sql/test_numeric.py index 7d7938465ca57..09cbcc237739f 100644 --- a/py-polars/tests/unit/sql/test_numeric.py +++ b/py-polars/tests/unit/sql/test_numeric.py @@ -107,10 +107,15 @@ def test_round_ndigits(decimals: int, expected: list[float]) -> None: def test_round_ndigits_errors() -> None: df = pl.DataFrame({"n": [99.999]}) - with pl.SQLContext(df=df, eager_execution=True) as ctx, pytest.raises( - InvalidOperationError, match="Invalid 'decimals' for Round: -1" - ): - ctx.execute("SELECT ROUND(n,-1) AS n FROM df") + with pl.SQLContext(df=df, eager_execution=True) as ctx: + with pytest.raises( + InvalidOperationError, match="Invalid 'decimals' for Round: ??" + ): + ctx.execute("SELECT ROUND(n,'??') AS n FROM df") + with pytest.raises( + InvalidOperationError, match="Round .* negative 'decimals': -1" + ): + ctx.execute("SELECT ROUND(n,-1) AS n FROM df") def test_stddev_variance() -> None: