Skip to content

Commit

Permalink
feat(datafusion): add ExtractWeekOfYear, ExtractMicrosecond, ExtractE…
Browse files Browse the repository at this point in the history
…pochSeconds
  • Loading branch information
mesejo authored and cpcloud committed Aug 23, 2023
1 parent 476a659 commit 5612d48
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
41 changes: 40 additions & 1 deletion ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def down(array: pa.Array) -> pa.Array:
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_seconds_udf",
name="extract_down_udf",
)
arg = translate(op.arg, **kw)
return extract_down_udf(arg)
Expand All @@ -996,3 +996,42 @@ def down(array: pa.Array) -> pa.Array:
def date(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_trunc(df.literal("day"), arg)


@translate.register(ops.ExtractWeekOfYear)
def extract_week_of_year(op, **kw):
arg = translate(op.arg, **kw)
return df.functions.date_part(df.literal("week"), arg)


@translate.register(ops.ExtractMicrosecond)
def extract_microsecond(op, **kw):
def us(array: pa.Array) -> pa.Array:
arr = pc.multiply(pc.millisecond(array), 1000)
return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32())

extract_microseconds_udf = df.udf(
us,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_microseconds_udf",
)
arg = translate(op.arg, **kw)
return extract_microseconds_udf(arg)


@translate.register(ops.ExtractEpochSeconds)
def extract_epoch_seconds(op, **kw):
def epoch_seconds(array: pa.Array) -> pa.Array:
return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1000_000), pa.int32())

extract_epoch_seconds_udf = df.udf(
epoch_seconds,
input_types=[PyArrowType.from_ibis(op.arg.dtype)],
return_type=PyArrowType.from_ibis(op.dtype),
volatility="immutable",
name="extract_epoch_seconds_udf",
)
arg = translate(op.arg, **kw)
return extract_epoch_seconds_udf(arg)
6 changes: 3 additions & 3 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_timestamp_extract_literal(con, func, expected):
assert con.execute(func(value).name("tmp")) == expected


@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["druid"],
raises=AttributeError,
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_timestamp_extract_milliseconds(backend, alltypes, df):
backend.assert_series_equal(result, expected)


@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["druid"],
raises=AttributeError,
Expand All @@ -323,7 +323,7 @@ def test_timestamp_extract_epoch_seconds(backend, alltypes, df):
backend.assert_series_equal(result, expected)


@pytest.mark.notimpl(["datafusion", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
["druid"],
raises=AttributeError,
Expand Down

0 comments on commit 5612d48

Please sign in to comment.