Skip to content

Commit

Permalink
feat(pyspark): add support for struct operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 31, 2022
1 parent 54372c3 commit ce05987
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 42 deletions.
36 changes: 20 additions & 16 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from ibis import interval
from ibis.backends.pandas.client import PandasInMemoryTable
from ibis.backends.pandas.execution import execute
from ibis.backends.pyspark.datatypes import (
ibis_array_dtype_to_spark_dtype,
ibis_dtype_to_spark_dtype,
spark_dtype,
)
from ibis.backends.pyspark.datatypes import spark_dtype
from ibis.backends.pyspark.timecontext import (
combine_time_context,
filter_by_time_context,
Expand Down Expand Up @@ -236,6 +232,16 @@ def compile_struct_field(t, op, **kwargs):
return arg[op.field]


@compiles(ops.StructColumn)
def compile_struct_column(t, op, **kwargs):
return F.struct(
*(
t.translate(col, **kwargs).alias(name)
for name, col in zip(op.names, op.values)
)
)


@compiles(ops.SelfReference)
def compile_self_reference(t, op, **kwargs):
return t.translate(op.table, **kwargs)
Expand All @@ -252,10 +258,7 @@ def compile_cast(t, op, **kwargs):
'in the PySpark backend. {} not allowed.'.format(type(op.arg))
)

if op.to.is_array():
cast_type = ibis_array_dtype_to_spark_dtype(op.to)
else:
cast_type = ibis_dtype_to_spark_dtype(op.to)
cast_type = spark_dtype(op.to)

src_column = t.translate(op.arg, **kwargs)
return src_column.cast(cast_type)
Expand Down Expand Up @@ -357,8 +360,10 @@ def compile_literal(t, op, *, raw=False, **kwargs):
return set(value)
else:
return value
elif isinstance(value, tuple):
elif dtype.is_array():
return F.array(*map(F.lit, value))
elif dtype.is_struct():
return F.struct(*(F.lit(val).alias(name) for name, val in value.items()))
else:
if isinstance(value, pd.Timestamp) and value.tz is None:
value = value.tz_localize("UTC").to_pydatetime()
Expand Down Expand Up @@ -629,7 +634,7 @@ def compile_covariance(t, op, **kwargs):

fn = {"sample": F.covar_samp, "pop": F.covar_pop}[how]

pyspark_double_type = ibis_dtype_to_spark_dtype(dt.double)
pyspark_double_type = spark_dtype(dt.double)
new_op = op.__class__(
left=ops.Cast(op.left, to=pyspark_double_type),
right=ops.Cast(op.right, to=pyspark_double_type),
Expand All @@ -644,7 +649,7 @@ def compile_correlation(t, op, **kwargs):
if (how := op.how) == "pop":
raise ValueError("PySpark only implements sample correlation")

pyspark_double_type = ibis_dtype_to_spark_dtype(dt.double)
pyspark_double_type = spark_dtype(dt.double)
new_op = op.__class__(
left=ops.Cast(op.left, to=pyspark_double_type),
right=ops.Cast(op.right, to=pyspark_double_type),
Expand Down Expand Up @@ -706,7 +711,6 @@ def compile_abs(t, op, **kwargs):

@compiles(ops.Clip)
def compile_clip(t, op, **kwargs):
spark_dtype = ibis_dtype_to_spark_dtype(op.output_dtype)
col = t.translate(op.arg, **kwargs)
upper = t.translate(op.upper, **kwargs) if op.upper is not None else float('inf')
lower = t.translate(op.lower, **kwargs) if op.lower is not None else float('-inf')
Expand All @@ -722,7 +726,7 @@ def column_max(value, limit):
def clip(column, lower_value, upper_value):
return column_max(column_min(column, F.lit(lower_value)), F.lit(upper_value))

return clip(col, lower, upper).cast(spark_dtype)
return clip(col, lower, upper).cast(spark_dtype(op.output_dtype))


@compiles(ops.Round)
Expand Down Expand Up @@ -1575,7 +1579,7 @@ def compile_array_length(t, op, **kwargs):
def compile_array_slice(t, op, **kwargs):
start = op.start.value if op.start is not None else op.start
stop = op.stop.value if op.stop is not None else op.stop
spark_type = ibis_array_dtype_to_spark_dtype(op.arg.output_dtype)
spark_type = spark_dtype(op.arg.output_dtype)

@F.udf(spark_type)
def array_slice(array):
Expand Down Expand Up @@ -1817,7 +1821,7 @@ def compile_random(*args, **kwargs):
@compiles(PandasInMemoryTable)
def compile_in_memory_table(t, op, session, **kwargs):
fields = [
pt.StructField(name, ibis_dtype_to_spark_dtype(dtype), dtype.nullable)
pt.StructField(name, spark_dtype(dtype), dtype.nullable)
for name, dtype in op.schema.items()
]
schema = pt.StructType(fields)
Expand Down
39 changes: 18 additions & 21 deletions ibis/backends/pyspark/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,45 @@ def type_to_sql_string(tval):

# maps pyspark type class to ibis type class
_SPARK_DTYPE_TO_IBIS_DTYPE = {
pt.NullType: dt.Null,
pt.StringType: dt.String,
pt.BinaryType: dt.Binary,
pt.BooleanType: dt.Boolean,
pt.ByteType: dt.Int8,
pt.DateType: dt.Date,
pt.DoubleType: dt.Float64,
pt.FloatType: dt.Float32,
pt.ByteType: dt.Int8,
pt.IntegerType: dt.Int32,
pt.LongType: dt.Int64,
pt.NullType: dt.Null,
pt.ShortType: dt.Int16,
pt.StringType: dt.String,
pt.TimestampType: dt.Timestamp,
}


@dt.dtype.register(pt.DataType)
def spark_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):
def _spark_dtype(spark_dtype_obj, nullable=True):
"""Convert Spark SQL type objects to ibis type objects."""
ibis_type_class = _SPARK_DTYPE_TO_IBIS_DTYPE.get(type(spark_dtype_obj))
return ibis_type_class(nullable=nullable)


@dt.dtype.register(pt.DecimalType)
def spark_decimal_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):
def _spark_decimal(spark_dtype_obj, nullable=True):
precision = spark_dtype_obj.precision
scale = spark_dtype_obj.scale
return dt.Decimal(precision, scale, nullable=nullable)


