Skip to content

Commit

Permalink
fix(python): validate operator arithmetic with None, fix Series e…
Browse files Browse the repository at this point in the history
…dge-case (#13780)

Co-authored-by: Stijn de Gooijer <stijn@degooijer.io>
  • Loading branch information
alexander-beedie and stinodego authored Jan 18, 2024
1 parent c9a1468 commit 2728e0b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
3 changes: 2 additions & 1 deletion py-polars/polars/datatypes/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,10 +486,11 @@ def maybe_cast(el: Any, dtype: PolarsDataType) -> Any:
_timedelta_to_pl_timedelta,
)

time_unit = getattr(dtype, "time_unit", None)
if isinstance(el, datetime):
time_unit = getattr(dtype, "time_unit", None)
return _datetime_to_pl_timestamp(el, time_unit)
elif isinstance(el, timedelta):
time_unit = getattr(dtype, "time_unit", None)
return _timedelta_to_pl_timedelta(el, time_unit)

py_type = dtype_to_py_type(dtype)
Expand Down
8 changes: 6 additions & 2 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,11 +958,14 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self:
if isinstance(other, pl.Expr):
# expand pl.lit, pl.datetime, pl.duration Exprs to compatible Series
other = self.to_frame().select_seq(other).to_series()
elif other is None:
other = pl.Series("", [None])

if isinstance(other, Series):
return self._from_pyseries(getattr(self._s, op_s)(other._s))
if _check_for_numpy(other) and isinstance(other, np.ndarray):
elif _check_for_numpy(other) and isinstance(other, np.ndarray):
return self._from_pyseries(getattr(self._s, op_s)(Series(other)._s))
if (
elif (
isinstance(other, (float, date, datetime, timedelta, str))
and not self.dtype.is_float()
):
Expand All @@ -971,6 +974,7 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Self:
return self._from_pyseries(getattr(_s, op_s)(self._s))
else:
return self._from_pyseries(getattr(self._s, op_s)(_s))

if isinstance(other, (PyDecimal, int)) and self.dtype.is_decimal():
# Infer the number's scale. Then use the max of the inferred scale and the
# Series' scale. At present, this will cause arithmetic to fail with a
Expand Down
35 changes: 34 additions & 1 deletion py-polars/tests/unit/operations/test_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import operator
from datetime import date, datetime, timedelta
from typing import Any

import numpy as np
import pytest

import polars as pl
from polars.testing import assert_series_equal
from polars.datatypes import FLOAT_DTYPES, INTEGER_DTYPES
from polars.testing import assert_frame_equal, assert_series_equal


def test_sqrt_neg_inf() -> None:
Expand Down Expand Up @@ -246,3 +249,33 @@ def test_arithmetic_null_count() -> None:
"broadcast_left": [1],
"broadcast_right": [1],
}


@pytest.mark.parametrize(
"op",
[
operator.add,
operator.floordiv,
operator.mod,
operator.mul,
operator.sub,
],
)
def test_operator_arithmetic_with_nulls(op: Any) -> None:
for dtype in FLOAT_DTYPES | INTEGER_DTYPES:
df = pl.DataFrame({"n": [2, 3]}, schema={"n": dtype})
s = df.to_series()

df_expected = pl.DataFrame({"n": [None, None]}, schema={"n": dtype})
s_expected = df_expected.to_series()

# validate expr, frame, and series behaviour with null value arithmetic
op_name = op.__name__
for null_expr in (None, pl.lit(None)):
assert_frame_equal(df_expected, df.select(op(pl.col("n"), null_expr)))
assert_frame_equal(
df_expected, df.select(getattr(pl.col("n"), op_name)(null_expr))
)

assert_frame_equal(df_expected, op(df, None))
assert_series_equal(s_expected, op(s, None))

0 comments on commit 2728e0b

Please sign in to comment.