Skip to content

Commit

Permalink
fix(python): fix potential OverflowError in testing asserts with hu…
Browse files Browse the repository at this point in the history
…ge `UInt64` diffs (#10437)
  • Loading branch information
alexander-beedie authored Aug 13, 2023
1 parent 6634621 commit bdaad51
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
5 changes: 4 additions & 1 deletion py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DataTypeClass,
List,
Struct,
UInt64,
Utf8,
dtype_to_py_type,
unpack_dtypes,
Expand Down Expand Up @@ -407,7 +408,9 @@ def _assert_series_inner(

if all(tp in UNSIGNED_INTEGER_DTYPES for tp in (left.dtype, right.dtype)):
# avoid potential "subtract-with-overflow" panic on uint math
s_diff = Series("diff", [abs(v1 - v2) for v1, v2 in zip(left, right)])
s_diff = Series(
"diff", [abs(v1 - v2) for v1, v2 in zip(left, right)], dtype=UInt64
)
else:
s_diff = (left - right).abs()

Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,19 @@ def test_assert_series_equal_uint_overflow() -> None:
assert_series_equal(s1, s2, atol=0)
assert_series_equal(s1, s2, atol=1)

# confirm no OverflowError in the below test case:
# as "(left-right).abs()" > max(Int64)
left = pl.Series(
values=[2810428175213635359],
dtype=pl.UInt64,
)
right = pl.Series(
values=[15807433754238349345],
dtype=pl.UInt64,
)
with pytest.raises(AssertionError):
assert_series_equal(left, right)


@pytest.mark.parametrize(
("data1", "data2"),
Expand Down

0 comments on commit bdaad51

Please sign in to comment.