Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): validate operator arithmetic with None, fix Series edge-case #13780

Merged
merged 3 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,11 @@ def maybe_cast(el: Any, dtype: PolarsDataType) -> Any:
_timedelta_to_pl_timedelta,
)

time_unit = getattr(dtype, "time_unit", None)
if isinstance(el, datetime):
time_unit = getattr(dtype, "time_unit", None)
return _datetime_to_pl_timestamp(el, time_unit)
elif isinstance(el, timedelta):
time_unit = getattr(dtype, "time_unit", None)
return _timedelta_to_pl_timedelta(el, time_unit)

py_type = dtype_to_py_type(dtype)
Expand Down
8 changes: 6 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,11 +958,14 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self:
if isinstance(other, pl.Expr):
# expand pl.lit, pl.datetime, pl.duration Exprs to compatible Series
other = self.to_frame().select_seq(other).to_series()
elif other is None:
other = pl.Series("", [None])

if isinstance(other, Series):
return self._from_pyseries(getattr(self._s, op_s)(other._s))
if _check_for_numpy(other) and isinstance(other, np.ndarray):
elif _check_for_numpy(other) and isinstance(other, np.ndarray):
return self._from_pyseries(getattr(self._s, op_s)(Series(other)._s))
if (
elif (
isinstance(other, (float, date, datetime, timedelta, str))
and not self.dtype.is_float()
):
Expand All @@ -971,6 +974,7 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self:
return self._from_pyseries(getattr(_s, op_s)(self._s))
else:
return self._from_pyseries(getattr(self._s, op_s)(_s))

if isinstance(other, (PyDecimal, int)) and self.dtype.is_decimal():
# Infer the number's scale. Then use the max of the inferred scale and the
# Series' scale. At present, this will cause arithmetic to fail with a
Expand Down
35 changes: 34 additions & 1 deletion py-polars/tests/unit/operations/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import operator
from datetime import date, datetime, timedelta
from typing import Any

import numpy as np
import pytest

import polars as pl
from polars.testing import assert_series_equal
from polars.datatypes import FLOAT_DTYPES, INTEGER_DTYPES
from polars.testing import assert_frame_equal, assert_series_equal


def test_sqrt_neg_inf() -> None:
Expand Down Expand Up @@ -246,3 +249,33 @@ def test_arithmetic_null_count() -> None:
"broadcast_left": [1],
"broadcast_right": [1],
}


@pytest.mark.parametrize(
"op",
[
operator.add,
operator.floordiv,
operator.mod,
operator.mul,
operator.sub,
],
)
def test_operator_arithmetic_with_nulls(op: Any) -> None:
for dtype in FLOAT_DTYPES | INTEGER_DTYPES:
df = pl.DataFrame({"n": [2, 3]}, schema={"n": dtype})
s = df.to_series()

df_expected = pl.DataFrame({"n": [None, None]}, schema={"n": dtype})
s_expected = df_expected.to_series()

# validate expr, frame, and series behaviour with null value arithmetic
op_name = op.__name__
for null_expr in (None, pl.lit(None)):
assert_frame_equal(df_expected, df.select(op(pl.col("n"), null_expr)))
assert_frame_equal(
df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr))
)

assert_frame_equal(df_expected, op(df, None))
assert_series_equal(s_expected, op(s, None))