From 6be6c2be0a46c24c82af386e2452b91b57f0fdf4 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Sun, 13 Aug 2023 14:23:25 +0200 Subject: [PATCH] feat(datafusion): add temporal functions --- ibis/backends/datafusion/compiler.py | 111 +++++++++++++++++++++++++++ ibis/backends/tests/test_temporal.py | 25 +++--- 2 files changed, 122 insertions(+), 14 deletions(-) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index e65cc23dee95..cdfbf2d2d5d7 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -885,3 +885,114 @@ def join(op, **kw): ) return left.join(right, join_keys=(left_keys, right_keys), how=how) + + +@translate.register(ops.ExtractYear) +def extract_year(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("year"), arg) + + +@translate.register(ops.ExtractMonth) +def extract_month(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("month"), arg) + + +@translate.register(ops.ExtractDay) +def extract_day(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("day"), arg) + + +@translate.register(ops.ExtractQuarter) +def extract_quarter(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("quarter"), arg) + + +@translate.register(ops.ExtractMinute) +def extract_minute(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("minute"), arg) + + +@translate.register(ops.ExtractHour) +def extract_hour(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("hour"), arg) + + +@translate.register(ops.ExtractMillisecond) +def extract_millisecond(op, **kw): + def ms(array: pa.Array) -> pa.Array: + return pc.cast(pc.millisecond(array), pa.int32()) + + extract_milliseconds_udf = df.udf( + ms, + input_types=[PyArrowType.from_ibis(op.arg.dtype)], + return_type=PyArrowType.from_ibis(op.dtype), + volatility="immutable", + name="extract_milliseconds_udf", + ) + arg = translate(op.arg, **kw) + return extract_milliseconds_udf(arg) + + +@translate.register(ops.ExtractSecond) +def extract_second(op, **kw): + def s(array: pa.Array) -> pa.Array: + return pc.cast(pc.second(array), pa.int32()) + + extract_seconds_udf = df.udf( + s, + input_types=[PyArrowType.from_ibis(op.arg.dtype)], + return_type=PyArrowType.from_ibis(op.dtype), + volatility="immutable", + name="extract_seconds_udf", + ) + arg = translate(op.arg, **kw) + return extract_seconds_udf(arg) + + +@translate.register(ops.ExtractDayOfYear) +def extract_day_of_the_year(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("doy"), arg) + + +@translate.register(ops.DayOfWeekIndex) +def extract_day_of_the_week_index(op, **kw): + arg = translate(op.arg, **kw) + return (df.functions.date_part(df.literal("dow"), arg) + df.lit(6)) % df.lit(7) + + +@translate.register(ops.DayOfWeekName) +def extract_down(op, **kw): + def down(array: pa.Array) -> pa.Array: + return pc.choose( + pc.day_of_week(array), + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + "Saturday", + "Sunday", + ) + + extract_down_udf = df.udf( + down, + input_types=[PyArrowType.from_ibis(op.arg.dtype)], + return_type=PyArrowType.from_ibis(op.dtype), + volatility="immutable", + name="extract_seconds_udf", + ) + arg = translate(op.arg, **kw) + return extract_down_udf(arg) + + +@translate.register(ops.Date) +def date(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_trunc(df.literal("day"), arg) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 5c2d6413d829..55d7fed7d3c8 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -72,7 +72,6 @@ ), ], ) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["druid"], raises=AttributeError, @@ -108,7 +107,7 @@ def test_date_extract(backend, alltypes, df, attr, expr_fn): "second", ], ) -@pytest.mark.notimpl(["datafusion", "druid"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["druid"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["druid"], raises=AttributeError, @@ -251,7 +250,6 @@ def test_timestamp_extract(backend, alltypes, df, attr): ), ], ) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_timestamp_extract_literal(con, func, expected): value = ibis.timestamp("2015-09-01 14:48:05.359") assert con.execute(func(value).name("tmp")) == expected @@ -283,7 +281,7 @@ def test_timestamp_extract_microseconds(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -1581,10 +1579,13 @@ def test_string_to_timestamp(alltypes, fmt): param("2017-01-07", 5, "Saturday", id="saturday"), ], ) -@pytest.mark.notimpl( - ["datafusion", "mssql", "druid", "oracle"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["mssql", "druid", "oracle"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["impala"], raises=com.UnsupportedBackendType) +@pytest.mark.broken( + ["datafusion"], + raises=Exception, + reason="Exception: Arrow error: Cast error: Cannot cast string to value of Date64 type", +) def test_day_of_week_scalar(con, date, expected_index, expected_day): expr = ibis.literal(date).cast(dt.date) result_index = con.execute(expr.day_of_week.index().name("tmp")) @@ -1594,9 +1595,7 @@ def test_day_of_week_scalar(con, date, expected_index, expected_day): assert result_day.lower() == expected_day.lower() -@pytest.mark.notimpl( - ["datafusion", "mssql", "oracle"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["mssql", "oracle"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -1632,7 +1631,7 @@ def test_day_of_week_column(backend, alltypes, df): ), ], ) -@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["druid"], raises=AttributeError, @@ -2125,9 +2124,7 @@ def test_date_column_from_iso(con, alltypes, df): tm.assert_series_equal(golden.rename("tmp"), actual.rename("tmp")) -@pytest.mark.notimpl( - ["datafusion", "druid", "oracle"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["druid", "oracle"], raises=com.OperationNotDefinedError) @pytest.mark.notyet( ["pyspark"], raises=com.UnsupportedOperationError,