Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Improved assert equal messages #6737

Merged
merged 1 commit into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 53 additions & 44 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import textwrap
from typing import Any

import polars.internals as pli
Expand Down Expand Up @@ -61,28 +62,28 @@ def assert_frame_equal(
>>> df1 = pl.DataFrame({"a": [1, 2, 3]})
>>> df2 = pl.DataFrame({"a": [2, 3, 4]})
>>> assert_frame_equal(df1, df2) # doctest: +SKIP

AssertionError: Values for column 'a' are different.
"""
if isinstance(left, pli.LazyFrame) and isinstance(right, pli.LazyFrame):
left, right = left.collect(), right.collect()
obj = "LazyFrames"
else:
elif isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame):
obj = "DataFrames"
else:
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))

if not (isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
elif left.shape[0] != right.shape[0]:
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape)
if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr]

left_not_right = [c for c in left.columns if c not in right.columns]
if left_not_right:
raise AssertionError(
f"Columns {left_not_right} in left frame, but not in right"
f"Columns {left_not_right} in left frame, but not in right."
)
right_not_left = [c for c in right.columns if c not in left.columns]
if right_not_left:
raise AssertionError(
f"Columns {right_not_left} in right frame, but not in left"
f"Columns {right_not_left} in right frame, but not in left."
)

if check_column_order and left.columns != right.columns:
Expand All @@ -94,23 +95,26 @@ def assert_frame_equal(
try:
left = left.sort(by=left.columns)
right = right.sort(by=left.columns)
except PanicException as err:
except PanicException as exc:
raise InvalidAssert(
"Cannot set 'check_row_order=False' on frame with unsortable columns"
) from err
"Cannot set 'check_row_order=False' on frame with unsortable columns."
) from exc

# note: does not assume a particular column order
for c in left.columns:
_assert_series_inner(
left[c], # type: ignore[arg-type, index]
right[c], # type: ignore[arg-type, index]
check_dtype,
check_exact,
nans_compare_equal,
atol,
rtol,
obj,
)
try:
_assert_series_inner(
left[c], # type: ignore[arg-type, index]
right[c], # type: ignore[arg-type, index]
check_dtype,
check_exact,
nans_compare_equal,
atol,
rtol,
)
except AssertionError as exc:
msg = f"Values for column {c!r} are different."
raise AssertionError(msg) from exc


def assert_frame_not_equal(
Expand Down Expand Up @@ -174,8 +178,8 @@ def assert_frame_not_equal(
)
except AssertionError:
return

raise AssertionError("Expected the two frames to compare unequal")
else:
raise AssertionError("Expected the input frames to be unequal.")


def assert_series_equal(
Expand Down Expand Up @@ -219,22 +223,20 @@ def assert_series_equal(
>>> assert_series_equal(s1, s2) # doctest: +SKIP

"""
obj = "Series"

if not (
isinstance(left, pli.Series) # type: ignore[redundant-expr]
and isinstance(right, pli.Series)
):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))

if left.shape != right.shape:
raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape)
if len(left) != len(right):
raise_assert_detail("Series", "Length mismatch", len(left), len(right))

if check_names and left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)
raise_assert_detail("Series", "Name mismatch", left.name, right.name)

_assert_series_inner(
left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol, obj
left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol
)


Expand Down Expand Up @@ -292,8 +294,8 @@ def assert_series_not_equal(
)
except AssertionError:
return

raise AssertionError("Expected the two series to compare unequal")
else:
raise AssertionError("Expected the input Series to be unequal.")


def _assert_series_inner(
Expand All @@ -304,7 +306,6 @@ def _assert_series_inner(
nans_compare_equal: bool,
atol: float,
rtol: float,
obj: str,
) -> None:
"""Compare Series dtype + values."""
try:
Expand All @@ -314,7 +315,7 @@ def _assert_series_inner(

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
if check_dtype and left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)
raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype)

