From a0c02e09def8427958c42de7a9e6098271c598ee Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 8 Feb 2023 23:51:11 +0100 Subject: [PATCH] feat(python): Improved assert equal messages --- py-polars/polars/testing/asserts.py | 97 +++++++++++++++------------- py-polars/tests/unit/test_testing.py | 28 ++++---- 2 files changed, 68 insertions(+), 57 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 5891b64080ff..5e8dfe7a09c0 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -1,5 +1,6 @@ from __future__ import annotations +import textwrap from typing import Any import polars.internals as pli @@ -61,28 +62,28 @@ def assert_frame_equal( >>> df1 = pl.DataFrame({"a": [1, 2, 3]}) >>> df2 = pl.DataFrame({"a": [2, 3, 4]}) >>> assert_frame_equal(df1, df2) # doctest: +SKIP - + AssertionError: Values for column 'a' are different. """ if isinstance(left, pli.LazyFrame) and isinstance(right, pli.LazyFrame): left, right = left.collect(), right.collect() obj = "LazyFrames" - else: + elif isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame): obj = "DataFrames" + else: + raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) - if not (isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame)): - raise_assert_detail(obj, "Type mismatch", type(left), type(right)) - elif left.shape[0] != right.shape[0]: - raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) + if left.shape[0] != right.shape[0]: # type: ignore[union-attr] + raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr] left_not_right = [c for c in left.columns if c not in right.columns] if left_not_right: raise AssertionError( - f"Columns {left_not_right} in left frame, but not in right" + f"Columns {left_not_right} in left frame, but not in right." ) right_not_left = [c for c in right.columns if c not in left.columns] if right_not_left: raise AssertionError( - f"Columns {right_not_left} in right frame, but not in left" + f"Columns {right_not_left} in right frame, but not in left." ) if check_column_order and left.columns != right.columns: @@ -94,23 +95,26 @@ def assert_frame_equal( try: left = left.sort(by=left.columns) right = right.sort(by=left.columns) - except PanicException as err: + except PanicException as exc: raise InvalidAssert( - "Cannot set 'check_row_order=False' on frame with unsortable columns" - ) from err + "Cannot set 'check_row_order=False' on frame with unsortable columns." + ) from exc # note: does not assume a particular column order for c in left.columns: - _assert_series_inner( - left[c], # type: ignore[arg-type, index] - right[c], # type: ignore[arg-type, index] - check_dtype, - check_exact, - nans_compare_equal, - atol, - rtol, - obj, - ) + try: + _assert_series_inner( + left[c], # type: ignore[arg-type, index] + right[c], # type: ignore[arg-type, index] + check_dtype, + check_exact, + nans_compare_equal, + atol, + rtol, + ) + except AssertionError as exc: + msg = f"Values for column {c!r} are different." + raise AssertionError(msg) from exc def assert_frame_not_equal( @@ -174,8 +178,8 @@ def assert_frame_not_equal( ) except AssertionError: return - - raise AssertionError("Expected the two frames to compare unequal") + else: + raise AssertionError("Expected the input frames to be unequal.") def assert_series_equal( @@ -219,22 +223,20 @@ def assert_series_equal( >>> assert_series_equal(s1, s2) # doctest: +SKIP """ - obj = "Series" - if not ( isinstance(left, pli.Series) # type: ignore[redundant-expr] and isinstance(right, pli.Series) ): - raise_assert_detail(obj, "Type mismatch", type(left), type(right)) + raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) - if left.shape != right.shape: - raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape) + if len(left) != len(right): + raise_assert_detail("Series", "Length mismatch", len(left), len(right)) if check_names and left.name != right.name: - raise_assert_detail(obj, "Name mismatch", left.name, right.name) + raise_assert_detail("Series", "Name mismatch", left.name, right.name) _assert_series_inner( - left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol, obj + left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol ) @@ -292,8 +294,8 @@ def assert_series_not_equal( ) except AssertionError: return - - raise AssertionError("Expected the two series to compare unequal") + else: + raise AssertionError("Expected the input Series to be unequal.") def _assert_series_inner( @@ -304,7 +306,6 @@ def _assert_series_inner( nans_compare_equal: bool, atol: float, rtol: float, - obj: str, ) -> None: """Compare Series dtype + values.""" try: @@ -314,7 +315,7 @@ def _assert_series_inner( check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean if check_dtype and left.dtype != right.dtype: - raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype) + raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype) # confirm that we can call 'is_nan' on both sides left_is_float = left.dtype in (Float32, Float64) @@ -333,7 +334,7 @@ def _assert_series_inner( if unequal.any(): if check_exact: raise_assert_detail( - obj, "Exact value mismatch", left=list(left), right=list(right) + "Series", "Exact value mismatch", left=list(left), right=list(right) ) else: # apply check with tolerance (to the known-unequal matches). @@ -354,27 +355,34 @@ def _assert_series_inner( if mismatch: raise_assert_detail( - obj, f"Value mismatch{nan_info}", left=list(left), right=list(right) + "Series", + f"Value mismatch{nan_info}", + left=list(left), + right=list(right), ) def raise_assert_detail( obj: str, - message: str, + detail: str, left: Any, right: Any, + exc: AssertionError | None = None, ) -> None: + """Raise a detailed assertion error.""" __tracebackhide__ = True - msg = f"""{obj} are different - -{message}""" + error_msg = textwrap.dedent( + f"""\ + {obj} are different. - msg += f""" -[left]: {left} -[right]: {right}""" + {detail} + [left]: {left} + [right]: {right}\ + """ + ) - raise AssertionError(msg) + raise AssertionError(error_msg) from exc def is_categorical_dtype(data_type: Any) -> bool: @@ -389,6 +397,7 @@ def is_categorical_dtype(data_type: Any) -> bool: def assert_frame_equal_local_categoricals( df_a: pli.DataFrame, df_b: pli.DataFrame ) -> None: + """Assert frame equal for frames containing categoricals.""" for (a_name, a_value), (b_name, b_value) in zip( df_a.schema.items(), df_b.schema.items() ): diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index cce128ff880c..453cf2c0dfc4 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -16,7 +16,7 @@ def test_compare_series_value_mismatch() -> None: srs2 = pl.Series([2, 3, 4]) assert_series_not_equal(srs1, srs2) - with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"): + with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"): assert_series_equal(srs1, srs2) @@ -97,7 +97,7 @@ def test_compare_series_value_mismatch_string() -> None: srs1 = pl.Series(["hello", "no"]) srs2 = pl.Series(["hello", "yes"]) with pytest.raises( - AssertionError, match="Series are different\n\nExact value mismatch" + AssertionError, match="Series are different.\n\nExact value mismatch" ): assert_series_equal(srs1, srs2) @@ -105,25 +105,29 @@ def test_compare_series_value_mismatch_string() -> None: def test_compare_series_type_mismatch() -> None: srs1 = pl.Series([1, 2, 3]) srs2 = pl.DataFrame({"col1": [2, 3, 4]}) - with pytest.raises(AssertionError, match="Series are different\n\nType mismatch"): + with pytest.raises( + AssertionError, match="Inputs are different.\n\nUnexpected input types" + ): assert_series_equal(srs1, srs2) # type: ignore[arg-type] srs3 = pl.Series([1.0, 2.0, 3.0]) - with pytest.raises(AssertionError, match="Series are different\n\nDtype mismatch"): + with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"): assert_series_equal(srs1, srs3) def test_compare_series_name_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match="Series are different\n\nName mismatch"): + with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"): assert_series_equal(srs1, srs2) def test_compare_series_shape_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match="Series are different\n\nShape mismatch"): + with pytest.raises( + AssertionError, match="Series are different.\n\nLength mismatch" + ): assert_series_equal(srs1, srs2) @@ -131,7 +135,7 @@ def test_compare_series_value_exact_mismatch() -> None: srs1 = pl.Series([1.0, 2.0, 3.0]) srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) with pytest.raises( - AssertionError, match="Series are different\n\nExact value mismatch" + AssertionError, match="Series are different.\n\nExact value mismatch" ): assert_series_equal(srs1, srs2, check_exact=True) @@ -150,9 +154,7 @@ def test_compare_frame_equal_nans() -> None: data={"x": [1.0, nan], "y": [None, 2.0]}, schema=[("x", pl.Float32), ("y", pl.Float64)], ) - with pytest.raises( - AssertionError, match="DataFrames are different\n\nExact value mismatch" - ): + with pytest.raises(AssertionError, match="Values for column 'y' are different"): assert_frame_equal(df1, df2, check_exact=True) @@ -166,7 +168,7 @@ def test_assert_frame_equal_types() -> None: df1 = pl.DataFrame({"a": [1, 2]}) srs1 = pl.Series(values=[1, 2], name="a") with pytest.raises( - AssertionError, match="DataFrames are different\n\nType mismatch" + AssertionError, match="Inputs are different.\n\nUnexpected input types" ): assert_frame_equal(df1, srs1) # type: ignore[arg-type] @@ -175,7 +177,7 @@ def test_assert_frame_equal_length_mismatch() -> None: df1 = pl.DataFrame({"a": [1, 2]}) df2 = pl.DataFrame({"a": [1, 2, 3]}) with pytest.raises( - AssertionError, match="DataFrames are different\n\nLength mismatch" + AssertionError, match="DataFrames are different.\n\nLength mismatch" ): assert_frame_equal(df1, df2) @@ -216,7 +218,7 @@ def test_assert_frame_equal_ignore_row_order() -> None: df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) df2 = pl.DataFrame({"a": [2, 1], "b": [3, 4]}) df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) - with pytest.raises(AssertionError, match="Value mismatch"): + with pytest.raises(AssertionError, match="Values for column 'a' are different."): assert_frame_equal(df1, df2) assert_frame_equal(df1, df2, check_row_order=False)