Skip to content

Commit

Permalink
fix: Fix for set_operations of binary dtype (#14152)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Feb 1, 2024
1 parent e989418 commit 5301be4
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 130 deletions.
2 changes: 1 addition & 1 deletion crates/polars-ops/src/chunked_array/list/sets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ fn array_set_operation(

binary(&a, &b, offsets_a, offsets_b, set_op, validity, true)
},
ArrowDataType::LargeBinary => {
ArrowDataType::BinaryView => {
let a = values_a.as_any().downcast_ref::<BinaryViewArray>().unwrap();
let b = values_b.as_any().downcast_ref::<BinaryViewArray>().unwrap();
binary(a, b, offsets_a, offsets_b, set_op, validity, false)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -617,135 +617,6 @@ def test_list_count_matches_boolean_nulls_9141() -> None:
assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1]


def test_list_set_oob() -> None:
df = pl.DataFrame({"a": [42, 23]})
assert df.select(pl.col("a").list.set_intersection([])).to_dict(
as_series=False
) == {"a": [[], []]}


def test_list_set_operations_float() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]},
schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)},
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 12.0],
[4.0],
]
assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [
[1.0, 2.0],
[1.0],
[4.0],
]
assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [
[3.0],
[],
[],
]
assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [
[4.0],
[2.0, 12.0],
[],
]


def test_list_set_operations() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
[1, 2, 3, 4],
[1, 2, 12],
[4],
]
assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [
[1, 2],
[1],
[4],
]
assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [
[3],
[],
[],
]
assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [
[4],
[2, 12],
[],
]

# check logical types
dtype = pl.List(pl.Date)
assert (
df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[
"b"
].dtype
== dtype
)

df = pl.DataFrame(
{
"a": [["a", "b", "c"], ["b", "e", "z"]],
"b": [["b", "s", "a"], ["a", "e", "f"]],
}
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
["a", "b", "c", "s"],
["b", "e", "z", "a", "f"],
]

df = pl.DataFrame(
{
"a": [[2, 3, 3], [3, 1], [1, 2, 3]],
"b": [[2, 3, 4], [3, 3, 1], [3, 3]],
}
)
r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list()
r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list()
exp = [[2, 3], [3, 1], [3]]
assert r1 == exp
assert r2 == exp


def test_list_set_operations_broadcast() -> None:
df = pl.DataFrame(
{
"a": [[2, 3, 3], [3, 1], [1, 2, 3]],
}
)

assert df.with_columns(
pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]}
assert df.with_columns(
pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]}
assert df.with_columns(
pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[3], [3], [3]]}
assert df.with_columns(
pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a")
).to_dict(as_series=False) == {"a": [[1], [2], []]}


def test_list_set_operation_different_length_chunk_12734() -> None:
df = pl.DataFrame(
{
"a": [[2, 3, 3], [4, 1], [1, 2, 3]],
}
)

df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False)
assert df.with_columns(
pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[3], [4], [3]]}


def test_list_gather_oob_10079() -> None:
df = pl.DataFrame(
{
Expand Down
163 changes: 163 additions & 0 deletions py-polars/tests/unit/namespaces/list/test_set_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from __future__ import annotations

import polars as pl


def test_list_set_oob() -> None:
df = pl.DataFrame({"a": [42, 23]})
assert df.select(pl.col("a").list.set_intersection([])).to_dict(
as_series=False
) == {"a": [[], []]}


def test_list_set_operations_float() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]},
schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)},
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 12.0],
[4.0],
]
assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [
[1.0, 2.0],
[1.0],
[4.0],
]
assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [
[3.0],
[],
[],
]
assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [
[4.0],
[2.0, 12.0],
[],
]


def test_list_set_operations() -> None:
df = pl.DataFrame(
{"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
[1, 2, 3, 4],
[1, 2, 12],
[4],
]
assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [
[1, 2],
[1],
[4],
]
assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [
[3],
[],
[],
]
assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [
[4],
[2, 12],
[],
]

# check logical types
dtype = pl.List(pl.Date)
assert (
df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[
"b"
].dtype
== dtype
)

df = pl.DataFrame(
{
"a": [["a", "b", "c"], ["b", "e", "z"]],
"b": [["b", "s", "a"], ["a", "e", "f"]],
}
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
["a", "b", "c", "s"],
["b", "e", "z", "a", "f"],
]

df = pl.DataFrame(
{
"a": [[2, 3, 3], [3, 1], [1, 2, 3]],
"b": [[2, 3, 4], [3, 3, 1], [3, 3]],
}
)
r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list()
r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list()
exp = [[2, 3], [3, 1], [3]]
assert r1 == exp
assert r2 == exp


def test_list_set_operations_broadcast() -> None:
df = pl.DataFrame(
{
"a": [[2, 3, 3], [3, 1], [1, 2, 3]],
}
)

assert df.with_columns(
pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]}
assert df.with_columns(
pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]}
assert df.with_columns(
pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[3], [3], [3]]}
assert df.with_columns(
pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a")
).to_dict(as_series=False) == {"a": [[1], [2], []]}


def test_list_set_operation_different_length_chunk_12734() -> None:
df = pl.DataFrame(
{
"a": [[2, 3, 3], [4, 1], [1, 2, 3]],
}
)

df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False)
assert df.with_columns(
pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]])))
).to_dict(as_series=False) == {"a": [[3], [4], [3]]}


def test_list_set_operations_binary() -> None:
df = pl.DataFrame(
{
"a": [[b"1", b"2", b"3"], [b"1", b"1", b"1"], [b"4"]],
"b": [[b"4", b"2", b"1"], [b"2", b"1", b"12"], [b"4"]],
},
schema={"a": pl.List(pl.Binary), "b": pl.List(pl.Binary)},
)

assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [
[b"1", b"2", b"3", b"4"],
[b"1", b"2", b"12"],
[b"4"],
]
assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [
[b"1", b"2"],
[b"1"],
[b"4"],
]
assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [
[b"3"],
[],
[],
]
assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [
[b"4"],
[b"2", b"12"],
[],
]

0 comments on commit 5301be4

Please sign in to comment.