Skip to content

Commit

Permalink
feat(pyspark): support reading from and writing to Kafka (#9266)
Browse files Browse the repository at this point in the history
Co-authored-by: Chloe He <chloe@chloe-mac.lan>
  • Loading branch information
chloeh13q and Chloe He authored Jun 3, 2024
1 parent ccfcbbc commit 1c7c6e3
Showing 1 changed file with 123 additions and 12 deletions.
135 changes: 123 additions & 12 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from typing import TYPE_CHECKING, Any, Literal

import pyspark
import pyspark.sql.functions as F
import sqlglot as sg
import sqlglot.expressions as sge
from packaging.version import parse as vparse
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType

import ibis.common.exceptions as com
Expand All @@ -35,6 +35,9 @@
import pandas as pd
import polars as pl
import pyarrow as pa
from pyspark.sql.streaming import StreamingQuery

from ibis.expr.api import Watermark

PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4")

Expand All @@ -48,7 +51,7 @@ def normalize_filenames(source_list):
return list(map(util.normalize_filename, source_list))


@pandas_udf(returnType=DoubleType(), functionType=PandasUDFType.SCALAR)
@F.pandas_udf(returnType=DoubleType(), functionType=F.PandasUDFType.SCALAR)
def unwrap_json_float(s: pd.Series) -> pd.Series:
import json

Expand All @@ -73,7 +76,7 @@ def unwrap_json(typ):

type_mapping = {str: StringType(), int: LongType(), bool: BooleanType()}

@pandas_udf(returnType=type_mapping[typ], functionType=PandasUDFType.SCALAR)
@F.pandas_udf(returnType=type_mapping[typ], functionType=F.PandasUDFType.SCALAR)
def unwrap(s: pd.Series) -> pd.Series:
def nullify_type_mismatched_value(raw):
if pd.isna(raw):
Expand All @@ -89,6 +92,10 @@ def nullify_type_mismatched_value(raw):
return unwrap


def _interval_to_string(interval):
return f"{interval.op().value} {interval.op().dtype.unit.name.lower()}"


class Backend(SQLBackend, CanListCatalog, CanCreateDatabase):
name = "pyspark"
compiler = PySparkCompiler()
Expand Down Expand Up @@ -306,22 +313,22 @@ def _register_udfs(self, expr: ir.Expr) -> None:
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.__func__, udf.dtype)
udf_return = PySparkType.from_ibis(udf.dtype)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR)
spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)

for udf in node.find(ops.ElementWiseVectorizedUDF):
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.func, udf.return_type)
udf_return = PySparkType.from_ibis(udf.return_type)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.SCALAR)
spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.SCALAR)
self._session.udf.register(udf_name, spark_udf)

for udf in node.find(ops.ReductionVectorizedUDF):
udf_name = self.compiler.__sql_name__(udf)
udf_func = self._wrap_udf_to_return_pandas(udf.func, udf.return_type)
udf_func = udf.func
udf_return = PySparkType.from_ibis(udf.return_type)
spark_udf = pandas_udf(udf_func, udf_return, PandasUDFType.GROUPED_AGG)
spark_udf = F.pandas_udf(udf_func, udf_return, F.PandasUDFType.GROUPED_AGG)
self._session.udf.register(udf_name, spark_udf)

for typ in (str, int, bool):
Expand Down Expand Up @@ -460,12 +467,9 @@ def get_schema(
def create_table(
self,
name: str,
obj: ir.Table
| pd.DataFrame
| pa.Table
| pl.DataFrame
| pl.LazyFrame
| None = None,
obj: (
ir.Table | pd.DataFrame | pa.Table | pl.DataFrame | pl.LazyFrame | None
) = None,
*,
schema: sch.Schema | None = None,
database: str | None = None,
Expand Down Expand Up @@ -929,3 +933,110 @@ def to_pyarrow_batches(
return pa.ipc.RecordBatchReader.from_batches(
pa_table.schema, pa_table.to_batches(max_chunksize=chunk_size)
)

@util.experimental
def read_kafka(
self,
table_name: str | None = None,
watermark: Watermark | None = None,
auto_parse: bool = False,
schema: sch.Schema | None = None,
options: Mapping[str, str] | None = None,
) -> ir.Table:
"""Register a Kafka topic as a table.
Parameters
----------
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
watermark
Watermark strategy for the table.
auto_parse
Whether to parse Kafka messages automatically. If `False`, the source is read
as binary keys and values. If `True`, the key is discarded and the value is
parsed using the provided schema.
schema
Schema of the value of the Kafka messages.
options
Additional arguments passed to PySpark as .option("key", "value").
https://spark.apache.org/docs/latest/structured-streaming-kafka-integration.html
Returns
-------
ir.Table
The just-registered table
"""
if self.mode == "batch":
raise NotImplementedError(
"Reading from Kafka in batch mode is not supported"
)
spark_df = self._session.readStream.format("kafka")
for k, v in (options or {}).items():
spark_df = spark_df.option(k, v)
spark_df = spark_df.load()

# parse the values of the Kafka messages using the provided schema
if auto_parse:
if schema is None:
raise com.IbisError(
"When auto_parse is True, a schema must be provided to parse the messages"
)
schema = PySparkSchema.from_ibis(schema)
spark_df = spark_df.select(
F.from_json(F.col("value").cast("string"), schema).alias("parsed_value")
).select("parsed_value.*")

if watermark is not None:
spark_df = spark_df.withWatermark(
watermark.time_col,
_interval_to_string(watermark.allowed_delay),
)

table_name = table_name or util.gen_name("read_kafka")
spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

@util.experimental
def to_kafka(
self,
expr: ir.Expr,
auto_format: bool = False,
options: Mapping[str, str] | None = None,
) -> StreamingQuery:
"""Write the results of executing the given expression to a Kafka topic.
This method does not return outputs. Streaming queries are run continuously in
the background.
Parameters
----------
expr
The ibis expression to execute and persist to a Kafka topic.
auto_format
Whether to format the Kafka messages before writing. If `False`, the output is
written as-is. If `True`, the output is converted into JSON and written as the
value of the Kafka messages.
options
PySpark Kafka write arguments.
https://spark.apache.org/docs/latest/structured-streaming-kafka-integration.html
Returns
-------
StreamingQuery
A Pyspark StreamingQuery object
"""
if self.mode == "batch":
raise NotImplementedError("Writing to Kafka in batch mode is not supported")
df = self._session.sql(expr.compile())
if auto_format:
df = df.select(
F.to_json(F.struct([F.col(c).alias(c) for c in df.columns])).alias(
"value"
)
)
sq = df.writeStream.format("kafka")
for k, v in (options or {}).items():
sq = sq.option(k, v)
sq.start()
return sq

0 comments on commit 1c7c6e3

Please sign in to comment.