Skip to content

Commit

Permalink
fix(python): ensure kwargs filter behaviour matches docstring (expe…
Browse files Browse the repository at this point in the history
…ct equivalence with `eq`) (#13864)
  • Loading branch information
alexander-beedie authored Jan 20, 2024
1 parent 7c13fa4 commit 7bea512
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
4 changes: 2 additions & 2 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _scan_parquet(
return scan # type: ignore[return-value]

if storage_options:
storage_options = list(storage_options.items()) # type: ignore[assignment]
storage_options = list(storage_options.items()) # type: ignore[assignment]
else:
# Handle empty dict input
storage_options = None
Expand Down Expand Up @@ -2698,7 +2698,7 @@ def filter(

# unpack equality constraints from kwargs
all_predicates.extend(
F.col(name).eq_missing(value) for name, value in constraints.items()
F.col(name).eq(value) for name, value in constraints.items()
)
if not (all_predicates or boolean_masks):
msg = "at least one predicate or constraint must be provided"
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/unit/functions/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def test_all_any_horizontally() -> None:
assert_frame_equal(result, expected)

# note: a kwargs filter will use an internal call to all_horizontal
dfltr = df.lazy().filter(var1=None, var3=False)
assert dfltr.collect().rows() == [(None, None, False)]
dfltr = df.lazy().filter(var1=True, var3=False)
assert dfltr.collect().rows() == [(True, False, False)]

# confirm that we reduce the horizontal filter components
# confirm that we reduced the horizontal filter components
# (eg: explain does not contain an "all_horizontal" node)
assert "horizontal" not in dfltr.explain().lower()

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,8 +1670,8 @@ def test_limit() -> None:
def test_filter() -> None:
s = pl.Series("a", [1, 2, 3])
mask = pl.Series("", [True, False, True])
assert_series_equal(s.filter(mask), pl.Series("a", [1, 3]))

assert_series_equal(s.filter(mask), pl.Series("a", [1, 3]))
assert_series_equal(s.filter([True, False, True]), pl.Series("a", [1, 3]))


Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,15 @@ def test_multi_alias_pushdown() -> None:
lf = pl.LazyFrame({"a": [1], "b": [1]})

actual = lf.with_columns(m="a", n="b").filter((pl.col("m") + pl.col("n")) < 2)

plan = actual.explain()

assert "FILTER" not in plan
assert r'SELECTION: "[([(col(\"a\")) + (col(\"b\"))]) < (2)]' in plan

with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
# confirm we aren't using `eq_missing` in the query plan (denoted as " ==v ")
assert " ==v " not in lf.select(pl.col("a").filter(a=None)).explain()


def test_predicate_pushdown_with_window_projections_12637() -> None:
lf = pl.LazyFrame(
Expand Down Expand Up @@ -466,3 +470,24 @@ def test_predicate_pd_join_13300() -> None:
lf = lf.join(lf_other, left_on="new_col", right_on="col4", how="left")
lf = lf.filter(pl.col("new_col") < 12)
assert lf.collect().to_dict(as_series=False) == {"col3": [10], "new_col": [11]}


def test_filter_eq_missing_13861() -> None:
lf = pl.LazyFrame({"a": [1, None, 3], "b": ["xx", "yy", None]})

with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
assert lf.collect().filter(a=None).rows() == []

with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
lff = lf.filter(a=None)
assert lff.collect().rows() == []
assert " ==v " not in lff.explain() # check no `eq_missing` op

with pytest.warns(UserWarning, match="Comparisons with None always result in null"):
assert lf.filter(pl.col("a").eq(None)).collect().rows() == []

for filter_expr in (
pl.col("a").eq_missing(None),
pl.col("a").is_null(),
):
assert lf.collect().filter(filter_expr).rows() == [(None, "yy")]

0 comments on commit 7bea512

Please sign in to comment.