@dt.dtype.register(pt.ArrayType)
def spark_array_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):
def _spark_array(spark_dtype_obj, nullable=True):
value_type = dt.dtype(
spark_dtype_obj.elementType, nullable=spark_dtype_obj.containsNull
)
return dt.Array(value_type, nullable=nullable)


@dt.dtype.register(pt.MapType)
def spark_map_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):
def _spark_map(spark_dtype_obj, nullable=True):
key_type = dt.dtype(spark_dtype_obj.keyType)
value_type = dt.dtype(
spark_dtype_obj.valueType, nullable=spark_dtype_obj.valueContainsNull
Expand All @@ -69,7 +69,7 @@ def spark_map_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):


@dt.dtype.register(pt.StructType)
def spark_struct_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):
def _spark_struct(spark_dtype_obj, nullable=True):
names = spark_dtype_obj.names
fields = spark_dtype_obj.fields
ibis_types = [dt.dtype(f.dataType, nullable=f.nullable) for f in fields]
Expand All @@ -79,43 +79,40 @@ def spark_struct_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True):
_IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()}
_IBIS_DTYPE_TO_SPARK_DTYPE[dt.JSON] = pt.StringType

spark_dtype = functools.singledispatch('spark_dtype')
# from multipledispatch import Dispatcher
# spark_dtype = Dispatcher('spark_dtype')


@spark_dtype.register(object)
def default(value, **kwargs) -> pt.DataType:
@functools.singledispatch
def spark_dtype(value, **kwargs):
raise com.IbisTypeError(f'Value {value!r} is not a valid datatype')


@spark_dtype.register(pt.DataType)
def from_spark_dtype(value: pt.DataType) -> pt.DataType:
def _spark(value: pt.DataType) -> pt.DataType:
return value


@spark_dtype.register(dt.DataType)
def ibis_dtype_to_spark_dtype(ibis_dtype_obj):
def _dtype(ibis_dtype_obj):
"""Convert ibis types types to Spark SQL."""
return _IBIS_DTYPE_TO_SPARK_DTYPE.get(type(ibis_dtype_obj))()
dtype = _IBIS_DTYPE_TO_SPARK_DTYPE[type(ibis_dtype_obj)]
return dtype()


@spark_dtype.register(dt.Decimal)
def ibis_decimal_dtype_to_spark_dtype(ibis_dtype_obj):
def _decimal(ibis_dtype_obj):
precision = ibis_dtype_obj.precision
scale = ibis_dtype_obj.scale
return pt.DecimalType(precision, scale)


