Skip to content

Commit

Permalink
improve round errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jan 23, 2024
1 parent 8f51ba5 commit 3d21c8b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
14 changes: 7 additions & 7 deletions crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,15 +768,15 @@ impl SQLFunctionVisitor<'_> {
1 => self.visit_unary(|e| e.round(0)),
2 => self.try_visit_binary(|e, decimals| {
Ok(e.round(match decimals {
Expr::Literal(LiteralValue::Int64(n)) => n as u32,
_ => {
polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]);
}
Expr::Literal(LiteralValue::Int64(n)) => {
if n >= 0 { n as u32 } else {
polars_bail!(InvalidOperation: "Round does not (yet) support negative 'decimals': {}", function.args[1]);
}
},
_ => polars_bail!(InvalidOperation: "Invalid 'decimals' for Round: {}", function.args[1]);
}))
}),
_ => {
polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len());
},
_ => polars_bail!(InvalidOperation:"Invalid number of arguments for Round: {}", function.args.len());
},
Sign => self.visit_unary(Expr::sign),
Sqrt => self.visit_unary(Expr::sqrt),
Expand Down
13 changes: 9 additions & 4 deletions py-polars/tests/unit/sql/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,15 @@ def test_round_ndigits(decimals: int, expected: list[float]) -> None:

def test_round_ndigits_errors() -> None:
df = pl.DataFrame({"n": [99.999]})
with pl.SQLContext(df=df, eager_execution=True) as ctx, pytest.raises(
InvalidOperationError, match="Invalid 'decimals' for Round: -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")
with pl.SQLContext(df=df, eager_execution=True) as ctx:
with pytest.raises(
InvalidOperationError, match="Invalid 'decimals' for Round: ??"
):
ctx.execute("SELECT ROUND(n,'??') AS n FROM df")
with pytest.raises(
InvalidOperationError, match="Round .* negative 'decimals': -1"
):
ctx.execute("SELECT ROUND(n,-1) AS n FROM df")


def test_stddev_variance() -> None:
Expand Down

0 comments on commit 3d21c8b

Please sign in to comment.