Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add more tests for list arithmetic #19225

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 1 addition & 94 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
UInt32,
UInt64,
)
from polars.exceptions import ColumnNotFoundError, InvalidOperationError, ShapeError
from polars.exceptions import ColumnNotFoundError, InvalidOperationError
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.conftest import INTEGER_DTYPES, NUMERIC_DTYPES

Expand Down Expand Up @@ -610,99 +610,6 @@ def test_array_arithmetic_same_size(
)


@pytest.mark.parametrize(
("expected", "expr", "column_names"),
[
([[2, 4], [6]], lambda a, b: a + b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a - b, ("a", "a")),
([[1, 4], [9]], lambda a, b: a * b, ("a", "a")),
([[1.0, 1.0], [1.0]], lambda a, b: a / b, ("a", "a")),
([[0, 0], [0]], lambda a, b: a % b, ("a", "a")),
(
[[3, 4], [7]],
lambda a, b: a + b,
("a", "uint8"),
),
],
)
def test_list_arithmetic_same_size(
expected: Any,
expr: Callable[[pl.Series | pl.Expr, pl.Series | pl.Expr], pl.Series],
column_names: tuple[str, str],
) -> None:
df = pl.DataFrame(
[
pl.Series("a", [[1, 2], [3]]),
pl.Series("uint8", [[2, 2], [4]], dtype=pl.List(pl.UInt8())),
pl.Series("nested", [[[1, 2]], [[3]]]),
pl.Series(
"nested_uint8", [[[1, 2]], [[3]]], dtype=pl.List(pl.List(pl.UInt8()))
),
]
)
# Expr-based arithmetic:
assert_frame_equal(
df.select(expr(pl.col(column_names[0]), pl.col(column_names[1]))),
pl.Series(column_names[0], expected).to_frame(),
)
# Direct arithmetic on the Series:
assert_series_equal(
expr(df[column_names[0]], df[column_names[1]]),
pl.Series(column_names[0], expected),
)


@pytest.mark.parametrize(
("a", "b", "expected"),
[
([[1, 2, 3]], [[1, None, 5]], [[2, None, 8]]),
([[2], None, [5]], [None, [3], [2]], [None, None, [7]]),
],
)
def test_list_arithmetic_nulls(a: list[Any], b: list[Any], expected: list[Any]) -> None:
series_a = pl.Series(a)
series_b = pl.Series(b)
series_expected = pl.Series(expected)

# Same dtype:
assert_series_equal(series_a + series_b, series_expected)

# Different dtype:
assert_series_equal(
series_a._recursive_cast_to_dtype(pl.Int32())
+ series_b._recursive_cast_to_dtype(pl.Int64()),
series_expected._recursive_cast_to_dtype(pl.Int64()),
)


def test_list_arithmetic_error_cases() -> None:
# Different series length:
with pytest.raises(InvalidOperationError, match="different lengths"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], [3, 4]])
with pytest.raises(InvalidOperationError, match="different lengths"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1, 2], None])

# Different list length:
with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [1, 2], [1, 2]]) / pl.Series("b", [[1]])

with pytest.raises(ShapeError, match="lengths differed at index 0: 2 != 1"):
_ = pl.Series("a", [[1, 2], [2, 3]]) / pl.Series("b", [[1], None])

# Wrong types:
with pytest.raises(
InvalidOperationError, match="add operation not supported for dtypes"
):
_ = pl.Series("a", [[1, 2]]) + pl.Series("b", ["hello"])

# Different nesting:
with pytest.raises(
InvalidOperationError,
match="cannot add two list columns with non-numeric inner types",
):
_ = pl.Series("a", [[1]]) + pl.Series("b", [[[1]]])


def test_schema_owned_arithmetic_5669() -> None:
df = (
pl.LazyFrame({"A": [1, 2, 3]})
Expand Down
Loading