Skip to content

Commit

Permalink
feat(pyspark): add partial support for interval types
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko authored and cpcloud committed Mar 26, 2023
1 parent e2c159c commit 067120d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
22 changes: 22 additions & 0 deletions ibis/backends/pyspark/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ def _spark_struct(spark_dtype_obj, nullable=True):
return dt.Struct(fields, nullable=nullable)


_SPARK_INTERVAL_TO_IBIS_INTERVAL = {
pt.DayTimeIntervalType.SECOND: 's',
pt.DayTimeIntervalType.MINUTE: 'm',
pt.DayTimeIntervalType.HOUR: 'h',
pt.DayTimeIntervalType.DAY: 'D',
}


@dt.dtype.register(pt.DayTimeIntervalType)
def _spark_struct(spark_dtype_obj, nullable=True):
if (
spark_dtype_obj.startField == spark_dtype_obj.endField
and spark_dtype_obj.startField in _SPARK_INTERVAL_TO_IBIS_INTERVAL
):
return dt.Interval(
_SPARK_INTERVAL_TO_IBIS_INTERVAL[spark_dtype_obj.startField],
nullable=nullable,
)
else:
raise com.IbisTypeError("DayTimeIntervalType couldn't be converted to Interval")


_IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()}
_IBIS_DTYPE_TO_SPARK_DTYPE[dt.JSON] = pt.StringType

Expand Down
59 changes: 58 additions & 1 deletion ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import os
from datetime import datetime, timezone
from datetime import datetime, timezone, timedelta

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -276,6 +276,63 @@ def client(data_directory):

df_time_indexed.createTempView('time_indexed_table')

df_interval = client._session.createDataFrame(
[
[
timedelta(days=10),
timedelta(hours=10),
timedelta(minutes=10),
timedelta(seconds=10),
]
],
pt.StructType(
[
pt.StructField(
"interval_day",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY
),
),
pt.StructField(
"interval_hour",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.HOUR, pt.DayTimeIntervalType.HOUR
),
),
pt.StructField(
"interval_minute",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.MINUTE, pt.DayTimeIntervalType.MINUTE
),
),
pt.StructField(
"interval_second",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.SECOND, pt.DayTimeIntervalType.SECOND
),
),
]
),
)

df_interval.createTempView('interval_table')

df_interval_invalid = client._session.createDataFrame(
[[timedelta(days=10, hours=10, minutes=10, seconds=10)]],
pt.StructType(
[
pt.StructField(
"interval_day_hour",
pt.DayTimeIntervalType(
pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.HOUR
),
)
]
),
)

df_interval_invalid.createTempView('invalid_interval_table')

return client


Expand Down
21 changes: 21 additions & 0 deletions ibis/backends/pyspark/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pyspark.sql.functions as F # noqa: E402

from ibis.backends.pyspark.compiler import _can_be_replaced_by_column_name # noqa: E402
from ibis.expr import datatypes as dt
from ibis.common.exceptions import IbisTypeError


def test_basic(client):
Expand Down Expand Up @@ -211,3 +213,22 @@ def test_can_be_replaced_by_column_name(selection_fn, selection_idx, expected):
selection_to_test = table.op().selections[selection_idx]
result = _can_be_replaced_by_column_name(selection_to_test, table.op().table)
assert result == expected


def test_interval_columns(client):
table = client.table('interval_table')
assert table.schema() == ibis.schema(
pairs=[
('interval_day', dt.Interval('D')),
('interval_hour', dt.Interval('h')),
('interval_minute', dt.Interval('m')),
('interval_second', dt.Interval('s')),
]
)


def test_interval_columns_invalid(client):
with pytest.raises(
IbisTypeError, match="DayTimeIntervalType couldn't be converted to Interval"
):
client.table('invalid_interval_table')

0 comments on commit 067120d

Please sign in to comment.