From e53437033f00b49e4c4484ad942810349d1b6628 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 24 Jan 2024 16:27:29 +0100 Subject: [PATCH 1/4] Move streaming test to correct module --- py-polars/tests/unit/io/test_lazy_json.py | 15 --------------- .../tests/unit/streaming/test_streaming_io.py | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/py-polars/tests/unit/io/test_lazy_json.py b/py-polars/tests/unit/io/test_lazy_json.py index a14f47690a5b..8d0967d10277 100644 --- a/py-polars/tests/unit/io/test_lazy_json.py +++ b/py-polars/tests/unit/io/test_lazy_json.py @@ -141,18 +141,3 @@ def test_anonymous_scan_explain(io_files_path: Path) -> None: assert "Anonymous" in q.explain() assert "Anonymous" in q.show_graph(raw_output=True) # type: ignore[operator] - -def test_sink_ndjson_should_write_same_data( - io_files_path: Path, tmp_path: Path -) -> None: - tmp_path.mkdir(exist_ok=True) - # Arrange - source_path = io_files_path / "foods1.csv" - target_path = tmp_path / "foods_test.ndjson" - expected = pl.read_csv(source_path) - lf = pl.scan_csv(source_path) - # Act - lf.sink_ndjson(target_path) - df = pl.read_ndjson(target_path) - # Assert - assert_frame_equal(df, expected) diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index 6b0b02f8d5b7..ed09cb9e704e 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -198,3 +198,21 @@ def test_streaming_cross_join_schema(tmp_path: Path) -> None: a.join(b, how="cross").sink_parquet(file_path) read = pl.read_parquet(file_path, parallel="none") assert read.to_dict(as_series=False) == {"a": [1, 2], "b": ["b", "b"]} + + +@pytest.mark.write_disk() +def test_sink_ndjson_should_write_same_data( + io_files_path: Path, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + + source_path = io_files_path / "foods1.csv" + target_path = tmp_path / "foods_test.ndjson" + + expected = pl.read_csv(source_path) + + lf = pl.scan_csv(source_path) + lf.sink_ndjson(target_path) + df = pl.read_ndjson(target_path) + + assert_frame_equal(df, expected) From 907ceb0c17c6466ee84ab2399c7bb0be00bbd0d5 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 24 Jan 2024 16:46:27 +0100 Subject: [PATCH 2/4] Add missing marker --- .../unit/streaming/test_streaming_categoricals.py | 4 ++++ py-polars/tests/unit/streaming/test_streaming_sort.py | 10 ++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py index 0df920d4ee0d..65dd967abb76 100644 --- a/py-polars/tests/unit/streaming/test_streaming_categoricals.py +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -1,5 +1,9 @@ +import pytest + import polars as pl +pytestmark = pytest.mark.xdist_group("streaming") + def test_streaming_nested_categorical() -> None: assert ( diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py index 83957bdcffc8..c9befd8362c3 100644 --- a/py-polars/tests/unit/streaming/test_streaming_sort.py +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -1,20 +1,18 @@ from __future__ import annotations +from collections import Counter from datetime import datetime from typing import TYPE_CHECKING, Any -if TYPE_CHECKING: - from pathlib import Path - - -from collections import Counter - import numpy as np import pytest import polars as pl from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from pathlib import Path + pytestmark = pytest.mark.xdist_group("streaming") From f6f46c025e86dfb342543abae5939c35070627c5 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 24 Jan 2024 17:18:12 +0100 Subject: [PATCH 3/4] Move marker to streaming module level --- py-polars/tests/unit/streaming/__init__.py | 3 +++ py-polars/tests/unit/streaming/test_streaming.py | 2 -- .../tests/unit/streaming/test_streaming_categoricals.py | 4 ---- py-polars/tests/unit/streaming/test_streaming_cse.py | 2 -- py-polars/tests/unit/streaming/test_streaming_group_by.py | 2 -- py-polars/tests/unit/streaming/test_streaming_io.py | 7 ++----- py-polars/tests/unit/streaming/test_streaming_join.py | 3 --- py-polars/tests/unit/streaming/test_streaming_sort.py | 2 -- py-polars/tests/unit/streaming/test_streaming_unique.py | 2 -- 9 files changed, 5 insertions(+), 22 deletions(-) diff --git a/py-polars/tests/unit/streaming/__init__.py b/py-polars/tests/unit/streaming/__init__.py index e69de29bb2d1..0cd682cbde8b 100644 --- a/py-polars/tests/unit/streaming/__init__.py +++ b/py-polars/tests/unit/streaming/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytestmark = pytest.mark.xdist_group("streaming") diff --git a/py-polars/tests/unit/streaming/test_streaming.py b/py-polars/tests/unit/streaming/test_streaming.py index fd18289fdc86..c91262be01bf 100644 --- a/py-polars/tests/unit/streaming/test_streaming.py +++ b/py-polars/tests/unit/streaming/test_streaming.py @@ -17,8 +17,6 @@ if TYPE_CHECKING: from polars.type_aliases import JoinStrategy -pytestmark = pytest.mark.xdist_group("streaming") - def test_streaming_categoricals_5921() -> None: with pl.StringCache(): diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py index 65dd967abb76..0df920d4ee0d 100644 --- a/py-polars/tests/unit/streaming/test_streaming_categoricals.py +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -1,9 +1,5 @@ -import pytest - import polars as pl -pytestmark = pytest.mark.xdist_group("streaming") - def test_streaming_nested_categorical() -> None: assert ( diff --git a/py-polars/tests/unit/streaming/test_streaming_cse.py b/py-polars/tests/unit/streaming/test_streaming_cse.py index 909bbc272f3a..b650c646c552 100644 --- a/py-polars/tests/unit/streaming/test_streaming_cse.py +++ b/py-polars/tests/unit/streaming/test_streaming_cse.py @@ -7,8 +7,6 @@ import polars as pl from polars.testing import assert_frame_equal -pytestmark = pytest.mark.xdist_group("streaming") - def test_cse_expr_selection_streaming(monkeypatch: Any, capfd: Any) -> None: monkeypatch.setenv("POLARS_VERBOSE", "1") diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index c93ec2357ba5..dbb53ec66656 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -9,8 +9,6 @@ import polars as pl from polars.testing import assert_frame_equal -pytestmark = pytest.mark.xdist_group("streaming") - @pytest.mark.slow() def test_streaming_group_by_sorted_fast_path_nulls_10273() -> None: diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index ed09cb9e704e..b4e009a469fb 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -1,7 +1,7 @@ from __future__ import annotations -import unittest from typing import TYPE_CHECKING +from unittest.mock import patch import pytest @@ -12,9 +12,6 @@ from pathlib import Path -pytestmark = pytest.mark.xdist_group("streaming") - - @pytest.mark.write_disk() def test_streaming_parquet_glob_5900(df: pl.DataFrame, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -122,7 +119,7 @@ def test_sink_csv_with_options() -> None: passed into the rust-polars correctly. """ df = pl.LazyFrame({"dummy": ["abc"]}) - with unittest.mock.patch.object(df, "_ldf") as ldf: + with patch.object(df, "_ldf") as ldf: df.sink_csv( "path", include_bom=True, diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index 68719711d546..22baf6f6d4c2 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -4,13 +4,10 @@ import numpy as np import pandas as pd -import pytest import polars as pl from polars.testing import assert_frame_equal -pytestmark = pytest.mark.xdist_group("streaming") - def test_streaming_joins() -> None: n = 100 diff --git a/py-polars/tests/unit/streaming/test_streaming_sort.py b/py-polars/tests/unit/streaming/test_streaming_sort.py index c9befd8362c3..9b541543c063 100644 --- a/py-polars/tests/unit/streaming/test_streaming_sort.py +++ b/py-polars/tests/unit/streaming/test_streaming_sort.py @@ -13,8 +13,6 @@ if TYPE_CHECKING: from pathlib import Path -pytestmark = pytest.mark.xdist_group("streaming") - def assert_df_sorted_by( df: pl.DataFrame, diff --git a/py-polars/tests/unit/streaming/test_streaming_unique.py b/py-polars/tests/unit/streaming/test_streaming_unique.py index c79a734464a3..8afb8eef692c 100644 --- a/py-polars/tests/unit/streaming/test_streaming_unique.py +++ b/py-polars/tests/unit/streaming/test_streaming_unique.py @@ -10,8 +10,6 @@ if TYPE_CHECKING: from pathlib import Path -pytestmark = pytest.mark.xdist_group("streaming") - @pytest.mark.write_disk() @pytest.mark.slow() From 2cda97d31011e31c828458501d39c92168be68e3 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 24 Jan 2024 17:25:50 +0100 Subject: [PATCH 4/4] Move/mark more streaming tests --- py-polars/tests/unit/io/test_hive.py | 3 + py-polars/tests/unit/io/test_lazy_json.py | 1 - py-polars/tests/unit/io/test_lazy_parquet.py | 42 +--------- .../tests/unit/operations/test_group_by.py | 16 ---- py-polars/tests/unit/operations/test_join.py | 82 ------------------ .../unit/streaming/test_streaming_group_by.py | 16 ++++ .../tests/unit/streaming/test_streaming_io.py | 42 +++++++++- .../unit/streaming/test_streaming_join.py | 83 +++++++++++++++++++ 8 files changed, 144 insertions(+), 141 deletions(-) diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 9b1de1ac8003..ad4145abf8df 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -12,6 +12,7 @@ @pytest.mark.skip( reason="Broken by pyarrow 15 release: https://github.com/pola-rs/polars/issues/13892" ) +@pytest.mark.xdist_group("streaming") @pytest.mark.write_disk() def test_hive_partitioned_predicate_pushdown( io_files_path: Path, tmp_path: Path, monkeypatch: Any, capfd: Any @@ -91,6 +92,7 @@ def test_hive_partitioned_predicate_pushdown_skips_correct_number_of_files( @pytest.mark.skip( reason="Broken by pyarrow 15 release: https://github.com/pola-rs/polars/issues/13892" ) +@pytest.mark.xdist_group("streaming") @pytest.mark.write_disk() def test_hive_partitioned_slice_pushdown(io_files_path: Path, tmp_path: Path) -> None: df = pl.read_ipc(io_files_path / "*.ipc") @@ -127,6 +129,7 @@ def test_hive_partitioned_slice_pushdown(io_files_path: Path, tmp_path: Path) -> @pytest.mark.skip( reason="Broken by pyarrow 15 release: https://github.com/pola-rs/polars/issues/13892" ) +@pytest.mark.xdist_group("streaming") @pytest.mark.write_disk() def test_hive_partitioned_projection_pushdown( io_files_path: Path, tmp_path: Path diff --git a/py-polars/tests/unit/io/test_lazy_json.py b/py-polars/tests/unit/io/test_lazy_json.py index 8d0967d10277..97e32f3eaee6 100644 --- a/py-polars/tests/unit/io/test_lazy_json.py +++ b/py-polars/tests/unit/io/test_lazy_json.py @@ -140,4 +140,3 @@ def test_anonymous_scan_explain(io_files_path: Path) -> None: q = pl.scan_ndjson(source=file) assert "Anonymous" in q.explain() assert "Anonymous" in q.show_graph(raw_output=True) # type: ignore[operator] - diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 145ec0a42209..5d6cfbb64b00 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -201,46 +201,6 @@ def test_row_index_schema_parquet(parquet_file_path: Path) -> None: ).dtypes == [pl.UInt32, pl.String] -@pytest.mark.write_disk() -def test_parquet_eq_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: - tmp_path.mkdir(exist_ok=True) - - monkeypatch.setenv("POLARS_VERBOSE", "1") - - df = pl.DataFrame({"idx": pl.arange(100, 200, eager=True)}).with_columns( - (pl.col("idx") // 25).alias("part") - ) - df = pl.concat(df.partition_by("part", as_dict=False), rechunk=False) - assert df.n_chunks("all") == [4, 4] - - file_path = tmp_path / "stats.parquet" - df.write_parquet(file_path, statistics=True, use_pyarrow=False) - - file_path = tmp_path / "stats.parquet" - df.write_parquet(file_path, statistics=True, use_pyarrow=False) - - for streaming in [False, True]: - for pred in [ - pl.col("idx") == 50, - pl.col("idx") == 150, - pl.col("idx") == 210, - ]: - result = ( - pl.scan_parquet(file_path).filter(pred).collect(streaming=streaming) - ) - assert_frame_equal(result, df.filter(pred)) - - captured = capfd.readouterr().err - assert ( - "parquet file must be read, statistics not sufficient for predicate." - in captured - ) - assert ( - "parquet file can be skipped, the statistics were sufficient" - " to apply the predicate." in captured - ) - - @pytest.mark.write_disk() def test_parquet_is_in_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -314,7 +274,7 @@ def test_parquet_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> Non @pytest.mark.write_disk() -def test_streaming_categorical(tmp_path: Path) -> None: +def test_categorical(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) df = pl.DataFrame( diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 0b334568b073..ad68990ff7e2 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -831,22 +831,6 @@ def test_group_by_rolling_deprecated() -> None: assert_frame_equal(result_lazy, expected, check_row_order=False) -def test_group_by_multiple_keys_one_literal() -> None: - df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) - - expected = {"a": [1, 2], "literal": [1, 1], "b": [5, 6]} - for streaming in [True, False]: - assert ( - df.lazy() - .group_by("a", pl.lit(1)) - .agg(pl.col("b").max()) - .sort(["a", "b"]) - .collect(streaming=streaming) - .to_dict(as_series=False) - == expected - ) - - def test_group_by_list_scalar_11749() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index 22c0d1bc875b..2c908dc1806b 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -738,88 +738,6 @@ def test_outer_join_bool() -> None: } -@pytest.mark.parametrize("streaming", [False, True]) -def test_join_null_matches(streaming: bool) -> None: - # null values in joins should never find a match. - df_a = pl.LazyFrame( - { - "idx_a": [0, 1, 2], - "a": [None, 1, 2], - } - ) - - df_b = pl.LazyFrame( - { - "idx_b": [0, 1, 2, 3], - "a": [None, 2, 1, None], - } - ) - - expected = pl.DataFrame({"idx_a": [2, 1], "a": [2, 1], "idx_b": [1, 2]}) - assert_frame_equal( - df_a.join(df_b, on="a", how="inner").collect(streaming=streaming), expected - ) - expected = pl.DataFrame( - {"idx_a": [0, 1, 2], "a": [None, 1, 2], "idx_b": [None, 2, 1]} - ) - assert_frame_equal( - df_a.join(df_b, on="a", how="left").collect(streaming=streaming), expected - ) - expected = pl.DataFrame( - { - "idx_a": [None, 2, 1, None, 0], - "a": [None, 2, 1, None, None], - "idx_b": [0, 1, 2, 3, None], - "a_right": [None, 2, 1, None, None], - } - ) - assert_frame_equal(df_a.join(df_b, on="a", how="outer").collect(), expected) - - -@pytest.mark.parametrize("streaming", [False, True]) -def test_join_null_matches_multiple_keys(streaming: bool) -> None: - df_a = pl.LazyFrame( - { - "a": [None, 1, 2], - "idx": [0, 1, 2], - } - ) - - df_b = pl.LazyFrame( - { - "a": [None, 2, 1, None, 1], - "idx": [0, 1, 2, 3, 1], - "c": [10, 20, 30, 40, 50], - } - ) - - expected = pl.DataFrame({"a": [1], "idx": [1], "c": [50]}) - assert_frame_equal( - df_a.join(df_b, on=["a", "idx"], how="inner").collect(streaming=streaming), - expected, - ) - expected = pl.DataFrame( - {"a": [None, 1, 2], "idx": [0, 1, 2], "c": [None, 50, None]} - ) - assert_frame_equal( - df_a.join(df_b, on=["a", "idx"], how="left").collect(streaming=streaming), - expected, - ) - - expected = pl.DataFrame( - { - "a": [None, None, None, None, None, 1, 2], - "idx": [None, None, None, None, 0, 1, 2], - "a_right": [None, 2, 1, None, None, 1, None], - "idx_right": [0, 1, 2, 3, None, 1, None], - "c": [10, 20, 30, 40, None, 50, None], - } - ) - assert_frame_equal( - df_a.join(df_b, on=["a", "idx"], how="outer").sort("a").collect(), expected - ) - - def test_outer_join_coalesce_different_names_13450() -> None: df1 = pl.DataFrame({"L1": ["a", "b", "c"], "L3": ["b", "c", "d"], "L2": [1, 2, 3]}) df2 = pl.DataFrame({"L3": ["a", "c", "d"], "R2": [7, 8, 9]}) diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index dbb53ec66656..1f8424182190 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -420,3 +420,19 @@ def test_streaming_group_by_literal(literal: Any) -> None: "a_count": [20], "a_sum": [190], } + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_group_by_multiple_keys_one_literal(streaming: bool) -> None: + df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + + expected = {"a": [1, 2], "literal": [1, 1], "b": [5, 6]} + assert ( + df.lazy() + .group_by("a", pl.lit(1)) + .agg(pl.col("b").max()) + .sort(["a", "b"]) + .collect(streaming=streaming) + .to_dict(as_series=False) + == expected + ) diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index b4e009a469fb..5c9d01537f55 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import patch import pytest @@ -213,3 +213,43 @@ def test_sink_ndjson_should_write_same_data( df = pl.read_ndjson(target_path) assert_frame_equal(df, expected) + + +@pytest.mark.write_disk() +def test_parquet_eq_statistics(monkeypatch: Any, capfd: Any, tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + + monkeypatch.setenv("POLARS_VERBOSE", "1") + + df = pl.DataFrame({"idx": pl.arange(100, 200, eager=True)}).with_columns( + (pl.col("idx") // 25).alias("part") + ) + df = pl.concat(df.partition_by("part", as_dict=False), rechunk=False) + assert df.n_chunks("all") == [4, 4] + + file_path = tmp_path / "stats.parquet" + df.write_parquet(file_path, statistics=True, use_pyarrow=False) + + file_path = tmp_path / "stats.parquet" + df.write_parquet(file_path, statistics=True, use_pyarrow=False) + + for streaming in [False, True]: + for pred in [ + pl.col("idx") == 50, + pl.col("idx") == 150, + pl.col("idx") == 210, + ]: + result = ( + pl.scan_parquet(file_path).filter(pred).collect(streaming=streaming) + ) + assert_frame_equal(result, df.filter(pred)) + + captured = capfd.readouterr().err + assert ( + "parquet file must be read, statistics not sufficient for predicate." + in captured + ) + assert ( + "parquet file can be skipped, the statistics were sufficient" + " to apply the predicate." in captured + ) diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index 22baf6f6d4c2..507baa859095 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +import pytest import polars as pl from polars.testing import assert_frame_equal @@ -105,3 +106,85 @@ def test_streaming_join_rechunk_12498() -> None: "A": [0, 1, 0, 1], "B": [0, 0, 1, 1], } + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_join_null_matches(streaming: bool) -> None: + # null values in joins should never find a match. + df_a = pl.LazyFrame( + { + "idx_a": [0, 1, 2], + "a": [None, 1, 2], + } + ) + + df_b = pl.LazyFrame( + { + "idx_b": [0, 1, 2, 3], + "a": [None, 2, 1, None], + } + ) + + expected = pl.DataFrame({"idx_a": [2, 1], "a": [2, 1], "idx_b": [1, 2]}) + assert_frame_equal( + df_a.join(df_b, on="a", how="inner").collect(streaming=streaming), expected + ) + expected = pl.DataFrame( + {"idx_a": [0, 1, 2], "a": [None, 1, 2], "idx_b": [None, 2, 1]} + ) + assert_frame_equal( + df_a.join(df_b, on="a", how="left").collect(streaming=streaming), expected + ) + expected = pl.DataFrame( + { + "idx_a": [None, 2, 1, None, 0], + "a": [None, 2, 1, None, None], + "idx_b": [0, 1, 2, 3, None], + "a_right": [None, 2, 1, None, None], + } + ) + assert_frame_equal(df_a.join(df_b, on="a", how="outer").collect(), expected) + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_join_null_matches_multiple_keys(streaming: bool) -> None: + df_a = pl.LazyFrame( + { + "a": [None, 1, 2], + "idx": [0, 1, 2], + } + ) + + df_b = pl.LazyFrame( + { + "a": [None, 2, 1, None, 1], + "idx": [0, 1, 2, 3, 1], + "c": [10, 20, 30, 40, 50], + } + ) + + expected = pl.DataFrame({"a": [1], "idx": [1], "c": [50]}) + assert_frame_equal( + df_a.join(df_b, on=["a", "idx"], how="inner").collect(streaming=streaming), + expected, + ) + expected = pl.DataFrame( + {"a": [None, 1, 2], "idx": [0, 1, 2], "c": [None, 50, None]} + ) + assert_frame_equal( + df_a.join(df_b, on=["a", "idx"], how="left").collect(streaming=streaming), + expected, + ) + + expected = pl.DataFrame( + { + "a": [None, None, None, None, None, 1, 2], + "idx": [None, None, None, None, 0, 1, 2], + "a_right": [None, 2, 1, None, None, 1, None], + "idx_right": [0, 1, 2, 3, None, 1, None], + "c": [10, 20, 30, 40, None, 50, None], + } + ) + assert_frame_equal( + df_a.join(df_b, on=["a", "idx"], how="outer").sort("a").collect(), expected + )