Skip to content

Commit

Permalink
feat(python): Improved assert equal messages
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Feb 8, 2023
1 parent 870a818 commit a0c02e0
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 57 deletions.
97 changes: 53 additions & 44 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import textwrap
from typing import Any

import polars.internals as pli
Expand Down Expand Up @@ -61,28 +62,28 @@ def assert_frame_equal(
>>> df1 = pl.DataFrame({"a": [1, 2, 3]})
>>> df2 = pl.DataFrame({"a": [2, 3, 4]})
>>> assert_frame_equal(df1, df2) # doctest: +SKIP
AssertionError: Values for column 'a' are different.
"""
if isinstance(left, pli.LazyFrame) and isinstance(right, pli.LazyFrame):
left, right = left.collect(), right.collect()
obj = "LazyFrames"
else:
elif isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame):
obj = "DataFrames"
else:
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))

if not (isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
elif left.shape[0] != right.shape[0]:
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape)
if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr]

left_not_right = [c for c in left.columns if c not in right.columns]
if left_not_right:
raise AssertionError(
f"Columns {left_not_right} in left frame, but not in right"
f"Columns {left_not_right} in left frame, but not in right."
)
right_not_left = [c for c in right.columns if c not in left.columns]
if right_not_left:
raise AssertionError(
f"Columns {right_not_left} in right frame, but not in left"
f"Columns {right_not_left} in right frame, but not in left."
)

if check_column_order and left.columns != right.columns:
Expand All @@ -94,23 +95,26 @@ def assert_frame_equal(
try:
left = left.sort(by=left.columns)
right = right.sort(by=left.columns)
except PanicException as err:
except PanicException as exc:
raise InvalidAssert(
"Cannot set 'check_row_order=False' on frame with unsortable columns"
) from err
"Cannot set 'check_row_order=False' on frame with unsortable columns."
) from exc

# note: does not assume a particular column order
for c in left.columns:
_assert_series_inner(
left[c], # type: ignore[arg-type, index]
right[c], # type: ignore[arg-type, index]
check_dtype,
check_exact,
nans_compare_equal,
atol,
rtol,
obj,
)
try:
_assert_series_inner(
left[c], # type: ignore[arg-type, index]
right[c], # type: ignore[arg-type, index]
check_dtype,
check_exact,
nans_compare_equal,
atol,
rtol,
)
except AssertionError as exc:
msg = f"Values for column {c!r} are different."
raise AssertionError(msg) from exc


def assert_frame_not_equal(
Expand Down Expand Up @@ -174,8 +178,8 @@ def assert_frame_not_equal(
)
except AssertionError:
return

raise AssertionError("Expected the two frames to compare unequal")
else:
raise AssertionError("Expected the input frames to be unequal.")


def assert_series_equal(
Expand Down Expand Up @@ -219,22 +223,20 @@ def assert_series_equal(
>>> assert_series_equal(s1, s2) # doctest: +SKIP
"""
obj = "Series"

if not (
isinstance(left, pli.Series) # type: ignore[redundant-expr]
and isinstance(right, pli.Series)
):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))

if left.shape != right.shape:
raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape)
if len(left) != len(right):
raise_assert_detail("Series", "Length mismatch", len(left), len(right))

if check_names and left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)
raise_assert_detail("Series", "Name mismatch", left.name, right.name)

_assert_series_inner(
left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol, obj
left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol
)


Expand Down Expand Up @@ -292,8 +294,8 @@ def assert_series_not_equal(
)
except AssertionError:
return

raise AssertionError("Expected the two series to compare unequal")
else:
raise AssertionError("Expected the input Series to be unequal.")


def _assert_series_inner(
Expand All @@ -304,7 +306,6 @@ def _assert_series_inner(
nans_compare_equal: bool,
atol: float,
rtol: float,
obj: str,
) -> None:
"""Compare Series dtype + values."""
try:
Expand All @@ -314,7 +315,7 @@ def _assert_series_inner(

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
if check_dtype and left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)
raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype)

