From 4a8d611965b9015f442e258a56494b47e7e7626b Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Wed, 6 Sep 2023 15:09:58 -0400 Subject: [PATCH] fix(pyspark): gate datediff op to restore pyspark 3.2 support --- ibis/backends/pyspark/compiler.py | 21 ++++++++++++++------- ibis/backends/tests/test_temporal.py | 5 +++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index f705dd488429..c4dab04ac275 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -9,6 +9,7 @@ import pyspark import pyspark.sql.functions as F import pyspark.sql.types as pt +from packaging.version import parse as vparse from pyspark.sql import Window from pyspark.sql.functions import PandasUDFType, pandas_udf @@ -1557,14 +1558,20 @@ def compile_date_sub(t, op, **kwargs): ) -@compiles(ops.DateDiff) -def compile_date_diff(t, op, **kwargs): - left = t.translate(op.left, **kwargs) - right = t.translate(op.right, **kwargs) +if vparse(pyspark.__version__) >= vparse("3.3"): - return F.concat(F.lit("INTERVAL '"), F.datediff(left, right), F.lit("' DAY")).cast( - pt.DayTimeIntervalType(pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY) - ) + @compiles(ops.DateDiff) + def compile_date_diff(t, op, **kwargs): + left = t.translate(op.left, **kwargs) + right = t.translate(op.right, **kwargs) + + return F.concat( + F.lit("INTERVAL '"), F.datediff(left, right), F.lit("' DAY") + ).cast( + pt.DayTimeIntervalType( + pt.DayTimeIntervalType.DAY, pt.DayTimeIntervalType.DAY + ) + ) @compiles(ops.TimestampAdd) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index a4b6bbae8593..04544fa5aaab 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -949,6 +949,11 @@ def convert_to_offset(x): ), id="date-subtract-date", marks=[ + pytest.mark.xfail_version( + pyspark=["pyspark<3.3"], + raises=AttributeError, + reason="DayTimeIntervalType added in pyspark 3.3", + ), pytest.mark.notimpl(["bigquery"], raises=com.OperationNotDefinedError), pytest.mark.notimpl( ["druid"],