From 5612d488bd0351da0e8c45c34ce3256834bda05d Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Tue, 22 Aug 2023 15:44:26 +0200 Subject: [PATCH] feat(datafusion): add ExtractWeekOfYear, ExtractMicrosecond, ExtractEpochSeconds --- ibis/backends/datafusion/compiler.py | 41 +++++++++++++++++++++++++++- ibis/backends/tests/test_temporal.py | 6 ++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index cdfbf2d2d5d7..cdd77f38d27b 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -986,7 +986,7 @@ def down(array: pa.Array) -> pa.Array: input_types=[PyArrowType.from_ibis(op.arg.dtype)], return_type=PyArrowType.from_ibis(op.dtype), volatility="immutable", - name="extract_seconds_udf", + name="extract_down_udf", ) arg = translate(op.arg, **kw) return extract_down_udf(arg) @@ -996,3 +996,42 @@ def down(array: pa.Array) -> pa.Array: def date(op, **kw): arg = translate(op.arg, **kw) return df.functions.date_trunc(df.literal("day"), arg) + + +@translate.register(ops.ExtractWeekOfYear) +def extract_week_of_year(op, **kw): + arg = translate(op.arg, **kw) + return df.functions.date_part(df.literal("week"), arg) + + +@translate.register(ops.ExtractMicrosecond) +def extract_microsecond(op, **kw): + def us(array: pa.Array) -> pa.Array: + arr = pc.multiply(pc.millisecond(array), 1000) + return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32()) + + extract_microseconds_udf = df.udf( + us, + input_types=[PyArrowType.from_ibis(op.arg.dtype)], + return_type=PyArrowType.from_ibis(op.dtype), + volatility="immutable", + name="extract_microseconds_udf", + ) + arg = translate(op.arg, **kw) + return extract_microseconds_udf(arg) + + +@translate.register(ops.ExtractEpochSeconds) +def extract_epoch_seconds(op, **kw): + def epoch_seconds(array: pa.Array) -> pa.Array: + return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1000_000), pa.int32()) + + extract_epoch_seconds_udf = df.udf( + epoch_seconds, + input_types=[PyArrowType.from_ibis(op.arg.dtype)], + return_type=PyArrowType.from_ibis(op.dtype), + volatility="immutable", + name="extract_epoch_seconds_udf", + ) + arg = translate(op.arg, **kw) + return extract_epoch_seconds_udf(arg) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 55d7fed7d3c8..a4b6bbae8593 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -255,7 +255,7 @@ def test_timestamp_extract_literal(con, func, expected): assert con.execute(func(value).name("tmp")) == expected -@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -302,7 +302,7 @@ def test_timestamp_extract_milliseconds(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AttributeError, @@ -323,7 +323,7 @@ def test_timestamp_extract_epoch_seconds(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["druid"], raises=AttributeError,