From 41e10cacae07e43f36690508bb67004deaf0120a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:30:06 -0400 Subject: [PATCH] feat(snowflake): implement interval arithmetic (#9794) Closes #9783. --- ibis/backends/sql/compilers/snowflake.py | 23 +++++++++++++++-------- ibis/backends/tests/test_temporal.py | 13 +++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 2aff5ab5512b..5074c79b3a53 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -65,8 +65,10 @@ class SnowflakeCompiler(SQLGlotCompiler): UNSUPPORTED_OPS = ( ops.RowID, ops.MultiQuantile, - ops.IntervalFromInteger, ops.IntervalAdd, + ops.IntervalSubtract, + ops.IntervalMultiply, + ops.IntervalFloorDivide, ops.TimestampDiff, ) @@ -266,6 +268,8 @@ def visit_Cast(self, op, *, arg, to): return self.if_(self.f.is_object(arg), arg, NULL) elif to.is_array(): return self.if_(self.f.is_array(arg), arg, NULL) + elif op.arg.dtype.is_integer() and to.is_interval(): + return sge.Interval(this=arg, unit=self.v[to.unit.name]) return self.cast(arg, to) def visit_ToJSONMap(self, op, *, arg): @@ -365,14 +369,17 @@ def visit_DateDelta(self, op, *, part, left, right): def visit_TimestampDelta(self, op, *, part, left, right): return self.f.timestampdiff(part, right, left, dialect=self.dialect) - def visit_TimestampDateAdd(self, op, *, left, right): - if not isinstance(op.right, ops.Literal): - raise com.OperationNotDefinedError( - f"right side of {type(op).__name__} operation must be an interval literal" - ) - return sg.exp.Add(this=left, expression=right) + def visit_TimestampAdd(self, op, *, left, right): + return self.f.timestampadd(right.unit, right.this, left, dialect=self.dialect) + + def visit_TimestampSub(self, op, *, left, right): + return self.f.timestampadd(right.unit, -right.this, left, dialect=self.dialect) + + visit_DateAdd = visit_TimestampAdd + visit_DateSub = visit_TimestampSub - visit_DateAdd = visit_TimestampAdd = visit_TimestampDateAdd + def visit_IntervalFromInteger(self, op, *, arg, unit): + return sge.Interval(this=arg, unit=self.v[unit.name]) def visit_IntegerRange(self, op, *, start, stop, step): return self.if_( diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index d1bfdd82efd0..790ffecdc529 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -579,7 +579,7 @@ def test_date_truncate(backend, alltypes, df, unit): ], ) @pytest.mark.notimpl( - ["datafusion", "sqlite", "snowflake", "mssql", "oracle", "druid", "exasol"], + ["datafusion", "sqlite", "mssql", "oracle", "druid", "exasol"], raises=com.OperationNotDefinedError, ) def test_integer_to_interval_timestamp( @@ -637,7 +637,6 @@ def convert_to_offset(offset, displacement_type=displacement_type): "impala", "mysql", "sqlite", - "snowflake", "polars", "mssql", "druid", @@ -1017,10 +1016,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl( - ["snowflake", "mssql", "exasol", "druid"], - raises=com.OperationNotDefinedError, -) +@pytest.mark.notimpl(["mssql", "exasol", "druid"], raises=com.OperationNotDefinedError) def test_interval_add_cast_scalar(backend, alltypes): timestamp_date = alltypes.timestamp_col.date() delta = ibis.literal(10).cast("interval('D')") @@ -1030,10 +1026,7 @@ def test_interval_add_cast_scalar(backend, alltypes): backend.assert_series_equal(result, expected.astype(result.dtype)) -@pytest.mark.notimpl( - ["snowflake", "mssql", "exasol", "druid"], - raises=com.OperationNotDefinedError, -) +@pytest.mark.notimpl(["mssql", "exasol", "druid"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["flink"], raises=AssertionError, reason="incorrect results") def test_interval_add_cast_column(backend, alltypes, df): timestamp_date = alltypes.timestamp_col.date()