# confirm that we can call 'is_nan' on both sides
left_is_float = left.dtype in (Float32, Float64)
Expand All @@ -333,7 +334,7 @@ def _assert_series_inner(
if unequal.any():
if check_exact:
raise_assert_detail(
obj, "Exact value mismatch", left=list(left), right=list(right)
"Series", "Exact value mismatch", left=list(left), right=list(right)
)
else:
# apply check with tolerance (to the known-unequal matches).
Expand All @@ -354,27 +355,34 @@ def _assert_series_inner(

if mismatch:
raise_assert_detail(
obj, f"Value mismatch{nan_info}", left=list(left), right=list(right)
"Series",
f"Value mismatch{nan_info}",
left=list(left),
right=list(right),
)


def raise_assert_detail(
obj: str,
message: str,
detail: str,
left: Any,
right: Any,
exc: AssertionError | None = None,
) -> None:
"""Raise a detailed assertion error."""
__tracebackhide__ = True

msg = f"""{obj} are different
{message}"""
error_msg = textwrap.dedent(
f"""\
{obj} are different.
msg += f"""
[left]: {left}
[right]: {right}"""
{detail}
[left]: {left}
[right]: {right}\
"""
)

raise AssertionError(msg)
raise AssertionError(error_msg) from exc


def is_categorical_dtype(data_type: Any) -> bool:
Expand All @@ -389,6 +397,7 @@ def is_categorical_dtype(data_type: Any) -> bool:
def assert_frame_equal_local_categoricals(
df_a: pli.DataFrame, df_b: pli.DataFrame
) -> None:
"""Assert frame equal for frames containing categoricals."""
for (a_name, a_value), (b_name, b_value) in zip(
df_a.schema.items(), df_b.schema.items()
):
Expand Down
28 changes: 15 additions & 13 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_compare_series_value_mismatch() -> None:
srs2 = pl.Series([2, 3, 4])

assert_series_not_equal(srs1, srs2)
with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"):
with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"):
assert_series_equal(srs1, srs2)


Expand Down Expand Up @@ -97,41 +97,45 @@ def test_compare_series_value_mismatch_string() -> None:
srs1 = pl.Series(["hello", "no"])
srs2 = pl.Series(["hello", "yes"])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
AssertionError, match="Series are different.\n\nExact value mismatch"
):
assert_series_equal(srs1, srs2)


def test_compare_series_type_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.DataFrame({"col1": [2, 3, 4]})
with pytest.raises(AssertionError, match="Series are different\n\nType mismatch"):
with pytest.raises(
AssertionError, match="Inputs are different.\n\nUnexpected input types"
):
assert_series_equal(srs1, srs2) # type: ignore[arg-type]

srs3 = pl.Series([1.0, 2.0, 3.0])
with pytest.raises(AssertionError, match="Series are different\n\nDtype mismatch"):
with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"):
assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nName mismatch"):
with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"):
assert_series_equal(srs1, srs2)


def test_compare_series_shape_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nShape mismatch"):
with pytest.raises(
AssertionError, match="Series are different.\n\nLength mismatch"
):
assert_series_equal(srs1, srs2)


def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
AssertionError, match="Series are different.\n\nExact value mismatch"
):
assert_series_equal(srs1, srs2, check_exact=True)

Expand All @@ -150,9 +154,7 @@ def test_compare_frame_equal_nans() -> None:
data={"x": [1.0, nan], "y": [None, 2.0]},
schema=[("x", pl.Float32), ("y", pl.Float64)],
)
with pytest.raises(
AssertionError, match="DataFrames are different\n\nExact value mismatch"
):
with pytest.raises(AssertionError, match="Values for column 'y' are different"):
assert_frame_equal(df1, df2, check_exact=True)


Expand All @@ -166,7 +168,7 @@ def test_assert_frame_equal_types() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(
AssertionError, match="DataFrames are different\n\nType mismatch"
AssertionError, match="Inputs are different.\n\nUnexpected input types"
):
assert_frame_equal(df1, srs1) # type: ignore[arg-type]

Expand All @@ -175,7 +177,7 @@ def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(
AssertionError, match="DataFrames are different\n\nLength mismatch"
AssertionError, match="DataFrames are different.\n\nLength mismatch"
):
assert_frame_equal(df1, df2)

Expand Down Expand Up @@ -216,7 +218,7 @@ def test_assert_frame_equal_ignore_row_order() -> None:
df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]})
df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]})
df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]})
with pytest.raises(AssertionError, match="Value mismatch"):
with pytest.raises(AssertionError, match="Values for column 'a' are different."):
assert_frame_equal(df1, df2)

assert_frame_equal(df1, df2, check_row_order=False)
Expand Down

0 comments on commit a0c02e0

Please sign in to comment.