diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 1cccd07f67c1..2ea3bfa91a60 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from math import ceil from typing import TYPE_CHECKING import pytest @@ -82,35 +83,50 @@ class _DataFile: df: pl.DataFrame +def df_with_chunk_size_limit(df: pl.DataFrame, limit: int) -> pl.DataFrame: + return pl.concat( + ( + df.slice(i * limit, min(limit, df.height - i * limit)) + for i in range(ceil(df.height / limit)) + ), + rechunk=False, + ) + + @pytest.fixture(scope="session") def data_file_single(session_tmp_dir: Path, data_file_extension: str) -> _DataFile: + max_rows_per_batch = 727 file_path = (session_tmp_dir / "data").with_suffix(data_file_extension) df = pl.DataFrame( { - "seq_int": range(10000), - "seq_str": [f"{x}" for x in range(10000)], + "sequence": range(10000), } ) - _write(df, file_path) + assert max_rows_per_batch < df.height + _write(df_with_chunk_size_limit(df, max_rows_per_batch), file_path) return _DataFile(path=file_path, df=df) @pytest.fixture(scope="session") def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile: + max_rows_per_batch = 200 row_counts = [ 100, 186, 95, 185, 90, 84, 115, 81, 87, 217, 126, 85, 98, 122, 129, 122, 1089, 82, 234, 86, 93, 90, 91, 263, 87, 126, 86, 161, 191, 1368, 403, 192, 102, 98, 115, 81, 111, 305, 92, 534, 431, 150, 90, 128, 152, 118, 127, 124, 229, 368, 81, ] # fmt: skip assert sum(row_counts) == 10000 - assert ( - len(row_counts) < 100 - ) # need to make sure we pad file names with enough zeros, otherwise the lexographical ordering of the file names is not what we want. + # Make sure we pad file names with enough zeros to ensure correct + # lexographical ordering. + assert len(row_counts) < 100 + + # Make sure that some of our data frames consist of multiple chunks which + # affects the output of certain file formats. + assert any(row_count > max_rows_per_batch for row_count in row_counts) df = pl.DataFrame( { - "seq_int": range(10000), - "seq_str": [str(x) for x in range(10000)], + "sequence": range(10000), } ) @@ -119,7 +135,12 @@ def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile file_path = (session_tmp_dir / f"data_{index:02}").with_suffix( data_file_extension ) - _write(df.slice(row_offset, row_count), file_path) + _write( + df_with_chunk_size_limit( + df.slice(row_offset, row_count), max_rows_per_batch + ), + file_path, + ) row_offset += row_count return _DataFile( path=(session_tmp_dir / "data_*").with_suffix(data_file_extension), df=df @@ -147,28 +168,122 @@ def test_scan(data_file: _DataFile) -> None: @pytest.mark.write_disk() def test_scan_with_limit(data_file: _DataFile) -> None: - df = _scan(data_file.path, data_file.df.schema).limit(100).collect() + df = _scan(data_file.path, data_file.df.schema).limit(4483).collect() + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": range(4483), + } + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_filter(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) assert_frame_equal( df, pl.DataFrame( { - "seq_int": range(100), - "seq_str": [str(x) for x in range(100)], + "sequence": (2 * x for x in range(5000)), } ), ) @pytest.mark.write_disk() -def test_scan_with_row_index(data_file: _DataFile) -> None: - df = _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()).collect() +def test_scan_with_filter_and_limit(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema) + .filter(pl.col("sequence") % 2 == 0) + .limit(4483) + .collect() + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": (2 * x for x in range(4483)), + }, + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_limit_and_filter(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema) + .limit(4483) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "sequence": (2 * x for x in range(2242)), + }, + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_row_index_and_limit(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .limit(4483) + .collect() + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "index": range(4483), + "sequence": range(4483), + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_row_index_and_filter(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) + assert_frame_equal( + df, + pl.DataFrame( + { + "index": (2 * x for x in range(5000)), + "sequence": (2 * x for x in range(5000)), + }, + schema_overrides={"index": pl.UInt32}, + ), + ) + + +@pytest.mark.write_disk() +def test_scan_with_row_index_limit_and_filter(data_file: _DataFile) -> None: + df = ( + _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) + .limit(4483) + .filter(pl.col("sequence") % 2 == 0) + .collect() + ) assert_frame_equal( df, pl.DataFrame( { - "index": range(10000), - "seq_int": range(10000), - "seq_str": [str(x) for x in range(10000)], + "index": (2 * x for x in range(2242)), + "sequence": (2 * x for x in range(2242)), }, schema_overrides={"index": pl.UInt32}, ), @@ -176,19 +291,19 @@ def test_scan_with_row_index(data_file: _DataFile) -> None: @pytest.mark.write_disk() -def test_scan_with_row_index_and_predicate(data_file: _DataFile) -> None: +def test_scan_with_row_index_filter_and_limit(data_file: _DataFile) -> None: df = ( _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()) - .filter(pl.col("seq_int") % 2 == 0) + .filter(pl.col("sequence") % 2 == 0) + .limit(4483) .collect() ) assert_frame_equal( df, pl.DataFrame( { - "index": [2 * x for x in range(5000)], - "seq_int": [2 * x for x in range(5000)], - "seq_str": [str(2 * x) for x in range(5000)], + "index": (2 * x for x in range(4483)), + "sequence": (2 * x for x in range(4483)), }, schema_overrides={"index": pl.UInt32}, ),