From 067120df1f3c0d2b45af60fc1b28cb4b0463d5fb Mon Sep 17 00:00:00 2001 From: tokoko Date: Sun, 26 Mar 2023 03:35:29 +0400 Subject: [PATCH] feat(pyspark): add partial support for interval types --- ibis/backends/pyspark/datatypes.py | 22 +++++++++ ibis/backends/pyspark/tests/conftest.py | 59 ++++++++++++++++++++++- ibis/backends/pyspark/tests/test_basic.py | 21 ++++++++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/ibis/backends/pyspark/datatypes.py b/ibis/backends/pyspark/datatypes.py index 8dd6c34f26b4..14d76d4d4feb 100644 --- a/ibis/backends/pyspark/datatypes.py +++ b/ibis/backends/pyspark/datatypes.py @@ -79,6 +79,28 @@ def _spark_struct(spark_dtype_obj, nullable=True): return dt.Struct(fields, nullable=nullable) +_SPARK_INTERVAL_TO_IBIS_INTERVAL = { + pt.DayTimeIntervalType.SECOND: 's', + pt.DayTimeIntervalType.MINUTE: 'm', + pt.DayTimeIntervalType.HOUR: 'h', + pt.DayTimeIntervalType.DAY: 'D', +} + + +@dt.dtype.register(pt.DayTimeIntervalType) +def _spark_struct(spark_dtype_obj, nullable=True): + if ( + spark_dtype_obj.startField == spark_dtype_obj.endField + and spark_dtype_obj.startField in _SPARK_INTERVAL_TO_IBIS_INTERVAL + ): + return dt.Interval( + _SPARK_INTERVAL_TO_IBIS_INTERVAL[spark_dtype_obj.startField], + nullable=nullable, + ) + else: + raise com.IbisTypeError("DayTimeIntervalType couldn't be converted to Interval") + + _IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()} _IBIS_DTYPE_TO_SPARK_DTYPE[dt.JSON] = pt.StringType diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index d52d57549a7c..ec8ffe7c6ef8 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from datetime import datetime, timezone +from datetime import datetime, timezone, timedelta import numpy as np import pandas as pd @@ -276,6 +276,63 @@ def client(data_directory): df_time_indexed.createTempView('time_indexed_table') + df_interval = client._session.createDataFrame( + [ + [ + timedelta(days=10), + timedelta(hours=10), + timedelta(minutes=10), + timedelta(seconds=10), + ] + ], + pt.StructType( + [ + pt.StructField( + "interval_day", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY + ), + ), + pt.StructField( + "interval_hour", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.HOUR, pt.DayTimeIntervalType.HOUR + ), + ), + pt.StructField( + "interval_minute", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.MINUTE, pt.DayTimeIntervalType.MINUTE + ), + ), + pt.StructField( + "interval_second", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.SECOND, pt.DayTimeIntervalType.SECOND + ), + ), + ] + ), + ) + + df_interval.createTempView('interval_table') + + df_interval_invalid = client._session.createDataFrame( + [[timedelta(days=10, hours=10, minutes=10, seconds=10)]], + pt.StructType( + [ + pt.StructField( + "interval_day_hour", + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.HOUR + ), + ) + ] + ), + ) + + df_interval_invalid.createTempView('invalid_interval_table') + return client diff --git a/ibis/backends/pyspark/tests/test_basic.py b/ibis/backends/pyspark/tests/test_basic.py index c2dd523bb806..0072d55fb405 100644 --- a/ibis/backends/pyspark/tests/test_basic.py +++ b/ibis/backends/pyspark/tests/test_basic.py @@ -10,6 +10,8 @@ import pyspark.sql.functions as F # noqa: E402 from ibis.backends.pyspark.compiler import _can_be_replaced_by_column_name # noqa: E402 +from ibis.expr import datatypes as dt +from ibis.common.exceptions import IbisTypeError def test_basic(client): @@ -211,3 +213,22 @@ def test_can_be_replaced_by_column_name(selection_fn, selection_idx, expected): selection_to_test = table.op().selections[selection_idx] result = _can_be_replaced_by_column_name(selection_to_test, table.op().table) assert result == expected + + +def test_interval_columns(client): + table = client.table('interval_table') + assert table.schema() == ibis.schema( + pairs=[ + ('interval_day', dt.Interval('D')), + ('interval_hour', dt.Interval('h')), + ('interval_minute', dt.Interval('m')), + ('interval_second', dt.Interval('s')), + ] + ) + + +def test_interval_columns_invalid(client): + with pytest.raises( + IbisTypeError, match="DayTimeIntervalType couldn't be converted to Interval" + ): + client.table('invalid_interval_table')