Skip to content

Commit

Permalink
feat(python): boolean Series broadcast comparison (eq/neq) against …
Browse files Browse the repository at this point in the history
…scalar True/False (#6797)
  • Loading branch information
alexander-beedie authored Feb 11, 2023
1 parent c3d91b9 commit b5fa39c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
13 changes: 10 additions & 3 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,22 +430,29 @@ def __rxor__(self, other: Series) -> Series:
return self.__xor__(other)

def _comp(self, other: Any, op: ComparisonOperator) -> Series:
# special edge-case; boolean broadcast series (eq/neq) is its own result
if self.dtype == Boolean and isinstance(other, bool) and op in ("eq", "neq"):
if (other is True and op == "eq") or (other is False and op == "neq"):
return self.clone()
elif (other is False and op == "eq") or (other is True and op == "neq"):
return ~self

if isinstance(other, datetime) and self.dtype == Datetime:
ts = _datetime_to_pl_timestamp(other, self.time_unit)
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
return wrap_s(f(ts))
if isinstance(other, time) and self.dtype == Time:
elif isinstance(other, time) and self.dtype == Time:
d = _time_to_pl_time(other)
f = get_ffi_func(op + "_<>", Int64, self._s)
assert f is not None
return wrap_s(f(d))
if isinstance(other, date) and self.dtype == Date:
elif isinstance(other, date) and self.dtype == Date:
d = _date_to_pl_date(other)
f = get_ffi_func(op + "_<>", Int32, self._s)
assert f is not None
return wrap_s(f(d))
if self.dtype == Categorical and not isinstance(other, Series):
elif self.dtype == Categorical and not isinstance(other, Series):
other = Series([other])

if isinstance(other, Sequence) and not isinstance(other, str):
Expand Down
25 changes: 16 additions & 9 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1539,7 +1539,12 @@ def test_comparisons_float_series_to_int() -> None:

def test_comparisons_bool_series_to_int() -> None:
srs_bool = pl.Series([True, False])
# todo: do we want this to work?

# (native bool comparison should work...)
for t, f in ((True, False), (False, True)):
assert list(srs_bool == t) == list(srs_bool != f) == [t, f]

# TODO: do we want this to work?
assert_series_equal(srs_bool / 1, pl.Series([True, False], dtype=Float64))
match = (
r"cannot do arithmetic with series of dtype: Boolean"
Expand All @@ -1557,14 +1562,16 @@ def test_comparisons_bool_series_to_int() -> None:
srs_bool % 2
with pytest.raises(ValueError, match=match):
srs_bool * 1
with pytest.raises(
TypeError, match=r"'<' not supported between instances of 'Series' and 'int'"
):
srs_bool < 2 # noqa: B015
with pytest.raises(
TypeError, match=r"'>' not supported between instances of 'Series' and 'int'"
):
srs_bool > 2 # noqa: B015

from operator import ge, gt, le, lt

for op in (ge, gt, le, lt):
for scalar in (0, 1.0, True, False):
with pytest.raises(
TypeError,
match=r"'\W{1,2}' not supported .* 'Series' and '(int|bool|float)'",
):
op(srs_bool, scalar)


def test_abs() -> None:
Expand Down

0 comments on commit b5fa39c

Please sign in to comment.