Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pyspark): implement new experimental read/write directory methods #9272

Merged
merged 12 commits into from
Jul 1, 2024
203 changes: 200 additions & 3 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of parquet files. "
"Please use `read_parquet_directory` instead."
"Please use `read_parquet_dir` instead."
)
path = util.normalize_filename(path)
spark_df = self._session.read.parquet(path, **kwargs)
Expand Down Expand Up @@ -748,7 +748,7 @@
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of CSV files. "
"Please use `read_csv_directory` instead."
"Please use `read_csv_dir` instead."
)
inferSchema = kwargs.pop("inferSchema", True)
header = kwargs.pop("header", True)
Expand Down Expand Up @@ -790,7 +790,7 @@
if self.mode == "streaming":
raise NotImplementedError(
"Pyspark in streaming mode does not support direction registration of JSON files. "
"Please use `read_json_directory` instead."
"Please use `read_json_dir` instead."
)
source_list = normalize_filenames(source_list)
spark_df = self._session.read.json(source_list, **kwargs)
Expand Down Expand Up @@ -1055,3 +1055,200 @@
sq = sq.option(k, v)
sq.start()
return sq

@util.experimental
def read_csv_dir(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a CSV directory as a table in the current database.

Parameters
----------
path
The data source.
table_name
An optional name to use for the created table. This defaults to
a random generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.csv.html

Returns
-------
ir.Table
The just-registered table

"""
inferSchema = kwargs.pop("inferSchema", True)
header = kwargs.pop("header", True)
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.csv(

Check warning on line 1086 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1086

Added line #L1086 was not covered by tests
path, inferSchema=inferSchema, header=header, **kwargs
)
elif self.mode == "streaming":
spark_df = self._session.readStream.csv(
path, inferSchema=inferSchema, header=header, **kwargs
)
table_name = table_name or util.gen_name("read_csv_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

@util.experimental
def read_parquet_dir(
self,
path: str | Path,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a parquet file as a table in the current database.

Parameters
----------
path
The data source. A directory of parquet files.
table_name
An optional name to use for the created table. This defaults to
a random generated name.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.parquet.html

Returns
-------
ir.Table
The just-registered table

"""
path = util.normalize_filename(path)
if self.mode == "batch":
spark_df = self._session.read.parquet(path, **kwargs)

Check warning on line 1126 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1126

Added line #L1126 was not covered by tests
elif self.mode == "streaming":
spark_df = self._session.readStream.parquet(path, **kwargs)
table_name = table_name or util.gen_name("read_parquet_dir")

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

@util.experimental
def read_json_dir(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a JSON file as a table in the current database.

Parameters
----------
path
The data source. A directory of JSON files.
table_name
An optional name to use for the created table. This defaults to
a random generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.ss/api/pyspark.sql.streaming.DataStreamReader.json.html

Returns
-------
ir.Table
The just-registered table

"""
path = util.normalize_filename(path)

Check warning on line 1157 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1157

Added line #L1157 was not covered by tests
if self.mode == "batch":
spark_df = self._session.read.json(path, **kwargs)

Check warning on line 1159 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1159

Added line #L1159 was not covered by tests
elif self.mode == "streaming":
spark_df = self._session.readStream.json(path, **kwargs)
table_name = table_name or util.gen_name("read_json_dir")

Check warning on line 1162 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1161-L1162

Added lines #L1161 - L1162 were not covered by tests

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

Check warning on line 1165 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1164-L1165

Added lines #L1164 - L1165 were not covered by tests

def _to_filesystem_output(
self,
expr: ir.Expr,
format: str,
path: str | Path,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
options: Mapping[str, str] | None = None,
) -> StreamingQuery | None:
df = self._session.sql(expr.compile(params=params, limit=limit))
if self.mode == "batch":
df = df.write.format(format)

Check warning on line 1178 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1178

Added line #L1178 was not covered by tests
for k, v in (options or {}).items():
df = df.option(k, v)
df.save(path)
return None

Check warning on line 1182 in ibis/backends/pyspark/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/__init__.py#L1180-L1182

Added lines #L1180 - L1182 were not covered by tests
sq = df.writeStream.format(format)
sq = sq.option("path", os.fspath(path))
for k, v in (options or {}).items():
sq = sq.option(k, v)
sq.start()
return sq

@util.experimental
def to_parquet_dir(
self,
expr: ir.Expr,
path: str | Path,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
options: Mapping[str, str] | None = None,
) -> StreamingQuery | None:
"""Write the results of executing the given expression to a parquet directory.

Parameters
----------
expr
The ibis expression to execute and persist to parquet.
path
The data source. A string or Path to the parquet directory.
params
Mapping of scalar parameter expressions to value.
limit
An integer to effect a specific row limit. A value of `None` means
"no limit". The default is in `ibis/config.py`.
options
Additional keyword arguments passed to pyspark.sql.streaming.DataStreamWriter

Returns
-------
StreamingQuery | None
Returns a Pyspark StreamingQuery object if in streaming mode, otherwise None
"""
self._run_pre_execute_hooks(expr)
return self._to_filesystem_output(expr, "parquet", path, params, limit, options)

@util.experimental
def to_csv_dir(
self,
expr: ir.Expr,
path: str | Path,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
options: Mapping[str, str] | None = None,
) -> StreamingQuery | None:
"""Write the results of executing the given expression to a CSV directory.

Parameters
----------
expr
The ibis expression to execute and persist to CSV.
path
The data source. A string or Path to the CSV directory.
params
Mapping of scalar parameter expressions to value.
limit
An integer to effect a specific row limit. A value of `None` means
"no limit". The default is in `ibis/config.py`.
options
Additional keyword arguments passed to pyspark.sql.streaming.DataStreamWriter

Returns
-------
StreamingQuery | None
Returns a Pyspark StreamingQuery object if in streaming mode, otherwise None
"""
self._run_pre_execute_hooks(expr)
return self._to_filesystem_output(expr, "csv", path, params, limit, options)
80 changes: 37 additions & 43 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@

import os
from datetime import datetime, timedelta, timezone
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
import pytest
from filelock import FileLock

import ibis
from ibis import util
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest
from ibis.backends.tests.data import json_types, topk, win

if TYPE_CHECKING:
from pathlib import Path


def set_pyspark_database(con, database):
con._session.catalog.setCurrentDatabase(database)
Expand Down Expand Up @@ -161,6 +165,7 @@
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)

try:
Expand Down Expand Up @@ -193,52 +198,41 @@
t = t.sort(sort_col)
t.createOrReplaceTempView(name)

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
# Spark internally stores timestamps as UTC values, and timestamp
# data that is brought in without a specified time zone is
# converted as local time to UTC with microsecond resolution.
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
@classmethod
def load_data(
cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any
) -> BackendTest:
"""Load testdata from `data_dir`."""
# handling for multi-processes pytest

from pyspark.sql import SparkSession
# get the temp directory shared by all workers
root_tmp_dir = tmpdir.getbasetemp() / "streaming"
if worker_id != "master":
root_tmp_dir = root_tmp_dir.parent

Check warning on line 211 in ibis/backends/pyspark/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/conftest.py#L211

Added line #L211 was not covered by tests

config = (
SparkSession.builder.appName("ibis_testing")
.master("local[1]")
.config("spark.cores.max", 1)
.config("spark.default.parallelism", 1)
.config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.dynamicAllocation.enabled", False)
.config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT")
.config("spark.executor.heartbeatInterval", "3600s")
.config("spark.executor.instances", 1)
.config("spark.network.timeout", "4200s")
.config("spark.rdd.compress", False)
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.shuffle.compress", False)
.config("spark.shuffle.spill.compress", False)
.config("spark.sql.legacy.timeParserPolicy", "LEGACY")
.config("spark.sql.session.timeZone", "UTC")
.config("spark.sql.shuffle.partitions", 1)
.config("spark.storage.blockManagerSlaveTimeoutMs", "4200s")
.config("spark.ui.enabled", False)
.config("spark.ui.showConsoleProgress", False)
.config("spark.sql.execution.arrow.pyspark.enabled", False)
.config("spark.sql.streaming.schemaInference", True)
)
fn = root_tmp_dir / cls.name()
with FileLock(f"{fn}.lock"):
cls.skip_if_missing_deps()

try:
from delta.pip_utils import configure_spark_with_delta_pip
except ImportError:
configure_spark_with_delta_pip = lambda cfg: cfg
else:
config = config.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
).config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
inst = cls(data_dir=data_dir, tmpdir=tmpdir, worker_id=worker_id, **kw)

spark = configure_spark_with_delta_pip(config).getOrCreate()
return ibis.pyspark.connect(spark, mode="streaming", **kw)
if inst.stateful:
inst.stateful_load(fn, **kw)
else:
inst.stateless_load(**kw)

Check warning on line 222 in ibis/backends/pyspark/tests/conftest.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pyspark/tests/conftest.py#L222

Added line #L222 was not covered by tests
inst.postload(tmpdir=tmpdir, worker_id=worker_id, **kw)
return inst

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
from pyspark.sql import SparkSession

# SparkContext is shared globally; only one SparkContext should be active
# per JVM. We need to create a new SparkSession for streaming tests but
# this session shares the same SparkContext.
spark = SparkSession.getActiveSession().newSession()
con = ibis.pyspark.connect(spark, mode="streaming", **kw)
return con


@pytest.fixture(scope="session")
Expand Down
Loading