From ce059870d76e68c1241ad331ae468bfeabbf213f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 17 Jul 2022 10:19:12 -0400 Subject: [PATCH] feat(pyspark): add support for struct operations --- ibis/backends/pyspark/compiler.py | 36 +++++++++++++---------- ibis/backends/pyspark/datatypes.py | 39 ++++++++++++------------- ibis/backends/pyspark/tests/conftest.py | 29 ++++++++++++++++-- ibis/backends/tests/test_struct.py | 9 ++++-- 4 files changed, 71 insertions(+), 42 deletions(-) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 14cee3062f61..b07ee187407d 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -17,11 +17,7 @@ from ibis import interval from ibis.backends.pandas.client import PandasInMemoryTable from ibis.backends.pandas.execution import execute -from ibis.backends.pyspark.datatypes import ( - ibis_array_dtype_to_spark_dtype, - ibis_dtype_to_spark_dtype, - spark_dtype, -) +from ibis.backends.pyspark.datatypes import spark_dtype from ibis.backends.pyspark.timecontext import ( combine_time_context, filter_by_time_context, @@ -236,6 +232,16 @@ def compile_struct_field(t, op, **kwargs): return arg[op.field] +@compiles(ops.StructColumn) +def compile_struct_column(t, op, **kwargs): + return F.struct( + *( + t.translate(col, **kwargs).alias(name) + for name, col in zip(op.names, op.values) + ) + ) + + @compiles(ops.SelfReference) def compile_self_reference(t, op, **kwargs): return t.translate(op.table, **kwargs) @@ -252,10 +258,7 @@ def compile_cast(t, op, **kwargs): 'in the PySpark backend. {} not allowed.'.format(type(op.arg)) ) - if op.to.is_array(): - cast_type = ibis_array_dtype_to_spark_dtype(op.to) - else: - cast_type = ibis_dtype_to_spark_dtype(op.to) + cast_type = spark_dtype(op.to) src_column = t.translate(op.arg, **kwargs) return src_column.cast(cast_type) @@ -357,8 +360,10 @@ def compile_literal(t, op, *, raw=False, **kwargs): return set(value) else: return value - elif isinstance(value, tuple): + elif dtype.is_array(): return F.array(*map(F.lit, value)) + elif dtype.is_struct(): + return F.struct(*(F.lit(val).alias(name) for name, val in value.items())) else: if isinstance(value, pd.Timestamp) and value.tz is None: value = value.tz_localize("UTC").to_pydatetime() @@ -629,7 +634,7 @@ def compile_covariance(t, op, **kwargs): fn = {"sample": F.covar_samp, "pop": F.covar_pop}[how] - pyspark_double_type = ibis_dtype_to_spark_dtype(dt.double) + pyspark_double_type = spark_dtype(dt.double) new_op = op.__class__( left=ops.Cast(op.left, to=pyspark_double_type), right=ops.Cast(op.right, to=pyspark_double_type), @@ -644,7 +649,7 @@ def compile_correlation(t, op, **kwargs): if (how := op.how) == "pop": raise ValueError("PySpark only implements sample correlation") - pyspark_double_type = ibis_dtype_to_spark_dtype(dt.double) + pyspark_double_type = spark_dtype(dt.double) new_op = op.__class__( left=ops.Cast(op.left, to=pyspark_double_type), right=ops.Cast(op.right, to=pyspark_double_type), @@ -706,7 +711,6 @@ def compile_abs(t, op, **kwargs): @compiles(ops.Clip) def compile_clip(t, op, **kwargs): - spark_dtype = ibis_dtype_to_spark_dtype(op.output_dtype) col = t.translate(op.arg, **kwargs) upper = t.translate(op.upper, **kwargs) if op.upper is not None else float('inf') lower = t.translate(op.lower, **kwargs) if op.lower is not None else float('-inf') @@ -722,7 +726,7 @@ def column_max(value, limit): def clip(column, lower_value, upper_value): return column_max(column_min(column, F.lit(lower_value)), F.lit(upper_value)) - return clip(col, lower, upper).cast(spark_dtype) + return clip(col, lower, upper).cast(spark_dtype(op.output_dtype)) @compiles(ops.Round) @@ -1575,7 +1579,7 @@ def compile_array_length(t, op, **kwargs): def compile_array_slice(t, op, **kwargs): start = op.start.value if op.start is not None else op.start stop = op.stop.value if op.stop is not None else op.stop - spark_type = ibis_array_dtype_to_spark_dtype(op.arg.output_dtype) + spark_type = spark_dtype(op.arg.output_dtype) @F.udf(spark_type) def array_slice(array): @@ -1817,7 +1821,7 @@ def compile_random(*args, **kwargs): @compiles(PandasInMemoryTable) def compile_in_memory_table(t, op, session, **kwargs): fields = [ - pt.StructField(name, ibis_dtype_to_spark_dtype(dtype), dtype.nullable) + pt.StructField(name, spark_dtype(dtype), dtype.nullable) for name, dtype in op.schema.items() ] schema = pt.StructType(fields) diff --git a/ibis/backends/pyspark/datatypes.py b/ibis/backends/pyspark/datatypes.py index 55ee0fe75e73..cdd9a07e194a 100644 --- a/ibis/backends/pyspark/datatypes.py +++ b/ibis/backends/pyspark/datatypes.py @@ -22,37 +22,37 @@ def type_to_sql_string(tval): # maps pyspark type class to ibis type class _SPARK_DTYPE_TO_IBIS_DTYPE = { - pt.NullType: dt.Null, - pt.StringType: dt.String, pt.BinaryType: dt.Binary, pt.BooleanType: dt.Boolean, + pt.ByteType: dt.Int8, pt.DateType: dt.Date, pt.DoubleType: dt.Float64, pt.FloatType: dt.Float32, - pt.ByteType: dt.Int8, pt.IntegerType: dt.Int32, pt.LongType: dt.Int64, + pt.NullType: dt.Null, pt.ShortType: dt.Int16, + pt.StringType: dt.String, pt.TimestampType: dt.Timestamp, } @dt.dtype.register(pt.DataType) -def spark_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): +def _spark_dtype(spark_dtype_obj, nullable=True): """Convert Spark SQL type objects to ibis type objects.""" ibis_type_class = _SPARK_DTYPE_TO_IBIS_DTYPE.get(type(spark_dtype_obj)) return ibis_type_class(nullable=nullable) @dt.dtype.register(pt.DecimalType) -def spark_decimal_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): +def _spark_decimal(spark_dtype_obj, nullable=True): precision = spark_dtype_obj.precision scale = spark_dtype_obj.scale return dt.Decimal(precision, scale, nullable=nullable) @dt.dtype.register(pt.ArrayType) -def spark_array_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): +def _spark_array(spark_dtype_obj, nullable=True): value_type = dt.dtype( spark_dtype_obj.elementType, nullable=spark_dtype_obj.containsNull ) @@ -60,7 +60,7 @@ def spark_array_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): @dt.dtype.register(pt.MapType) -def spark_map_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): +def _spark_map(spark_dtype_obj, nullable=True): key_type = dt.dtype(spark_dtype_obj.keyType) value_type = dt.dtype( spark_dtype_obj.valueType, nullable=spark_dtype_obj.valueContainsNull @@ -69,7 +69,7 @@ def spark_map_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): @dt.dtype.register(pt.StructType) -def spark_struct_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): +def _spark_struct(spark_dtype_obj, nullable=True): names = spark_dtype_obj.names fields = spark_dtype_obj.fields ibis_types = [dt.dtype(f.dataType, nullable=f.nullable) for f in fields] @@ -79,43 +79,40 @@ def spark_struct_dtype_to_ibis_dtype(spark_dtype_obj, nullable=True): _IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()} _IBIS_DTYPE_TO_SPARK_DTYPE[dt.JSON] = pt.StringType -spark_dtype = functools.singledispatch('spark_dtype') -# from multipledispatch import Dispatcher -# spark_dtype = Dispatcher('spark_dtype') - -@spark_dtype.register(object) -def default(value, **kwargs) -> pt.DataType: +@functools.singledispatch +def spark_dtype(value, **kwargs): raise com.IbisTypeError(f'Value {value!r} is not a valid datatype') @spark_dtype.register(pt.DataType) -def from_spark_dtype(value: pt.DataType) -> pt.DataType: +def _spark(value: pt.DataType) -> pt.DataType: return value @spark_dtype.register(dt.DataType) -def ibis_dtype_to_spark_dtype(ibis_dtype_obj): +def _dtype(ibis_dtype_obj): """Convert ibis types types to Spark SQL.""" - return _IBIS_DTYPE_TO_SPARK_DTYPE.get(type(ibis_dtype_obj))() + dtype = _IBIS_DTYPE_TO_SPARK_DTYPE[type(ibis_dtype_obj)] + return dtype() @spark_dtype.register(dt.Decimal) -def ibis_decimal_dtype_to_spark_dtype(ibis_dtype_obj): +def _decimal(ibis_dtype_obj): precision = ibis_dtype_obj.precision scale = ibis_dtype_obj.scale return pt.DecimalType(precision, scale) @spark_dtype.register(dt.Array) -def ibis_array_dtype_to_spark_dtype(ibis_dtype_obj): +def _array(ibis_dtype_obj): element_type = spark_dtype(ibis_dtype_obj.value_type) contains_null = ibis_dtype_obj.value_type.nullable return pt.ArrayType(element_type, contains_null) @spark_dtype.register(dt.Map) -def ibis_map_dtype_to_spark_dtype(ibis_dtype_obj): +def _map(ibis_dtype_obj): key_type = spark_dtype(ibis_dtype_obj.key_type) value_type = spark_dtype(ibis_dtype_obj.value_type) value_contains_null = ibis_dtype_obj.value_type.nullable @@ -124,7 +121,7 @@ def ibis_map_dtype_to_spark_dtype(ibis_dtype_obj): @spark_dtype.register(dt.Struct) @spark_dtype.register(Schema) -def ibis_struct_dtype_to_spark_dtype(ibis_dtype_obj): +def _struct(ibis_dtype_obj): fields = [ pt.StructField(n, spark_dtype(t), t.nullable) for n, t in zip(ibis_dtype_obj.names, ibis_dtype_obj.types) diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index ae162dc5cea4..e22d4187dc4d 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -16,7 +16,8 @@ pytest.importorskip("pyspark") import pyspark.sql.functions as F # noqa: E402 -from pyspark.sql import SparkSession # noqa: E402 +import pyspark.sql.types as pt # noqa: E402 +from pyspark.sql import Row, SparkSession # noqa: E402 def get_common_spark_testing_client(data_directory, connect): @@ -75,7 +76,31 @@ def get_common_spark_testing_client(data_directory, connect): df_simple = s.createDataFrame([(1, 'a')], ['foo', 'bar']) df_simple.createOrReplaceTempView('simple') - df_struct = s.createDataFrame([((1, 2, 'a'),)], ['struct_col']) + df_struct = s.createDataFrame( + [ + Row(abc=Row(a=1.0, b='banana', c=2)), + Row(abc=Row(a=2.0, b='apple', c=3)), + Row(abc=Row(a=3.0, b='orange', c=4)), + Row(abc=Row(a=None, b='banana', c=2)), + Row(abc=Row(a=2.0, b=None, c=3)), + Row(abc=None), + Row(abc=Row(a=3.0, b='orange', c=None)), + ], + schema=pt.StructType( + [ + pt.StructField( + "abc", + pt.StructType( + [ + pt.StructField("a", pt.DoubleType(), True), + pt.StructField("b", pt.StringType(), True), + pt.StructField("c", pt.IntegerType(), True), + ] + ), + ) + ] + ), + ) df_struct.createOrReplaceTempView('struct') df_nested_types = s.createDataFrame( diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 7bcf15263d0b..f0cea815540d 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -12,11 +12,12 @@ pytestmark = [ pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"), pytest.mark.notyet(["impala"]), - pytest.mark.notimpl(["datafusion", "pyspark"]), + pytest.mark.notimpl(["datafusion"]), ] @pytest.mark.notimpl(["dask", "snowflake"]) +@pytest.mark.broken(["pyspark"], reason="fixed in #5097") @pytest.mark.parametrize("field", ["a", "b", "c"]) def test_single_field(backend, struct, struct_df, field): expr = struct.abc[field] @@ -56,10 +57,11 @@ def test_all_fields(struct, struct_df): @pytest.mark.parametrize("field", ["a", "b", "c"]) def test_literal(con, field): query = _STRUCT_LITERAL[field] - result = pd.Series([con.execute(query)]) + dtype = query.type().to_pandas() + result = pd.Series([con.execute(query)], dtype=dtype) result = result.replace({np.nan: None}) expected = pd.Series([_SIMPLE_DICT[field]]) - tm.assert_series_equal(result, expected) + tm.assert_series_equal(result, expected.astype(dtype)) @pytest.mark.notimpl(["postgres", "snowflake"]) @@ -87,6 +89,7 @@ def test_null_literal(con, field): @pytest.mark.notimpl(["bigquery", "dask", "pandas", "postgres", "snowflake"]) +@pytest.mark.broken(["pyspark"], reason="fixed in #5097") def test_struct_column(alltypes, df): t = alltypes expr = ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)).name("s")