From adade5e745440393ef66fd6961ae181ddb9c9e05 Mon Sep 17 00:00:00 2001 From: Chloe He Date: Mon, 1 Jul 2024 05:55:42 -0700 Subject: [PATCH] feat(pyspark): implement new experimental read/write directory methods (#9272) Co-authored-by: Chloe He Co-authored-by: Chloe He Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com> --- ibis/backends/pyspark/__init__.py | 203 +++++++++++++++++- ibis/backends/pyspark/tests/conftest.py | 80 ++++--- .../pyspark/tests/test_import_export.py | 84 ++++++++ 3 files changed, 321 insertions(+), 46 deletions(-) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 98f03c53ec7d..9d27da440bc9 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -710,7 +710,7 @@ def read_parquet( if self.mode == "streaming": raise NotImplementedError( "Pyspark in streaming mode does not support direction registration of parquet files. " - "Please use `read_parquet_directory` instead." + "Please use `read_parquet_dir` instead." ) path = util.normalize_filename(path) spark_df = self._session.read.parquet(path, **kwargs) @@ -748,7 +748,7 @@ def read_csv( if self.mode == "streaming": raise NotImplementedError( "Pyspark in streaming mode does not support direction registration of CSV files. " - "Please use `read_csv_directory` instead." + "Please use `read_csv_dir` instead." ) inferSchema = kwargs.pop("inferSchema", True) header = kwargs.pop("header", True) @@ -790,7 +790,7 @@ def read_json( if self.mode == "streaming": raise NotImplementedError( "Pyspark in streaming mode does not support direction registration of JSON files. " - "Please use `read_json_directory` instead." + "Please use `read_json_dir` instead." ) source_list = normalize_filenames(source_list) spark_df = self._session.read.json(source_list, **kwargs) @@ -1055,3 +1055,200 @@ def to_kafka( sq = sq.option(k, v) sq.start() return sq + + @util.experimental + def read_csv_dir( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Register a CSV directory as a table in the current database. + + Parameters + ---------- + path + The data source. + table_name + An optional name to use for the created table. This defaults to + a random generated name. + kwargs + Additional keyword arguments passed to PySpark loading function. + https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.csv.html + + Returns + ------- + ir.Table + The just-registered table + + """ + inferSchema = kwargs.pop("inferSchema", True) + header = kwargs.pop("header", True) + path = util.normalize_filename(path) + if self.mode == "batch": + spark_df = self._session.read.csv( + path, inferSchema=inferSchema, header=header, **kwargs + ) + elif self.mode == "streaming": + spark_df = self._session.readStream.csv( + path, inferSchema=inferSchema, header=header, **kwargs + ) + table_name = table_name or util.gen_name("read_csv_dir") + + spark_df.createOrReplaceTempView(table_name) + return self.table(table_name) + + @util.experimental + def read_parquet_dir( + self, + path: str | Path, + table_name: str | None = None, + **kwargs: Any, + ) -> ir.Table: + """Register a parquet file as a table in the current database. + + Parameters + ---------- + path + The data source. A directory of parquet files. + table_name + An optional name to use for the created table. This defaults to + a random generated name. + kwargs + Additional keyword arguments passed to PySpark. + https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.parquet.html + + Returns + ------- + ir.Table + The just-registered table + + """ + path = util.normalize_filename(path) + if self.mode == "batch": + spark_df = self._session.read.parquet(path, **kwargs) + elif self.mode == "streaming": + spark_df = self._session.readStream.parquet(path, **kwargs) + table_name = table_name or util.gen_name("read_parquet_dir") + + spark_df.createOrReplaceTempView(table_name) + return self.table(table_name) + + @util.experimental + def read_json_dir( + self, path: str | Path, table_name: str | None = None, **kwargs: Any + ) -> ir.Table: + """Register a JSON file as a table in the current database. + + Parameters + ---------- + path + The data source. A directory of JSON files. + table_name + An optional name to use for the created table. This defaults to + a random generated name. + kwargs + Additional keyword arguments passed to PySpark loading function. + https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.json.html + + Returns + ------- + ir.Table + The just-registered table + + """ + path = util.normalize_filename(path) + if self.mode == "batch": + spark_df = self._session.read.json(path, **kwargs) + elif self.mode == "streaming": + spark_df = self._session.readStream.json(path, **kwargs) + table_name = table_name or util.gen_name("read_json_dir") + + spark_df.createOrReplaceTempView(table_name) + return self.table(table_name) + + def _to_filesystem_output( + self, + expr: ir.Expr, + format: str, + path: str | Path, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + options: Mapping[str, str] | None = None, + ) -> StreamingQuery | None: + df = self._session.sql(expr.compile(params=params, limit=limit)) + if self.mode == "batch": + df = df.write.format(format) + for k, v in (options or {}).items(): + df = df.option(k, v) + df.save(path) + return None + sq = df.writeStream.format(format) + sq = sq.option("path", os.fspath(path)) + for k, v in (options or {}).items(): + sq = sq.option(k, v) + sq.start() + return sq + + @util.experimental + def to_parquet_dir( + self, + expr: ir.Expr, + path: str | Path, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + options: Mapping[str, str] | None = None, + ) -> StreamingQuery | None: + """Write the results of executing the given expression to a parquet directory. + + Parameters + ---------- + expr + The ibis expression to execute and persist to parquet. + path + The data source. A string or Path to the parquet directory. + params + Mapping of scalar parameter expressions to value. + limit + An integer to effect a specific row limit. A value of `None` means + "no limit". The default is in `ibis/config.py`. + options + Additional keyword arguments passed to pyspark.sql.streaming.DataStreamWriter + + Returns + ------- + StreamingQuery | None + Returns a Pyspark StreamingQuery object if in streaming mode, otherwise None + """ + self._run_pre_execute_hooks(expr) + return self._to_filesystem_output(expr, "parquet", path, params, limit, options) + + @util.experimental + def to_csv_dir( + self, + expr: ir.Expr, + path: str | Path, + params: Mapping[ir.Scalar, Any] | None = None, + limit: int | str | None = None, + options: Mapping[str, str] | None = None, + ) -> StreamingQuery | None: + """Write the results of executing the given expression to a CSV directory. + + Parameters + ---------- + expr + The ibis expression to execute and persist to CSV. + path + The data source. A string or Path to the CSV directory. + params + Mapping of scalar parameter expressions to value. + limit + An integer to effect a specific row limit. A value of `None` means + "no limit". The default is in `ibis/config.py`. + options + Additional keyword arguments passed to pyspark.sql.streaming.DataStreamWriter + + Returns + ------- + StreamingQuery | None + Returns a Pyspark StreamingQuery object if in streaming mode, otherwise None + """ + self._run_pre_execute_hooks(expr) + return self._to_filesystem_output(expr, "csv", path, params, limit, options) diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index a15f405e0ecb..7ffce3ac297e 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -2,11 +2,12 @@ import os from datetime import datetime, timedelta, timezone -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd import pytest +from filelock import FileLock import ibis from ibis import util @@ -14,6 +15,9 @@ from ibis.backends.tests.base import BackendTest from ibis.backends.tests.data import json_types, topk, win +if TYPE_CHECKING: + from pathlib import Path + def set_pyspark_database(con, database): con._session.catalog.setCurrentDatabase(database) @@ -161,6 +165,7 @@ def connect(*, tmpdir, worker_id, **kw): .config("spark.ui.enabled", False) .config("spark.ui.showConsoleProgress", False) .config("spark.sql.execution.arrow.pyspark.enabled", False) + .config("spark.sql.streaming.schemaInference", True) ) try: @@ -193,52 +198,41 @@ def _load_data(self, **_: Any) -> None: t = t.sort(sort_col) t.createOrReplaceTempView(name) - @staticmethod - def connect(*, tmpdir, worker_id, **kw): - # Spark internally stores timestamps as UTC values, and timestamp - # data that is brought in without a specified time zone is - # converted as local time to UTC with microsecond resolution. - # https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics + @classmethod + def load_data( + cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any + ) -> BackendTest: + """Load testdata from `data_dir`.""" + # handling for multi-processes pytest - from pyspark.sql import SparkSession + # get the temp directory shared by all workers + root_tmp_dir = tmpdir.getbasetemp() / "streaming" + if worker_id != "master": + root_tmp_dir = root_tmp_dir.parent - config = ( - SparkSession.builder.appName("ibis_testing") - .master("local[1]") - .config("spark.cores.max", 1) - .config("spark.default.parallelism", 1) - .config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT") - .config("spark.dynamicAllocation.enabled", False) - .config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT") - .config("spark.executor.heartbeatInterval", "3600s") - .config("spark.executor.instances", 1) - .config("spark.network.timeout", "4200s") - .config("spark.rdd.compress", False) - .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .config("spark.shuffle.compress", False) - .config("spark.shuffle.spill.compress", False) - .config("spark.sql.legacy.timeParserPolicy", "LEGACY") - .config("spark.sql.session.timeZone", "UTC") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.storage.blockManagerSlaveTimeoutMs", "4200s") - .config("spark.ui.enabled", False) - .config("spark.ui.showConsoleProgress", False) - .config("spark.sql.execution.arrow.pyspark.enabled", False) - .config("spark.sql.streaming.schemaInference", True) - ) + fn = root_tmp_dir / cls.name() + with FileLock(f"{fn}.lock"): + cls.skip_if_missing_deps() - try: - from delta.pip_utils import configure_spark_with_delta_pip - except ImportError: - configure_spark_with_delta_pip = lambda cfg: cfg - else: - config = config.config( - "spark.sql.catalog.spark_catalog", - "org.apache.spark.sql.delta.catalog.DeltaCatalog", - ).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw) - spark = configure_spark_with_delta_pip(config).getOrCreate() - return ibis.pyspark.connect(spark, mode="streaming", **kw) + if inst.stateful: + inst.stateful_load(fn, **kw) + else: + inst.stateless_load(**kw) + inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw) + return inst + + @staticmethod + def connect(*, tmpdir, worker_id, **kw): + from pyspark.sql import SparkSession + + # SparkContext is shared globally; only one SparkContext should be active + # per JVM. We need to create a new SparkSession for streaming tests but + # this session shares the same SparkContext. + spark = SparkSession.getActiveSession().newSession() + con = ibis.pyspark.connect(spark, mode="streaming", **kw) + return con @pytest.fixture(scope="session") diff --git a/ibis/backends/pyspark/tests/test_import_export.py b/ibis/backends/pyspark/tests/test_import_export.py index 1aed2537d830..9955359a9ac2 100644 --- a/ibis/backends/pyspark/tests/test_import_export.py +++ b/ibis/backends/pyspark/tests/test_import_export.py @@ -1,9 +1,42 @@ from __future__ import annotations from operator import methodcaller +from time import sleep +from unittest import mock +import pandas as pd import pytest +from ibis.backends.conftest import TEST_TABLES +from ibis.backends.pyspark import Backend +from ibis.backends.pyspark.datatypes import PySparkSchema + + +@pytest.fixture(scope="session", autouse=True) +def default_session_fixture(): + with mock.patch.object(Backend, "write_to_memory", write_to_memory, create=True): + yield + + +def write_to_memory(self, expr, table_name): + if self.mode == "batch": + raise NotImplementedError + df = self._session.sql(expr.compile()) + df.writeStream.format("memory").queryName(table_name).start() + + +@pytest.fixture(autouse=True, scope="function") +def stop_active_jobs(con_streaming): + yield + for sq in con_streaming._session.streams.active: + sq.stop() + sq.awaitTermination() + + +@pytest.fixture +def awards_players_schema(): + return TEST_TABLES["awards_players"] + @pytest.mark.parametrize( "method", @@ -17,3 +50,54 @@ def test_streaming_import_not_implemented(con_streaming, method): with pytest.raises(NotImplementedError): method(con_streaming) + + +def test_read_csv_dir(con_streaming, awards_players_schema): + t = con_streaming.read_csv_dir( + "ci/ibis-testing-data/directory/csv/awards_players", + table_name="t", + schema=PySparkSchema.from_ibis(awards_players_schema), + header=True, + ) + con_streaming.write_to_memory(t, "n") + sleep(2) # wait for results to populate; count(*) returns 0 if executed right away + pd_df = con_streaming._session.sql("SELECT count(*) FROM n").toPandas() + assert not pd_df.empty + assert pd_df.iloc[0, 0] == 6078 + + +def test_read_parquet_dir(con_streaming): + t = con_streaming.read_parquet_dir( + "ci/ibis-testing-data/directory/parquet/awards_players", table_name="t" + ) + con_streaming.write_to_memory(t, "n") + sleep(2) # wait for results to populate; count(*) returns 0 if executed right away + pd_df = con_streaming._session.sql("SELECT count(*) FROM n").toPandas() + assert not pd_df.empty + assert pd_df.iloc[0, 0] == 6078 + + +def test_to_csv_dir(con_streaming, tmp_path): + t = con_streaming.table("awards_players") + path = tmp_path / "out" + con_streaming.to_csv_dir( + t.limit(5), + path=path, + options={"checkpointLocation": tmp_path / "checkpoint", "header": True}, + ) + sleep(2) + df = pd.concat([pd.read_csv(f) for f in path.glob("*.csv")]) + assert len(df) == 5 + + +def test_to_parquet_dir(con_streaming, tmp_path): + t = con_streaming.table("awards_players") + path = tmp_path / "out" + con_streaming.to_parquet_dir( + t.limit(5), + path=path, + options={"checkpointLocation": tmp_path / "checkpoint"}, + ) + sleep(2) + df = pd.concat([pd.read_parquet(f) for f in path.glob("*.parquet")]) + assert len(df) == 5