From 2bffa5ad4be33cfa12ab5967fb96bce92ca046fe Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Wed, 8 Nov 2023 15:06:36 +0100 Subject: [PATCH] feat(datafusion): add TimestampFromUNIX and subtract/add operations --- ibis/backends/datafusion/compiler/values.py | 28 +++++++++++++++++---- ibis/backends/tests/test_temporal.py | 17 +++++++++---- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index 0df4a1ef8867..402ea004a489 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -20,7 +20,7 @@ parenthesize, ) from ibis.backends.base.sqlglot.datatypes import PostgresType -from ibis.common.temporal import IntervalUnit +from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType from ibis.formats.pyarrow import PyArrowType @@ -108,7 +108,6 @@ def translate_val(op, **_): ops.ArrayContains: "array_contains", ops.ArrayLength: "array_length", ops.ArrayRemove: "array_remove_all", - ops.StringLength: "length", } for _op, _name in _simple_ops.items(): @@ -148,6 +147,11 @@ def _fmt(_, _name: str = _name, **kw): ops.DateAdd: operator.add, ops.DateSub: operator.sub, ops.DateDiff: operator.sub, + ops.TimestampDiff: operator.sub, + ops.TimestampSub: operator.sub, + ops.TimestampAdd: operator.add, + ops.IntervalAdd: operator.add, + ops.IntervalSubtract: operator.sub, } @@ -212,7 +216,7 @@ def _literal(op, *, value, dtype, **kw): "DataFusion doesn't support subsecond interval resolutions" ) - return interval(value, unit=dtype.resolution.upper()) + return interval(value, unit=dtype.unit.plural.lower()) elif dtype.is_timestamp(): return _to_timestamp(value, dtype, literal=True) elif dtype.is_date(): @@ -780,10 +784,24 @@ def is_nan(op, *, arg, **_): @translate_val.register(ops.ArrayStringJoin) -def array_string_join(op, *, sep, arg): +def array_string_join(op, *, sep, arg, **_): return F.array_join(arg, sep) @translate_val.register(ops.FindInSet) -def array_string_find(op, *, needle, values): +def array_string_find(op, *, needle, values, **_): return F.coalesce(F.array_position(F.make_array(*values), needle), 0) + + +@translate_val.register(ops.TimestampFromUNIX) +def timestamp_from_unix(op, *, arg, unit, **_): + if unit == TimestampUnit.SECOND: + return F.from_unixtime(arg) + elif unit in ( + TimestampUnit.MILLISECOND, + TimestampUnit.MICROSECOND, + TimestampUnit.NANOSECOND, + ): + return F.arrow_cast(arg, f"Timestamp({unit.name.capitalize()}, None)") + else: + raise com.UnsupportedOperationError(f"Unsupported unit {unit}") diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 5970994f2e4a..8e92ae945aff 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -1159,6 +1159,11 @@ def convert_to_offset(x): "CalciteContextException: Cannot apply '-' to arguments of type ' - '." ), ), + pytest.mark.broken( + ["datafusion"], + raises=Exception, + reason="pyarrow.lib.ArrowInvalid: Casting from duration[us] to duration[s] would lose data", + ), ], ), param( @@ -1186,12 +1191,16 @@ def convert_to_offset(x): raises=com.UnsupportedOperationError, reason="DATE_DIFF is not supported in Flink", ), + pytest.mark.broken( + ["datafusion"], + raises=Exception, + reason="pyarrow.lib.ArrowNotImplementedError: Unsupported cast", + ), ], ), ], ) @pytest.mark.notimpl(["mssql", "oracle"], raises=com.OperationNotDefinedError) -@pytest.mark.broken(["datafusion"], raises=BaseException) def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): expr = expr_fn(alltypes, backend).name("tmp") expected = expected_fn(df, backend) @@ -1377,9 +1386,7 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): ), ], ) -@pytest.mark.notimpl( - ["datafusion", "sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError) def test_temporal_binop_pandas_timedelta( backend, con, alltypes, df, timedelta, temporal_fn ): @@ -1775,7 +1782,7 @@ def test_strftime(backend, alltypes, df, expr_fn, pandas_pattern): ], ) @pytest.mark.notimpl( - ["datafusion", "mysql", "postgres", "sqlite", "druid", "oracle"], + ["mysql", "postgres", "sqlite", "druid", "oracle"], raises=com.OperationNotDefinedError, ) def test_integer_to_timestamp(backend, con, unit):