# confirm that we can call 'is_nan' on both sides
left_is_float = left.dtype in (Float32, Float64)
Expand All @@ -333,7 +334,7 @@ def _assert_series_inner(
if unequal.any():
if check_exact:
raise_assert_detail(
obj, "Exact value mismatch", left=list(left), right=list(right)
"Series", "Exact value mismatch", left=list(left), right=list(right)
)
else:
# apply check with tolerance (to the known-unequal matches).
Expand All @@ -354,27 +355,34 @@ def _assert_series_inner(

if mismatch:
raise_assert_detail(
obj, f"Value mismatch{nan_info}", left=list(left), right=list(right)
"Series",
f"Value mismatch{nan_info}",
left=list(left),
right=list(right),
)


def raise_assert_detail(
obj: str,
message: str,
detail: str,
left: Any,
right: Any,
exc: AssertionError | None = None,
) -> None:
"""Raise a detailed assertion error."""
__tracebackhide__ = True

msg = f"""{obj} are different

{message}"""
error_msg = textwrap.dedent(
f"""\
{obj} are different.

msg += f"""
[left]: {left}
[right]: {right}"""
{detail}
[left]: {left}
[right]: {right}\
"""
)

raise AssertionError(msg)
raise AssertionError(error_msg) from exc


def is_categorical_dtype(data_type: Any) -> bool:
Expand All @@ -389,6 +397,7 @@ def is_categorical_dtype(data_type: Any) -> bool:
def assert_frame_equal_local_categoricals(
df_a: pli.DataFrame, df_b: pli.DataFrame
) -> None:
"""Assert frame equal for frames containing categoricals."""
for (a_name, a_value), (b_name, b_value) in zip(
df_a.schema.items(), df_b.schema.items()
):
Expand Down
28 changes: 15 additions & 13 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_compare_series_value_mismatch() -> None:
srs2 = pl.Series([2, 3, 4])

assert_series_not_equal(srs1, srs2)
with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"):
with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"):
assert_series_equal(srs1, srs2)


Expand Down Expand Up @@ -97,41 +97,45 @@ def test_compare_series_value_mismatch_string() -> None:
srs1 = pl.Series(["hello", "no"])
srs2 = pl.Series(["hello", "yes"])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
AssertionError, match="Series are different.\n\nExact value mismatch"
):
assert_series_equal(srs1, srs2)


def test_compare_series_type_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.DataFrame({"col1": [2, 3, 4]})
with pytest.raises(AssertionError, match="Series are different\n\nType mismatch"):
with pytest.raises(
AssertionError, match="Inputs are different.\n\nUnexpected input types"
):
assert_series_equal(srs1, srs2) # type: ignore[arg-type]

srs3 = pl.Series([1.0, 2.0, 3.0])
with pytest.raises(AssertionError, match="Series are different\n\nDtype mismatch"):
with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"):
assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nName mismatch"):
with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"):
assert_series_equal(srs1, srs2)


def test_compare_series_shape_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nShape mismatch"):
with pytest.raises(
AssertionError, match="Series are different.\n\nLength mismatch"
):
assert_series_equal(srs1, srs2)


def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
AssertionError, match="Series are different.\n\nExact value mismatch"
):
assert_series_equal(srs1, srs2, check_exact=True)

Expand All @@ -150,9 +154,7 @@ def test_compare_frame_equal_nans() -> None:
data={"x": [1.0, nan], "y": [None, 2.0]},
schema=[("x", pl.Float32), ("y", pl.Float64)],
)
with pytest.raises(
AssertionError, match="DataFrames are different\n\nExact value mismatch"
):
with pytest.raises(AssertionError, match="Values for column 'y' are different"):
assert_frame_equal(df1, df2, check_exact=True)


Expand All @@ -166,7 +168,7 @@ def test_assert_frame_equal_types() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(
AssertionError, match="DataFrames are different\n\nType mismatch"
AssertionError, match="Inputs are different.\n\nUnexpected input types"
):
assert_frame_equal(df1, srs1) # type: ignore[arg-type]

Expand All @@ -175,7 +177,7 @@ def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(
AssertionError, match="DataFrames are different\n\nLength mismatch"
AssertionError, match="DataFrames are different.\n\nLength mismatch"
):
assert_frame_equal(df1, df2)

Expand Down Expand Up @@ -216,7 +218,7 @@ def test_assert_frame_equal_ignore_row_order() -> None:
df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})
df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]})
df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})
with pytest.raises(AssertionError, match="Value mismatch"):
with pytest.raises(AssertionError, match="Values for column 'a' are different."):
assert_frame_equal(df1, df2)

assert_frame_equal(df1, df2, check_row_order=False)
Expand Down