Skip to content

Commit

Permalink
feat(pyspark): implement new experimental read/write directory methods (
Browse files Browse the repository at this point in the history
#9272)

Co-authored-by: Chloe He <chloe@chloe-mac.lan>
Co-authored-by: Chloe He <chloe@chloe-mac.local>
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
4 people authored Jul 1, 2024
1 parent 84974fe commit adade5e
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 46 deletions.
203 changes: 200 additions & 3 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def read_parquet(
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of parquet files. "
"Please use `read_parquet_directory` instead."
"Please use `read_parquet_dir` instead."
)
path = util.normalize_filename(path)
spark_df = self._session.read.parquet(path, **kwargs)
Expand Down Expand Up @@ -748,7 +748,7 @@ def read_csv(
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of CSV files. "
"Please use `read_csv_directory` instead."
"Please use `read_csv_dir` instead."
)
inferSchema = kwargs.pop("inferSchema", True)
header = kwargs.pop("header", True)
Expand Down Expand Up @@ -790,7 +790,7 @@ def read_json(
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of JSON files. "
"Please use `read_json_directory` instead."
"Please use `read_json_dir` instead."
)
source_list = normalize_filenames(source_list)
spark_df = self._session.read.json(source_list, **kwargs)
Expand Down Expand Up @@ -1055,3 +1055,200 @@ def to_kafka(
sq = sq.option(k, v)
sq.start()
return sq

@util.experimental
def read_csv_dir(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a CSV directory as a table in the current database.
Parameters
----------
path
The data source.
table_name
An optional name to use for the created table. This defaults to
a random generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.csv.html
Returns
-------
ir.Table
The just-registered table
"""
inferSchema = kwargs.pop("inferSchema", True)
header = kwargs.pop("header", True)
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.csv(
path, inferSchema=inferSchema, header=header, **kwargs
)
elif self.mode == "streaming":
spark_df = self._session.readStream.csv(
path, inferSchema=inferSchema, header=header, **kwargs
)
table_name = table_name or util.gen_name("read_csv_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

@util.experimental
def read_parquet_dir(
self,
path: str | Path,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a parquet file as a table in the current database.
Parameters
----------
path
The data source. A directory of parquet files.
table_name
An optional name to use for the created table. This defaults to
a random generated name.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.parquet.html
Returns
-------
ir.Table
The just-registered table
"""
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.parquet(path, **kwargs)
elif self.mode == "streaming":
spark_df = self._session.readStream.parquet(path, **kwargs)
table_name = table_name or util.gen_name("read_parquet_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

@util.experimental
def read_json_dir(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a JSON file as a table in the current database.
Parameters
----------
path
The data source. A directory of JSON files.
table_name
An optional name to use for the created table. This defaults to
a random generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.json.html
Returns
-------
ir.Table
The just-registered table
"""
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.json(path, **kwargs)
elif self.mode == "streaming":
spark_df = self._session.readStream.json(path, **kwargs)
table_name = table_name or util.gen_name("read_json_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def _to_filesystem_output(
self,
expr: ir.Expr,
format: str,
path: str | Path,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
options: Mapping[str, str] | None = None,
) -> StreamingQuery | None:
df = self._session.sql(expr.compile(params=params, limit=limit))
if self.mode == "batch":
df = df.write.format(format)
for k, v in (options or {}).items():
df = df.option(k, v)
df.save(path)
return None
sq = df.writeStream.format(format)
sq = sq.option("path", os.fspath(path))
for k, v in (options or {}).items():
sq = sq.option(k, v)
sq.start()
return sq

@util.experimental
def to_parquet_dir(
self,
expr: ir.Expr,
path: str | Path,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
options: Mapping[str, str] | None = None,
) -> StreamingQuery | None:
"""Write the results of executing the given expression to a parquet directory.
Parameters
----------
expr
The ibis expression to execute and persist to parquet.
path
The data source. A string or Path to the parquet directory.
params
Mapping of scalar parameter expressions to value.
limit
An integer to effect a specific row limit. A value of `None` means
"no limit". The default is in `ibis/config.py`.
options
Additional keyword arguments passed to pyspark.sql.streaming.DataStreamWriter
Returns
-------
StreamingQuery | None
Returns a Pyspark StreamingQuery object if in streaming mode, otherwise None
"""
self._run_pre_execute_hooks(expr)
return self._to_filesystem_output(expr, "parquet", path, params, limit, options)

@util.experimental
def to_csv_dir(
self,
expr: ir.Expr,
path: str | Path,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
options: Mapping[str, str] | None = None,
) -> StreamingQuery | None:
"""Write the results of executing the given expression to a CSV directory.
Parameters
----------
expr
The ibis expression to execute and persist to CSV.
path
The data source. A string or Path to the CSV directory.
params
Mapping of scalar parameter expressions to value.
limit
An integer to effect a specific row limit. A value of `None` means
"no limit". The default is in `ibis/config.py`.
options
Additional keyword arguments passed to pyspark.sql.streaming.DataStreamWriter
Returns
-------
StreamingQuery | None
Returns a Pyspark StreamingQuery object if in streaming mode, otherwise None
"""
self._run_pre_execute_hooks(expr)
return self._to_filesystem_output(expr, "csv", path, params, limit, options)
80 changes: 37 additions & 43 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import os
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
import pytest
from filelock import FileLock

import ibis
from ibis import util
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest
from ibis.backends.tests.data import json_types, topk, win

if TYPE_CHECKING:
from pathlib import Path


def set_pyspark_database(con, database):
con._session.catalog.setCurrentDatabase(database)
Expand Down Expand Up @@ -161,6 +165,7 @@ def connect(*, tmpdir, worker_id, **kw):
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)

try:
Expand Down Expand Up @@ -193,52 +198,41 @@ def _load_data(self, **_: Any) -> None:
t = t.sort(sort_col)
t.createOrReplaceTempView(name)

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
# Spark internally stores timestamps as UTC values, and timestamp
# data that is brought in without a specified time zone is
# converted as local time to UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
@classmethod
def load_data(
cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any
) -> BackendTest:
"""Load testdata from `data_dir`."""
# handling for multi-processes pytest

from pyspark.sql import SparkSession
# get the temp directory shared by all workers
root_tmp_dir = tmpdir.getbasetemp() / "streaming"
if worker_id != "master":
root_tmp_dir = root_tmp_dir.parent

config = (
SparkSession.builder.appName("ibis_testing")
.master("local[1]")
.config("spark.cores.max", 1)
.config("spark.default.parallelism", 1)
.config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.dynamicAllocation.enabled", False)
.config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.executor.heartbeatInterval", "3600s")
.config("spark.executor.instances", 1)
.config("spark.network.timeout", "4200s")
.config("spark.rdd.compress", False)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.shuffle.compress", False)
.config("spark.shuffle.spill.compress", False)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.sql.session.timeZone", "UTC")
.config("spark.sql.shuffle.partitions", 1)
.config("spark.storage.blockManagerSlaveTimeoutMs", "4200s")
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)
fn = root_tmp_dir / cls.name()
with FileLock(f"{fn}.lock"):
cls.skip_if_missing_deps()

try:
from delta.pip_utils import configure_spark_with_delta_pip
except ImportError:
configure_spark_with_delta_pip = lambda cfg: cfg
else:
config = config.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw)

spark = configure_spark_with_delta_pip(config).getOrCreate()
return ibis.pyspark.connect(spark, mode="streaming", **kw)
if inst.stateful:
inst.stateful_load(fn, **kw)
else:
inst.stateless_load(**kw)
inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw)
return inst

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
from pyspark.sql import SparkSession

# SparkContext is shared globally; only one SparkContext should be active
# per JVM. We need to create a new SparkSession for streaming tests but
# this session shares the same SparkContext.
spark = SparkSession.getActiveSession().newSession()
con = ibis.pyspark.connect(spark, mode="streaming", **kw)
return con


@pytest.fixture(scope="session")
Expand Down
Loading

0 comments on commit adade5e

Please sign in to comment.