Skip to content

Commit

Permalink
feat(snowflake): implement interval arithmetic (#9794)
Browse files Browse the repository at this point in the history
Closes #9783.
  • Loading branch information
cpcloud authored Aug 8, 2024
1 parent 01dc81e commit 41e10ca
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
23 changes: 15 additions & 8 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ class SnowflakeCompiler(SQLGlotCompiler):
UNSUPPORTED_OPS = (
ops.RowID,
ops.MultiQuantile,
ops.IntervalFromInteger,
ops.IntervalAdd,
ops.IntervalSubtract,
ops.IntervalMultiply,
ops.IntervalFloorDivide,
ops.TimestampDiff,
)

Expand Down Expand Up @@ -266,6 +268,8 @@ def visit_Cast(self, op, *, arg, to):
return self.if_(self.f.is_object(arg), arg, NULL)
elif to.is_array():
return self.if_(self.f.is_array(arg), arg, NULL)
elif op.arg.dtype.is_integer() and to.is_interval():
return sge.Interval(this=arg, unit=self.v[to.unit.name])
return self.cast(arg, to)

def visit_ToJSONMap(self, op, *, arg):
Expand Down Expand Up @@ -365,14 +369,17 @@ def visit_DateDelta(self, op, *, part, left, right):
def visit_TimestampDelta(self, op, *, part, left, right):
return self.f.timestampdiff(part, right, left, dialect=self.dialect)

def visit_TimestampDateAdd(self, op, *, left, right):
if not isinstance(op.right, ops.Literal):
raise com.OperationNotDefinedError(
f"right side of {type(op).__name__} operation must be an interval literal"
)
return sg.exp.Add(this=left, expression=right)
def visit_TimestampAdd(self, op, *, left, right):
return self.f.timestampadd(right.unit, right.this, left, dialect=self.dialect)

def visit_TimestampSub(self, op, *, left, right):
return self.f.timestampadd(right.unit, -right.this, left, dialect=self.dialect)

visit_DateAdd = visit_TimestampAdd
visit_DateSub = visit_TimestampSub

visit_DateAdd = visit_TimestampAdd = visit_TimestampDateAdd
def visit_IntervalFromInteger(self, op, *, arg, unit):
return sge.Interval(this=arg, unit=self.v[unit.name])

def visit_IntegerRange(self, op, *, start, stop, step):
return self.if_(
Expand Down
13 changes: 3 additions & 10 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def test_date_truncate(backend, alltypes, df, unit):
],
)
@pytest.mark.notimpl(
["datafusion", "sqlite", "snowflake", "mssql", "oracle", "druid", "exasol"],
["datafusion", "sqlite", "mssql", "oracle", "druid", "exasol"],
raises=com.OperationNotDefinedError,
)
def test_integer_to_interval_timestamp(
Expand Down Expand Up @@ -637,7 +637,6 @@ def convert_to_offset(offset, displacement_type=displacement_type):
"impala",
"mysql",
"sqlite",
"snowflake",
"polars",
"mssql",
"druid",
Expand Down Expand Up @@ -1017,10 +1016,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
["snowflake", "mssql", "exasol", "druid"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["mssql", "exasol", "druid"], raises=com.OperationNotDefinedError)
def test_interval_add_cast_scalar(backend, alltypes):
timestamp_date = alltypes.timestamp_col.date()
delta = ibis.literal(10).cast("interval('D')")
Expand All @@ -1030,10 +1026,7 @@ def test_interval_add_cast_scalar(backend, alltypes):
backend.assert_series_equal(result, expected.astype(result.dtype))


@pytest.mark.notimpl(
["snowflake", "mssql", "exasol", "druid"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["mssql", "exasol", "druid"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["flink"], raises=AssertionError, reason="incorrect results")
def test_interval_add_cast_column(backend, alltypes, df):
timestamp_date = alltypes.timestamp_col.date()
Expand Down

0 comments on commit 41e10ca

Please sign in to comment.