Skip to content

Commit

Permalink
test: Extend and speed up scan tests (#15127)
Browse files Browse the repository at this point in the history
  • Loading branch information
mickvangelderen authored Mar 18, 2024
1 parent 4a91992 commit 634fdbe
Showing 1 changed file with 137 additions and 22 deletions.
159 changes: 137 additions & 22 deletions py-polars/tests/unit/io/test_scan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from math import ceil
from typing import TYPE_CHECKING

import pytest
Expand Down Expand Up @@ -82,35 +83,50 @@ class _DataFile:
df: pl.DataFrame


def df_with_chunk_size_limit(df: pl.DataFrame, limit: int) -> pl.DataFrame:
return pl.concat(
(
df.slice(i * limit, min(limit, df.height - i * limit))
for i in range(ceil(df.height / limit))
),
rechunk=False,
)


@pytest.fixture(scope="session")
def data_file_single(session_tmp_dir: Path, data_file_extension: str) -> _DataFile:
max_rows_per_batch = 727
file_path = (session_tmp_dir / "data").with_suffix(data_file_extension)
df = pl.DataFrame(
{
"seq_int": range(10000),
"seq_str": [f"{x}" for x in range(10000)],
"sequence": range(10000),
}
)
_write(df, file_path)
assert max_rows_per_batch < df.height
_write(df_with_chunk_size_limit(df, max_rows_per_batch), file_path)
return _DataFile(path=file_path, df=df)


@pytest.fixture(scope="session")
def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile:
max_rows_per_batch = 200
row_counts = [
100, 186, 95, 185, 90, 84, 115, 81, 87, 217, 126, 85, 98, 122, 129, 122, 1089, 82,
234, 86, 93, 90, 91, 263, 87, 126, 86, 161, 191, 1368, 403, 192, 102, 98, 115, 81,
111, 305, 92, 534, 431, 150, 90, 128, 152, 118, 127, 124, 229, 368, 81,
] # fmt: skip
assert sum(row_counts) == 10000
assert (
len(row_counts) < 100
) # need to make sure we pad file names with enough zeros, otherwise the lexographical ordering of the file names is not what we want.

# Make sure we pad file names with enough zeros to ensure correct
# lexographical ordering.
assert len(row_counts) < 100

# Make sure that some of our data frames consist of multiple chunks which
# affects the output of certain file formats.
assert any(row_count > max_rows_per_batch for row_count in row_counts)
df = pl.DataFrame(
{
"seq_int": range(10000),
"seq_str": [str(x) for x in range(10000)],
"sequence": range(10000),
}
)

Expand All @@ -119,7 +135,12 @@ def data_file_glob(session_tmp_dir: Path, data_file_extension: str) -> _DataFile
file_path = (session_tmp_dir / f"data_{index:02}").with_suffix(
data_file_extension
)
_write(df.slice(row_offset, row_count), file_path)
_write(
df_with_chunk_size_limit(
df.slice(row_offset, row_count), max_rows_per_batch
),
file_path,
)
row_offset += row_count
return _DataFile(
path=(session_tmp_dir / "data_*").with_suffix(data_file_extension), df=df
Expand Down Expand Up @@ -147,48 +168,142 @@ def test_scan(data_file: _DataFile) -> None:

@pytest.mark.write_disk()
def test_scan_with_limit(data_file: _DataFile) -> None:
df = _scan(data_file.path, data_file.df.schema).limit(100).collect()
df = _scan(data_file.path, data_file.df.schema).limit(4483).collect()
assert_frame_equal(
df,
pl.DataFrame(
{
"sequence": range(4483),
}
),
)


@pytest.mark.write_disk()
def test_scan_with_filter(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema)
.filter(pl.col("sequence") % 2 == 0)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"seq_int": range(100),
"seq_str": [str(x) for x in range(100)],
"sequence": (2 * x for x in range(5000)),
}
),
)


@pytest.mark.write_disk()
def test_scan_with_row_index(data_file: _DataFile) -> None:
df = _scan(data_file.path, data_file.df.schema, row_index=_RowIndex()).collect()
def test_scan_with_filter_and_limit(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema)
.filter(pl.col("sequence") % 2 == 0)
.limit(4483)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"sequence": (2 * x for x in range(4483)),
},
),
)


@pytest.mark.write_disk()
def test_scan_with_limit_and_filter(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema)
.limit(4483)
.filter(pl.col("sequence") % 2 == 0)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"sequence": (2 * x for x in range(2242)),
},
),
)


@pytest.mark.write_disk()
def test_scan_with_row_index_and_limit(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
.limit(4483)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"index": range(4483),
"sequence": range(4483),
},
schema_overrides={"index": pl.UInt32},
),
)


@pytest.mark.write_disk()
def test_scan_with_row_index_and_filter(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
.filter(pl.col("sequence") % 2 == 0)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"index": (2 * x for x in range(5000)),
"sequence": (2 * x for x in range(5000)),
},
schema_overrides={"index": pl.UInt32},
),
)


@pytest.mark.write_disk()
def test_scan_with_row_index_limit_and_filter(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
.limit(4483)
.filter(pl.col("sequence") % 2 == 0)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"index": range(10000),
"seq_int": range(10000),
"seq_str": [str(x) for x in range(10000)],
"index": (2 * x for x in range(2242)),
"sequence": (2 * x for x in range(2242)),
},
schema_overrides={"index": pl.UInt32},
),
)


@pytest.mark.write_disk()
def test_scan_with_row_index_and_predicate(data_file: _DataFile) -> None:
def test_scan_with_row_index_filter_and_limit(data_file: _DataFile) -> None:
df = (
_scan(data_file.path, data_file.df.schema, row_index=_RowIndex())
.filter(pl.col("seq_int") % 2 == 0)
.filter(pl.col("sequence") % 2 == 0)
.limit(4483)
.collect()
)
assert_frame_equal(
df,
pl.DataFrame(
{
"index": [2 * x for x in range(5000)],
"seq_int": [2 * x for x in range(5000)],
"seq_str": [str(2 * x) for x in range(5000)],
"index": (2 * x for x in range(4483)),
"sequence": (2 * x for x in range(4483)),
},
schema_overrides={"index": pl.UInt32},
),
Expand Down

0 comments on commit 634fdbe

Please sign in to comment.