Skip to content

Commit

Permalink
add unit tests for window aggregation in pyspark backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Chloe He committed Jul 9, 2024
1 parent 6f14e43 commit 769fb23
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 37 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/flink/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_tumble_window_by_grouped_agg(con):
assert result.shape == (610, 4)


def test_tumble_window_by_global_agg(con):
def test_tumble_window_by_ungrouped_agg(con):
t = con.table("functional_alltypes_with_watermark")
expr = (
t.window_by(time_col=t.timestamp_col)
Expand Down
41 changes: 39 additions & 2 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import os
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -1045,7 +1046,11 @@ def to_kafka(

@util.experimental
def read_csv_dir(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
self,
path: str | Path,
table_name: str | None = None,
watermark: Watermark | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a CSV directory as a table in the current database.
Expand All @@ -1056,6 +1061,8 @@ def read_csv_dir(
table_name
An optional name to use for the created table. This defaults to
a random generated name.
watermark
Watermark strategy for the table.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.csv.html
Expand All @@ -1073,10 +1080,17 @@ def read_csv_dir(
spark_df = self._session.read.csv(
path, inferSchema=inferSchema, header=header, **kwargs
)
if watermark is not None:
warnings.warn("Watermark is not supported in batch mode")
elif self.mode == "streaming":
spark_df = self._session.readStream.csv(
path, inferSchema=inferSchema, header=header, **kwargs
)
if watermark is not None:
spark_df = spark_df.withWatermark(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)
table_name = table_name or util.gen_name("read_csv_dir")

spark_df.createOrReplaceTempView(table_name)
Expand All @@ -1087,6 +1101,7 @@ def read_parquet_dir(
self,
path: str | Path,
table_name: str | None = None,
watermark: Watermark | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a parquet file as a table in the current database.
Expand All @@ -1098,6 +1113,8 @@ def read_parquet_dir(
table_name
An optional name to use for the created table. This defaults to
a random generated name.
watermark
Watermark strategy for the table.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.parquet.html
Expand All @@ -1111,16 +1128,27 @@ def read_parquet_dir(
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.parquet(path, **kwargs)
if watermark is not None:
warnings.warn("Watermark is not supported in batch mode")
elif self.mode == "streaming":
spark_df = self._session.readStream.parquet(path, **kwargs)
if watermark is not None:
spark_df = spark_df.withWatermark(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)
table_name = table_name or util.gen_name("read_parquet_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

@util.experimental
def read_json_dir(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
self,
path: str | Path,
table_name: str | None = None,
watermark: Watermark | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a JSON file as a table in the current database.
Expand All @@ -1131,6 +1159,8 @@ def read_json_dir(
table_name
An optional name to use for the created table. This defaults to
a random generated name.
watermark
Watermark strategy for the table.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.json.html
Expand All @@ -1144,8 +1174,15 @@ def read_json_dir(
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.json(path, **kwargs)
if watermark is not None:
warnings.warn("Watermark is not supported in batch mode")
elif self.mode == "streaming":
spark_df = self._session.readStream.json(path, **kwargs)
if watermark is not None:
spark_df = spark_df.withWatermark(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)
table_name = table_name or util.gen_name("read_json_dir")

spark_df.createOrReplaceTempView(table_name)
Expand Down
43 changes: 38 additions & 5 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any
from unittest import mock

import numpy as np
import pandas as pd
Expand All @@ -12,6 +13,8 @@
import ibis
from ibis import util
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.pyspark import Backend
from ibis.backends.pyspark.datatypes import PySparkSchema
from ibis.backends.tests.base import BackendTest
from ibis.backends.tests.data import json_types, topk, win

Expand Down Expand Up @@ -189,13 +192,17 @@ def _load_data(self, **_: Any) -> None:
s = self.connection._session
num_partitions = 4

sort_cols = {"functional_alltypes": "id"}
watermark_cols = {"functional_alltypes": "timestamp_col"}

for name in TEST_TABLES:
for name, schema in TEST_TABLES.items():
path = str(self.data_dir / "directory" / "parquet" / name)
t = s.readStream.parquet(path).repartition(num_partitions)
if (sort_col := sort_cols.get(name)) is not None:
t = t.sort(sort_col)
t = (
s.readStream.schema(PySparkSchema.from_ibis(schema))
.parquet(path)
.repartition(num_partitions)
)
if (watermark_col := watermark_cols.get(name)) is not None:
t = t.withWatermark(watermark_col, "10 seconds")
t.createOrReplaceTempView(name)

@classmethod
Expand Down Expand Up @@ -409,3 +416,29 @@ def temp_table_db(con, temp_database):
yield temp_database, name
assert name in con.list_tables(database=temp_database), name
con.drop_table(name, database=temp_database)


@pytest.fixture(scope="session", autouse=True)
def default_session_fixture():
with mock.patch.object(Backend, "write_to_memory", write_to_memory, create=True):
yield


def write_to_memory(self, expr, table_name):
if self.mode == "batch":
raise NotImplementedError
df = self._session.sql(expr.compile())
df.writeStream.format("memory").queryName(table_name).start()


@pytest.fixture(autouse=True, scope="function")
def stop_active_jobs(con_streaming):
yield
for sq in con_streaming._session.streams.active:
sq.stop()
sq.awaitTermination()


@pytest.fixture
def awards_players_schema():
return TEST_TABLES["awards_players"]
29 changes: 0 additions & 29 deletions ibis/backends/pyspark/tests/test_import_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,13 @@

from operator import methodcaller
from time import sleep
from unittest import mock

import pandas as pd
import pytest

from ibis.backends.conftest import TEST_TABLES
from ibis.backends.pyspark import Backend
from ibis.backends.pyspark.datatypes import PySparkSchema


@pytest.fixture(scope="session", autouse=True)
def default_session_fixture():
with mock.patch.object(Backend, "write_to_memory", write_to_memory, create=True):
yield


def write_to_memory(self, expr, table_name):
if self.mode == "batch":
raise NotImplementedError
df = self._session.sql(expr.compile())
df.writeStream.format("memory").queryName(table_name).start()


@pytest.fixture(autouse=True, scope="function")
def stop_active_jobs(con_streaming):
yield
for sq in con_streaming._session.streams.active:
sq.stop()
sq.awaitTermination()


@pytest.fixture
def awards_players_schema():
return TEST_TABLES["awards_players"]


@pytest.mark.parametrize(
"method",
[
Expand Down
47 changes: 47 additions & 0 deletions ibis/backends/pyspark/tests/test_window.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

from time import sleep

import pandas as pd
import pandas.testing as tm
import pytest

import ibis
from ibis import _

pyspark = pytest.importorskip("pyspark")

Expand Down Expand Up @@ -87,3 +91,46 @@ def test_multiple_windows(t, spark_table, ibis_windows, spark_range):
.toPandas()
)
tm.assert_frame_equal(result, expected)


def test_tumble_window_by_grouped_agg(con_streaming, tmp_path):
t = con_streaming.table("functional_alltypes")
expr = (
t.window_by(time_col=t.timestamp_col)
.tumble(size=ibis.interval(seconds=30))
.agg(by=["string_col"], avg=_.float_col.mean())
)
path = tmp_path / "out"
con_streaming.to_csv_dir(
expr,
path=path,
options={"checkpointLocation": tmp_path / "checkpoint", "header": True},
)
sleep(5)
dfs = [pd.read_csv(f) for f in path.glob("*.csv")]
df = pd.concat([df for df in dfs if not df.empty])
assert list(df.columns) == ["window_start", "window_end", "string_col", "avg"]
# [NOTE] The expected number of rows here is 7299 because when all the data is ready
# at once, no event is dropped as out of order. On the contrary, Flink discards all
# out-of-order events as late arrivals and only emits 610 windows.
assert df.shape == (7299, 4)


def test_tumble_window_by_ungrouped_agg(con_streaming, tmp_path):
t = con_streaming.table("functional_alltypes")
expr = (
t.window_by(time_col=t.timestamp_col)
.tumble(size=ibis.interval(seconds=30))
.agg(avg=_.float_col.mean())
)
path = tmp_path / "out"
con_streaming.to_csv_dir(
expr,
path=path,
options={"checkpointLocation": tmp_path / "checkpoint", "header": True},
)
sleep(5)
dfs = [pd.read_csv(f) for f in path.glob("*.csv")]
df = pd.concat([df for df in dfs if not df.empty])
assert list(df.columns) == ["window_start", "window_end", "avg"]
assert df.shape == (7299, 3)

0 comments on commit 769fb23

Please sign in to comment.