From 1133973eef4192e18e52371e806fb1c6079347c9 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:26:50 -0400 Subject: [PATCH] feat(trino): support years and months in datetime arithmetic --- ibis/backends/sql/compilers/trino.py | 21 ++++++++++++--------- ibis/backends/tests/test_temporal.py | 7 ------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 2499bdc2f63d..80ed1394acb3 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -442,17 +442,20 @@ def visit_TemporalDelta(self, op, *, part, left, right): def _make_interval(self, arg, unit): short = unit.short - if short in ("Y", "Q", "M", "W"): + if short in ("Q", "W"): raise com.UnsupportedOperationError(f"Interval unit {unit!r} not supported") + + if isinstance(arg, sge.Literal): + # force strings in interval literals because trino requires it + arg.args["is_string"] = True + return super()._make_interval(arg, unit) + + elif short in ("Y", "M"): + return arg * super()._make_interval(sge.convert("1"), unit) elif short in ("D", "h", "m", "s", "ms", "us"): - if isinstance(arg, sge.Literal): - # force strings in interval literals because trino requires it - arg.args["is_string"] = True - return super()._make_interval(arg, unit) - else: - return self.f.parse_duration( - self.f.concat(self.cast(arg, dt.string), short.lower()) - ) + return self.f.parse_duration( + self.f.concat(self.cast(arg, dt.string), short.lower()) + ) else: raise com.UnsupportedOperationError( f"Interval unit {unit.name!r} not supported" diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index a035b9c62449..4c8cec0b903c 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -504,11 +504,6 @@ def test_date_truncate(backend, alltypes, df, unit): raises=TypeError, reason="duration() got an unexpected keyword argument 'months'", ), - pytest.mark.notyet( - ["trino"], - raises=com.UnsupportedOperationError, - reason="month not implemented", - ), pytest.mark.notyet( ["oracle"], raises=OracleInterfaceError, @@ -624,7 +619,6 @@ def convert_to_offset(offset, displacement_type=displacement_type): param( "Y", marks=[ - pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), pytest.mark.notyet( ["polars"], raises=TypeError, reason="not supported by polars" ), @@ -635,7 +629,6 @@ def convert_to_offset(offset, displacement_type=displacement_type): param( "M", marks=[ - pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), pytest.mark.notyet( ["polars"], raises=TypeError, reason="not supported by polars" ),