@spark_dtype.register(dt.Array)
def ibis_array_dtype_to_spark_dtype(ibis_dtype_obj):
def _array(ibis_dtype_obj):
element_type = spark_dtype(ibis_dtype_obj.value_type)
contains_null = ibis_dtype_obj.value_type.nullable
return pt.ArrayType(element_type, contains_null)


@spark_dtype.register(dt.Map)
def ibis_map_dtype_to_spark_dtype(ibis_dtype_obj):
def _map(ibis_dtype_obj):
key_type = spark_dtype(ibis_dtype_obj.key_type)
value_type = spark_dtype(ibis_dtype_obj.value_type)
value_contains_null = ibis_dtype_obj.value_type.nullable
Expand All @@ -124,7 +121,7 @@ def ibis_map_dtype_to_spark_dtype(ibis_dtype_obj):

@spark_dtype.register(dt.Struct)
@spark_dtype.register(Schema)
def ibis_struct_dtype_to_spark_dtype(ibis_dtype_obj):
def _struct(ibis_dtype_obj):
fields = [
pt.StructField(n, spark_dtype(t), t.nullable)
for n, t in zip(ibis_dtype_obj.names, ibis_dtype_obj.types)
Expand Down
29 changes: 27 additions & 2 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
pytest.importorskip("pyspark")

import pyspark.sql.functions as F # noqa: E402
from pyspark.sql import SparkSession # noqa: E402
import pyspark.sql.types as pt # noqa: E402
from pyspark.sql import Row, SparkSession # noqa: E402


def get_common_spark_testing_client(data_directory, connect):
Expand Down Expand Up @@ -75,7 +76,31 @@ def get_common_spark_testing_client(data_directory, connect):
df_simple = s.createDataFrame([(1, 'a')], ['foo', 'bar'])
df_simple.createOrReplaceTempView('simple')

df_struct = s.createDataFrame([((1, 2, 'a'),)], ['struct_col'])
df_struct = s.createDataFrame(
[
Row(abc=Row(a=1.0, b='banana', c=2)),
Row(abc=Row(a=2.0, b='apple', c=3)),
Row(abc=Row(a=3.0, b='orange', c=4)),
Row(abc=Row(a=None, b='banana', c=2)),
Row(abc=Row(a=2.0, b=None, c=3)),
Row(abc=None),
Row(abc=Row(a=3.0, b='orange', c=None)),
],
schema=pt.StructType(
[
pt.StructField(
"abc",
pt.StructType(
[
pt.StructField("a", pt.DoubleType(), True),
pt.StructField("b", pt.StringType(), True),
pt.StructField("c", pt.IntegerType(), True),
]
),
)
]
),
)
df_struct.createOrReplaceTempView('struct')

df_nested_types = s.createDataFrame(
Expand Down
9 changes: 6 additions & 3 deletions ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
pytestmark = [
pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"),
pytest.mark.notyet(["impala"]),
pytest.mark.notimpl(["datafusion", "pyspark"]),
pytest.mark.notimpl(["datafusion"]),
]


@pytest.mark.notimpl(["dask", "snowflake"])
@pytest.mark.broken(["pyspark"], reason="fixed in #5097")
@pytest.mark.parametrize("field", ["a", "b", "c"])
def test_single_field(backend, struct, struct_df, field):
expr = struct.abc[field]
Expand Down Expand Up @@ -56,10 +57,11 @@ def test_all_fields(struct, struct_df):
@pytest.mark.parametrize("field", ["a", "b", "c"])
def test_literal(con, field):
query = _STRUCT_LITERAL[field]
result = pd.Series([con.execute(query)])
dtype = query.type().to_pandas()
result = pd.Series([con.execute(query)], dtype=dtype)
result = result.replace({np.nan: None})
expected = pd.Series([_SIMPLE_DICT[field]])
tm.assert_series_equal(result, expected)
tm.assert_series_equal(result, expected.astype(dtype))


@pytest.mark.notimpl(["postgres", "snowflake"])
Expand Down Expand Up @@ -87,6 +89,7 @@ def test_null_literal(con, field):


@pytest.mark.notimpl(["bigquery", "dask", "pandas", "postgres", "snowflake"])
@pytest.mark.broken(["pyspark"], reason="fixed in #5097")
def test_struct_column(alltypes, df):
t = alltypes
expr = ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)).name("s")
Expand Down

0 comments on commit ce05987

Please sign in to comment.