Skip to content

Commit

Permalink
refactor: Fix a bunch of tests for new-streaming (#18659)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 10, 2024
1 parent 655c781 commit 2e92f0a
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 32 deletions.
3 changes: 1 addition & 2 deletions crates/polars-stream/src/nodes/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ impl ReduceNode {
scope.spawn_task(TaskPriority::High, async move {
while let Ok(morsel) = recv.recv().await {
for (reducer, selector) in local_reducers.iter_mut().zip(selectors) {
// TODO: don't convert to physical representation here.
let input = selector.evaluate(morsel.df(), state).await?;
reducer.update(&input.to_physical_repr())?;
reducer.update(&input)?;
}
}

Expand Down
18 changes: 7 additions & 11 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,16 @@ def test_cast_inner() -> None:


def test_list_empty_group_by_result_3521() -> None:
# Create a left relation where the join column contains a null value
left = pl.DataFrame().with_columns(
pl.lit(1).alias("group_by_column"),
pl.lit(None).cast(pl.Int32).alias("join_column"),
# Create a left relation where the join column contains a null value.
left = pl.DataFrame(
{"group_by_column": [1], "join_column": [None]},
schema_overrides={"join_column": pl.Int64},
)

# Create a right relation where there is a column to count distinct on
right = pl.DataFrame().with_columns(
pl.lit(1).alias("join_column"),
pl.lit(1).alias("n_unique_column"),
)
# Create a right relation where there is a column to count distinct on.
right = pl.DataFrame({"join_column": [1], "n_unique_column": [1]})

# Calculate n_unique after dropping nulls
# This will panic on polars version 0.13.38 and 0.13.39
# Calculate n_unique after dropping nulls.
result = (
left.join(right, on="join_column", how="left")
.group_by("group_by_column")
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def test_from_dicts_struct() -> None:
]


@pytest.mark.may_fail_auto_streaming
def test_list_to_struct() -> None:
df = pl.DataFrame({"a": [[1, 2, 3], [1, 2]]})
assert df.select([pl.col("a").list.to_struct()]).to_series().to_list() == [
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,12 +1399,12 @@ def test_replace_time_zone_sortedness_expressions(
from_tz: str | None, expected_sortedness: bool, ambiguous: str
) -> None:
df = (
pl.Series("ts", [1603584000000000, 1603587600000000])
pl.Series("ts", [1603584000000000, 1603584060000000, 1603587600000000])
.cast(pl.Datetime("us", from_tz))
.sort()
.to_frame()
)
df = df.with_columns(ambiguous=pl.Series([ambiguous] * 2))
df = df.with_columns(ambiguous=pl.Series([ambiguous] * 3))
assert df["ts"].flags["SORTED_ASC"]
result = df.select(
pl.col("ts").dt.replace_time_zone("UTC", ambiguous=pl.col("ambiguous"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,13 @@ def test_local_date_sortedness(time_zone: str | None, expected: bool) -> None:
ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort()
result = ser.dt.date()
assert result.flags["SORTED_ASC"]
assert result.flags["SORTED_DESC"] is False

# 2 elements - depends on time zone
ser = (
pl.Series([datetime(2022, 1, 1, 23)] * 2).dt.replace_time_zone(time_zone)
).sort()
result = ser.dt.date()
assert result.flags["SORTED_ASC"] == expected
assert result.flags["SORTED_DESC"] is False
assert result.flags["SORTED_ASC"] >= expected


@pytest.mark.parametrize("time_zone", [None, "Asia/Kathmandu", "UTC"])
Expand All @@ -155,11 +153,16 @@ def test_local_time_sortedness(time_zone: str | None) -> None:
ser = (pl.Series([datetime(2022, 1, 1, 23)]).dt.replace_time_zone(time_zone)).sort()
result = ser.dt.time()
assert result.flags["SORTED_ASC"]
assert not result.flags["SORTED_DESC"]

# two elements - not sorted
# three elements - not sorted
ser = (
pl.Series([datetime(2022, 1, 1, 23)] * 2).dt.replace_time_zone(time_zone)
pl.Series(
[
datetime(2022, 1, 1, 23),
datetime(2022, 1, 2, 21),
datetime(2022, 1, 3, 22),
]
).dt.replace_time_zone(time_zone)
).sort()
result = ser.dt.time()
assert not result.flags["SORTED_ASC"]
Expand All @@ -180,31 +183,34 @@ def test_local_time_before_epoch(time_unit: TimeUnit) -> None:
("time_zone", "offset", "expected"),
[
(None, "1d", True),
("Asia/Kathmandu", "1d", False),
("Europe/London", "1d", False),
("UTC", "1d", True),
(None, "1mo", True),
("Asia/Kathmandu", "1mo", False),
("Europe/London", "1mo", False),
("UTC", "1mo", True),
(None, "1w", True),
("Asia/Kathmandu", "1w", False),
("Europe/London", "1w", False),
("UTC", "1w", True),
(None, "1h", True),
("Asia/Kathmandu", "1h", True),
("Europe/London", "1h", True),
("UTC", "1h", True),
],
)
def test_offset_by_sortedness(
time_zone: str | None, offset: str, expected: bool
) -> None:
# create 2 values, as a single value is always sorted
ser = (
pl.Series(
[datetime(2022, 1, 1, 22), datetime(2022, 1, 1, 22)]
).dt.replace_time_zone(time_zone)
s = pl.datetime_range(
datetime(2020, 10, 25),
datetime(2020, 10, 25, 3),
"30m",
time_zone=time_zone,
eager=True,
).sort()
result = ser.dt.offset_by(offset)
assert s.flags["SORTED_ASC"]
assert not s.flags["SORTED_DESC"]
result = s.dt.offset_by(offset)
assert result.flags["SORTED_ASC"] == expected
assert result.flags["SORTED_DESC"] is False
assert not result.flags["SORTED_DESC"]


def test_dt_datetime_date_time_invalid() -> None:
Expand Down

0 comments on commit 2e92f0a

Please sign in to comment.