From 5301be4a1cf96b55ecdfd7b85790817e8867ff10 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 1 Feb 2024 14:43:44 +0800 Subject: [PATCH] fix: Fix for `set_operations` of binary dtype (#14152) --- .../polars-ops/src/chunked_array/list/sets.rs | 2 +- .../tests/unit/namespaces/list/__init__.py | 0 .../unit/namespaces/{ => list}/test_list.py | 129 -------------- .../namespaces/list/test_set_operations.py | 163 ++++++++++++++++++ 4 files changed, 164 insertions(+), 130 deletions(-) create mode 100644 py-polars/tests/unit/namespaces/list/__init__.py rename py-polars/tests/unit/namespaces/{ => list}/test_list.py (86%) create mode 100644 py-polars/tests/unit/namespaces/list/test_set_operations.py diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index 9fb0373d9d6e..b33be852b48d 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -368,7 +368,7 @@ fn array_set_operation( binary(&a, &b, offsets_a, offsets_b, set_op, validity, true) }, - ArrowDataType::LargeBinary => { + ArrowDataType::BinaryView => { let a = values_a.as_any().downcast_ref::().unwrap(); let b = values_b.as_any().downcast_ref::().unwrap(); binary(a, b, offsets_a, offsets_b, set_op, validity, false) diff --git a/py-polars/tests/unit/namespaces/list/__init__.py b/py-polars/tests/unit/namespaces/list/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/list/test_list.py similarity index 86% rename from py-polars/tests/unit/namespaces/test_list.py rename to py-polars/tests/unit/namespaces/list/test_list.py index 4e66160d5d37..b5bfc438073b 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/list/test_list.py @@ -617,135 +617,6 @@ def test_list_count_matches_boolean_nulls_9141() -> None: assert a.select(pl.col("a").list.count_matches(True))["a"].to_list() == [1] -def test_list_set_oob() -> None: - df = pl.DataFrame({"a": [42, 23]}) - assert df.select(pl.col("a").list.set_intersection([])).to_dict( - as_series=False - ) == {"a": [[], []]} - - -def test_list_set_operations_float() -> None: - df = pl.DataFrame( - {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}, - schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)}, - ) - - assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 12.0], - [4.0], - ] - assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ - [1.0, 2.0], - [1.0], - [4.0], - ] - assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ - [3.0], - [], - [], - ] - assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ - [4.0], - [2.0, 12.0], - [], - ] - - -def test_list_set_operations() -> None: - df = pl.DataFrame( - {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]} - ) - - assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ - [1, 2, 3, 4], - [1, 2, 12], - [4], - ] - assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ - [1, 2], - [1], - [4], - ] - assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ - [3], - [], - [], - ] - assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ - [4], - [2, 12], - [], - ] - - # check logical types - dtype = pl.List(pl.Date) - assert ( - df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[ - "b" - ].dtype - == dtype - ) - - df = pl.DataFrame( - { - "a": [["a", "b", "c"], ["b", "e", "z"]], - "b": [["b", "s", "a"], ["a", "e", "f"]], - } - ) - - assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ - ["a", "b", "c", "s"], - ["b", "e", "z", "a", "f"], - ] - - df = pl.DataFrame( - { - "a": [[2, 3, 3], [3, 1], [1, 2, 3]], - "b": [[2, 3, 4], [3, 3, 1], [3, 3]], - } - ) - r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list() - r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list() - exp = [[2, 3], [3, 1], [3]] - assert r1 == exp - assert r2 == exp - - -def test_list_set_operations_broadcast() -> None: - df = pl.DataFrame( - { - "a": [[2, 3, 3], [3, 1], [1, 2, 3]], - } - ) - - assert df.with_columns( - pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]} - assert df.with_columns( - pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]} - assert df.with_columns( - pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[3], [3], [3]]} - assert df.with_columns( - pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a") - ).to_dict(as_series=False) == {"a": [[1], [2], []]} - - -def test_list_set_operation_different_length_chunk_12734() -> None: - df = pl.DataFrame( - { - "a": [[2, 3, 3], [4, 1], [1, 2, 3]], - } - ) - - df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False) - assert df.with_columns( - pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) - ).to_dict(as_series=False) == {"a": [[3], [4], [3]]} - - def test_list_gather_oob_10079() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/namespaces/list/test_set_operations.py b/py-polars/tests/unit/namespaces/list/test_set_operations.py new file mode 100644 index 000000000000..aa4b5e9f561a --- /dev/null +++ b/py-polars/tests/unit/namespaces/list/test_set_operations.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import polars as pl + + +def test_list_set_oob() -> None: + df = pl.DataFrame({"a": [42, 23]}) + assert df.select(pl.col("a").list.set_intersection([])).to_dict( + as_series=False + ) == {"a": [[], []]} + + +def test_list_set_operations_float() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]}, + schema={"a": pl.List(pl.Float32), "b": pl.List(pl.Float32)}, + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 12.0], + [4.0], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [1.0, 2.0], + [1.0], + [4.0], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [3.0], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [4.0], + [2.0, 12.0], + [], + ] + + +def test_list_set_operations() -> None: + df = pl.DataFrame( + {"a": [[1, 2, 3], [1, 1, 1], [4]], "b": [[4, 2, 1], [2, 1, 12], [4]]} + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [1, 2, 3, 4], + [1, 2, 12], + [4], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [1, 2], + [1], + [4], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [3], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [4], + [2, 12], + [], + ] + + # check logical types + dtype = pl.List(pl.Date) + assert ( + df.select(pl.col("b").cast(dtype).list.set_difference(pl.col("a").cast(dtype)))[ + "b" + ].dtype + == dtype + ) + + df = pl.DataFrame( + { + "a": [["a", "b", "c"], ["b", "e", "z"]], + "b": [["b", "s", "a"], ["a", "e", "f"]], + } + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + ["a", "b", "c", "s"], + ["b", "e", "z", "a", "f"], + ] + + df = pl.DataFrame( + { + "a": [[2, 3, 3], [3, 1], [1, 2, 3]], + "b": [[2, 3, 4], [3, 3, 1], [3, 3]], + } + ) + r1 = df.with_columns(pl.col("a").list.set_intersection("b"))["a"].to_list() + r2 = df.with_columns(pl.col("b").list.set_intersection("a"))["b"].to_list() + exp = [[2, 3], [3, 1], [3]] + assert r1 == exp + assert r2 == exp + + +def test_list_set_operations_broadcast() -> None: + df = pl.DataFrame( + { + "a": [[2, 3, 3], [3, 1], [1, 2, 3]], + } + ) + + assert df.with_columns( + pl.col("a").list.set_intersection(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[2], [1], [1, 2]]} + assert df.with_columns( + pl.col("a").list.set_union(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[2, 3, 1], [3, 1, 2], [1, 2, 3]]} + assert df.with_columns( + pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[3], [3], [3]]} + assert df.with_columns( + pl.lit(pl.Series("a", [[1, 2]])).list.set_difference("a") + ).to_dict(as_series=False) == {"a": [[1], [2], []]} + + +def test_list_set_operation_different_length_chunk_12734() -> None: + df = pl.DataFrame( + { + "a": [[2, 3, 3], [4, 1], [1, 2, 3]], + } + ) + + df = pl.concat([df.slice(0, 1), df.slice(1, 1), df.slice(2, 1)], rechunk=False) + assert df.with_columns( + pl.col("a").list.set_difference(pl.lit(pl.Series([[1, 2]]))) + ).to_dict(as_series=False) == {"a": [[3], [4], [3]]} + + +def test_list_set_operations_binary() -> None: + df = pl.DataFrame( + { + "a": [[b"1", b"2", b"3"], [b"1", b"1", b"1"], [b"4"]], + "b": [[b"4", b"2", b"1"], [b"2", b"1", b"12"], [b"4"]], + }, + schema={"a": pl.List(pl.Binary), "b": pl.List(pl.Binary)}, + ) + + assert df.select(pl.col("a").list.set_union("b"))["a"].to_list() == [ + [b"1", b"2", b"3", b"4"], + [b"1", b"2", b"12"], + [b"4"], + ] + assert df.select(pl.col("a").list.set_intersection("b"))["a"].to_list() == [ + [b"1", b"2"], + [b"1"], + [b"4"], + ] + assert df.select(pl.col("a").list.set_difference("b"))["a"].to_list() == [ + [b"3"], + [], + [], + ] + assert df.select(pl.col("b").list.set_difference("a"))["b"].to_list() == [ + [b"4"], + [b"2", b"12"], + [], + ]