diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 12b87aa37274..99b1424fcce0 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -12,6 +12,7 @@ DataTypeClass, List, Struct, + UInt64, Utf8, dtype_to_py_type, unpack_dtypes, @@ -407,7 +408,9 @@ def _assert_series_inner( if all(tp in UNSIGNED_INTEGER_DTYPES for tp in (left.dtype, right.dtype)): # avoid potential "subtract-with-overflow" panic on uint math - s_diff = Series("diff", [abs(v1 - v2) for v1, v2 in zip(left, right)]) + s_diff = Series( + "diff", [abs(v1 - v2) for v1, v2 in zip(left, right)], dtype=UInt64 + ) else: s_diff = (left - right).abs() diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index 6d9f808f08c0..d763047a0dab 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -596,6 +596,19 @@ def test_assert_series_equal_uint_overflow() -> None: assert_series_equal(s1, s2, atol=0) assert_series_equal(s1, s2, atol=1) + # confirm no OverflowError in the below test case: + # as "(left-right).abs()" > max(Int64) + left = pl.Series( + values=[2810428175213635359], + dtype=pl.UInt64, + ) + right = pl.Series( + values=[15807433754238349345], + dtype=pl.UInt64, + ) + with pytest.raises(AssertionError): + assert_series_equal(left, right) + @pytest.mark.parametrize( ("data1", "data2"),