Skip to content

Commit

Permalink
fix(python): Require exact checking for Decimals in assertion utils (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Feb 8, 2024
1 parent 15707e0 commit 455e7bf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 24 deletions.
10 changes: 0 additions & 10 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
FLOAT_DTYPES,
Array,
Categorical,
Decimal,
Float64,
List,
String,
Struct,
Expand Down Expand Up @@ -129,14 +127,6 @@ def _assert_series_values_equal(
if right.dtype == Categorical:
right = right.cast(String)

# Handle decimals
# TODO: Delete this branch when Decimal equality is implemented
# https://github.com/pola-rs/polars/issues/12118
if left.dtype == Decimal:
left = left.cast(Float64)
if right.dtype == Decimal:
right = right.cast(Float64)

# Determine unequal elements
try:
unequal = left.ne_missing(right)
Expand Down
20 changes: 6 additions & 14 deletions py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,24 +619,16 @@ def test_series_equal_nested_lengths_mismatch() -> None:
assert_series_equal(s1, s2)


def test_series_equal_decimals_exact() -> None:
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
with pytest.raises(AssertionError, match="exact value mismatch"):
assert_series_equal(s1, s2, check_exact=True)


def test_series_equal_decimals_inexact() -> None:
@pytest.mark.parametrize("check_exact", [True, False])
def test_series_equal_decimals(check_exact: bool) -> None:
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
assert_series_equal(s1, s2, check_exact=False)

assert_series_equal(s1, s1, check_exact=check_exact)
assert_series_equal(s2, s2, check_exact=check_exact)

def test_series_equal_decimals_inexact_fail() -> None:
s1 = pl.Series([D("1.00000"), D("2.00000")], dtype=pl.Decimal)
s2 = pl.Series([D("1.00000"), D("2.00001")], dtype=pl.Decimal)
with pytest.raises(AssertionError, match="value mismatch"):
assert_series_equal(s1, s2, check_exact=False, rtol=0)
with pytest.raises(AssertionError, match="exact value mismatch"):
assert_series_equal(s1, s2, check_exact=check_exact)


def test_assert_series_equal_w_large_integers_12328() -> None:
Expand Down

0 comments on commit 455e7bf

Please sign in to comment.