Skip to content

Commit

Permalink
feat(python): support boolean Series broadcast comparison (eq/neq) ag…
Browse files Browse the repository at this point in the history
…ainst scalar True/False
  • Loading branch information
alexander-beedie committed Feb 11, 2023
1 parent c6a3df0 commit 5020dde
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
16 changes: 11 additions & 5 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,27 +435,33 @@ def _comp(self, other: Any, op: ComparisonOperator) -> Series:
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
return wrap_s(f(ts))
if isinstance(other, time) and self.dtype == Time:
elif isinstance(other, time) and self.dtype == Time:
d = _time_to_pl_time(other)
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
return wrap_s(f(d))
if isinstance(other, date) and self.dtype == Date:
elif isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func(op + "_<>", Int32, self._s)
assert f is not None
return wrap_s(f(d))
if self.dtype == Categorical and not isinstance(other, Series):
elif self.dtype == Categorical and not isinstance(other, Series):
other = Series([other])

if isinstance(other, Sequence) and not isinstance(other, str):
other = Series("", other, dtype_if_empty=self.dtype)
if isinstance(other, Series):
return wrap_s(getattr(self._s, op)(other._s))

f = get_ffi_func(op + "_<>", self.dtype, self._s)
if f is None and self.dtype == Boolean and op in ("eq", "neq"):
# TODO: ffi func for boolean? (until then, broadcast)
other = Series("", [other], dtype_if_empty=self.dtype)
return wrap_s(getattr(self._s, op)(other._s))

if other is not None:
other = maybe_cast(other, self.dtype, self.time_unit)
f = get_ffi_func(op + "_<>", self.dtype, self._s)

if f is None:
return NotImplemented

Expand Down Expand Up @@ -2827,7 +2833,7 @@ def to_arrow(self) -> pa.Array:
"""
return self._s.to_arrow()

def to_pandas( # noqa: D417
def to_pandas(
self, *args: Any, use_pyarrow_extension_array: bool = False, **kwargs: Any
) -> pd.Series:
"""
Expand Down
25 changes: 16 additions & 9 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,12 @@ def test_comparisons_float_series_to_int() -> None:

def test_comparisons_bool_series_to_int() -> None:
srs_bool = pl.Series([True, False])
# todo: do we want this to work?

# (native bool comparison should work...)
for t, f in ((True, False), (False, True)):
assert list(srs_bool == t) == list(srs_bool != f) == [t, f]

# TODO: do we want this to work?
assert_series_equal(srs_bool / 1, pl.Series([True, False], dtype=Float64))
match = (
r"cannot do arithmetic with series of dtype: Boolean"
Expand All @@ -1557,14 +1562,16 @@ def test_comparisons_bool_series_to_int() -> None:
srs_bool % 2
with pytest.raises(ValueError, match=match):
srs_bool * 1
with pytest.raises(
TypeError, match=r"'<' not supported between instances of 'Series' and 'int'"
):
srs_bool < 2 # noqa: B015
with pytest.raises(
TypeError, match=r"'>' not supported between instances of 'Series' and 'int'"
):
srs_bool > 2 # noqa: B015

from operator import ge, gt, le, lt

for op in (ge, gt, le, lt):
for scalar in (0, 1.0, True, False):
with pytest.raises(
TypeError,
match=r"'\W{1,2}' not supported .* 'Series' and '(int|bool|float)'",
):
op(srs_bool, scalar)


def test_abs() -> None:
Expand Down

0 comments on commit 5020dde

Please sign in to comment.