Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Chloe He committed May 30, 2024
1 parent 5511770 commit b008b1f
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 43 deletions.
50 changes: 8 additions & 42 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def connect(*, tmpdir, worker_id, **kw):
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)

try:
Expand Down Expand Up @@ -195,50 +196,14 @@ def _load_data(self, **_: Any) -> None:

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
# Spark internally stores timestamps as UTC values, and timestamp
# data that is brought in without a specified time zone is
# converted as local time to UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics

from pyspark.sql import SparkSession

config = (
SparkSession.builder.appName("ibis_testing")
.master("local[1]")
.config("spark.cores.max", 1)
.config("spark.default.parallelism", 1)
.config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.dynamicAllocation.enabled", False)
.config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.executor.heartbeatInterval", "3600s")
.config("spark.executor.instances", 1)
.config("spark.network.timeout", "4200s")
.config("spark.rdd.compress", False)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.shuffle.compress", False)
.config("spark.shuffle.spill.compress", False)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.sql.session.timeZone", "UTC")
.config("spark.sql.shuffle.partitions", 1)
.config("spark.storage.blockManagerSlaveTimeoutMs", "4200s")
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)

try:
from delta.pip_utils import configure_spark_with_delta_pip
except ImportError:
configure_spark_with_delta_pip = lambda cfg: cfg
else:
config = config.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")

spark = configure_spark_with_delta_pip(config).getOrCreate()
return ibis.pyspark.connect(spark, mode="streaming", **kw)
# SparkContext is shared globally; only one SparkContext should be active
# per JVM. We need to create a new SparkSession for streaming tests but
# this session shares the same SparkContext.
spark = SparkSession.getActiveSession().newSession()
con = ibis.pyspark.connect(spark, mode="streaming", **kw)
return con


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -360,6 +325,7 @@ def con(data_dir, tmp_path_factory, worker_id):
@pytest.fixture(scope="session")
def con_streaming(data_dir, tmp_path_factory, worker_id):
backend_test = TestConfForStreaming.load_data(data_dir, tmp_path_factory, worker_id)
backend_test._load_data()
return backend_test.connection


Expand Down
84 changes: 84 additions & 0 deletions ibis/backends/pyspark/tests/test_import_export.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,42 @@
from __future__ import annotations

from operator import methodcaller
from time import sleep
from unittest import mock

import pandas as pd
import pytest

from ibis.backends.conftest import TEST_TABLES
from ibis.backends.pyspark import Backend
from ibis.backends.pyspark.datatypes import PySparkSchema


@pytest.fixture(scope="session", autouse=True)
def default_session_fixture():
with mock.patch.object(Backend, "write_to_memory", write_to_memory, create=True):
yield


def write_to_memory(self, expr, table_name):
if self.mode == "batch":
raise NotImplementedError
df = self._session.sql(expr.compile())
df.writeStream.format("memory").queryName(table_name).start()


@pytest.fixture(autouse=True, scope="function")
def stop_active_jobs(con_streaming):
yield
for sq in con_streaming._session.streams.active:
sq.stop()
sq.awaitTermination()


@pytest.fixture
def awards_players_schema():
return TEST_TABLES["awards_players"]


@pytest.mark.parametrize(
"method",
Expand All @@ -17,3 +50,54 @@
def test_streaming_import_not_implemented(con_streaming, method):
with pytest.raises(NotImplementedError):
method(con_streaming)


def test_read_csv_directory(con_streaming, awards_players_schema):
t = con_streaming.read_csv_directory(
"ci/ibis-testing-data/directory/csv/awards_players",
table_name="t",
schema=PySparkSchema.from_ibis(awards_players_schema),
header=True,
)
con_streaming.write_to_memory(t, "n")
sleep(2) # wait for results to populate; count(*) returns 0 if executed right away
pd_df = con_streaming._session.sql("SELECT count(*) FROM n").toPandas()
assert not pd_df.empty
assert pd_df.iloc[0, 0] == 6078


def test_read_parquet_directory(con_streaming):
t = con_streaming.read_parquet_directory(
"ci/ibis-testing-data/directory/parquet/awards_players", table_name="t"
)
con_streaming.write_to_memory(t, "n")
sleep(2) # wait for results to populate; count(*) returns 0 if executed right away
pd_df = con_streaming._session.sql("SELECT count(*) FROM n").toPandas()
assert not pd_df.empty
assert pd_df.iloc[0, 0] == 6078


def test_to_csv_directory(con_streaming, tmp_path):
t = con_streaming.table("awards_players")
path = tmp_path / "out"
con_streaming.to_csv_directory(
t.limit(5),
path=path,
options={"checkpointLocation": tmp_path / "checkpoint", "header": True},
)
sleep(2)
df = pd.concat([pd.read_csv(f) for f in path.glob("*.csv")])
assert len(df) == 5


def test_to_parquet_directory(con_streaming, tmp_path):
t = con_streaming.table("awards_players")
path = tmp_path / "out"
con_streaming.to_parquet_directory(
t.limit(5),
path=path,
options={"checkpointLocation": tmp_path / "checkpoint"},
)
sleep(2)
df = pd.concat([pd.read_parquet(f) for f in path.glob("*.parquet")])
assert len(df) == 5
1 change: 0 additions & 1 deletion ibis/backends/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def load_data(
cls.skip_if_missing_deps()

inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw)

if inst.stateful:
inst.stateful_load(fn, **kw)
else:
Expand Down

0 comments on commit b008b1f

Please sign in to comment.