Skip to content

Commit

Permalink
fix(datafusion): fix some temporal operations
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Oct 27, 2023
1 parent d38d2c4 commit 3206dbc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 59 deletions.
48 changes: 36 additions & 12 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def translate_val(op, **_):

agg = AggGen(aggfunc=_aggregate)
cast = make_cast(PostgresType)
if_ = F["if"]

_simple_ops = {
ops.Abs: "abs",
Expand Down Expand Up @@ -164,6 +165,21 @@ def alias(op, *, arg, name, **_):
return arg.as_(name)


def _to_timestamp(value, target_dtype, literal=False):
tz = (
f'Some("{timezone}")'
if (timezone := target_dtype.timezone) is not None
else "None"
)
unit = (
target_dtype.unit.name.capitalize()
if target_dtype.scale is not None
else "Microsecond"
)
str_value = str(value) if literal else value
return F.arrow_cast(str_value, f"Timestamp({unit}, {tz})")


@translate_val.register(ops.Literal)
def _literal(op, *, value, dtype, **kw):
if value is None and dtype.nullable:
Expand Down Expand Up @@ -192,7 +208,7 @@ def _literal(op, *, value, dtype, **kw):

return interval(value, unit=dtype.resolution.upper())
elif dtype.is_timestamp():
return F.to_timestamp(value.isoformat())
return _to_timestamp(value, dtype, literal=True)
elif dtype.is_date():
return F.date_trunc("day", value.isoformat())
elif dtype.is_time():
Expand Down Expand Up @@ -233,9 +249,8 @@ def _cast(op, *, arg, to, **_):
if to.is_interval():
unit_name = to.unit.name.lower()
return sg.cast(F.concat(sg.cast(arg, "text"), f" {unit_name}"), "interval")
if to.is_timestamp() and (timezone := to.timezone) is not None:
unit = to.unit.name.capitalize()
return F.arrow_cast(arg, f'Timestamp({unit}, Some("{timezone}"))')
if to.is_timestamp():
return _to_timestamp(arg, to)
if to.is_decimal():
return F.arrow_cast(arg, f"{PyArrowType.from_ibis(to)}".capitalize())
return cast(arg, to)
Expand Down Expand Up @@ -492,16 +507,25 @@ def extract_day_of_the_week_index(op, *, arg, **_):
return (F.date_part("dow", arg) + 6) % 7


_DOW_INDEX_NAME = {
0: "Monday",
1: "Tuesday",
2: "Wednesday",
3: "Thursday",
4: "Friday",
5: "Saturday",
6: "Sunday",
}


@translate_val.register(ops.DayOfWeekName)
def extract_day_of_the_week_name(op, *, arg, **_):
if op.arg.dtype.is_date():
return F.extract_dow_name_date(arg)
elif op.arg.dtype.is_timestamp():
return F.extract_dow_name_timestamp(arg)
else:
raise com.OperationNotDefinedError(
f"The function is not defined for {type(op.arg)}"
)
cases, results = zip(*_DOW_INDEX_NAME.items())

return sg.exp.Case(
this=paren((F.date_part("dow", arg) + 6) % 7),
ifs=list(map(if_, cases, results)),
)


@translate_val.register(ops.Date)
Expand Down
21 changes: 0 additions & 21 deletions ibis/backends/datafusion/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,6 @@ def extract_microsecond(array: dt.Timestamp(scale=9)) -> dt.int32:
return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32())


def _extract_dow_name(array) -> str:
return pc.choose(
pc.day_of_week(array),
"Monday",
"Tuesday",
"Wednesday",
"Thursday",
"Friday",
"Saturday",
"Sunday",
)


def extract_dow_name_date(array: dt.Date) -> str:
return _extract_dow_name(array)


def extract_dow_name_timestamp(array: dt.Timestamp(scale=9)) -> str:
return _extract_dow_name(array)


def _extract_query_arrow(
arr: pa.StringArray, *, param: str | None = None
) -> pa.StringArray:
Expand Down
26 changes: 0 additions & 26 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,11 +1649,6 @@ def test_string_to_timestamp(alltypes, fmt):
)
@pytest.mark.notimpl(["mssql", "druid", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["impala"], raises=com.UnsupportedBackendType)
@pytest.mark.xfail_version(
datafusion=["datafusion"],
raises=Exception,
reason="Exception: Arrow error: Cast error: Cannot cast string to value of Date64 type",
)
def test_day_of_week_scalar(con, date, expected_index, expected_day):
expr = ibis.literal(date).cast(dt.date)
result_index = con.execute(expr.day_of_week.index().name("tmp"))
Expand Down Expand Up @@ -2268,14 +2263,6 @@ def test_integer_cast_to_timestamp_scalar(alltypes, df):
reason="PySpark doesn't handle big timestamps",
raises=pd.errors.OutOfBoundsDatetime,
)
@pytest.mark.notimpl(
["datafusion"],
raises=Exception,
reason=(
"Exception: Arrow error: Parser error: The dates that can be represented as nanoseconds have to be "
"between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"
),
)
def test_big_timestamp(con):
# TODO: test with a timezone
value = ibis.timestamp("2419-10-11 10:10:25")
Expand Down Expand Up @@ -2358,14 +2345,6 @@ def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn):
"value: OverflowError('int too big to convert'), traceback: None }"
),
)
@pytest.mark.notimpl(
["datafusion"],
raises=Exception,
reason=(
"Exception: Arrow error: Parser error: The dates that can be represented as nanoseconds have to be "
"between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"
),
)
def test_large_timestamp(con):
huge_timestamp = datetime.datetime(year=4567, month=1, day=1)
expr = ibis.timestamp("4567-01-01 00:00:00")
Expand Down Expand Up @@ -2448,11 +2427,6 @@ def test_large_timestamp(con):
raises=sa.exc.DatabaseError,
reason="ORA-01843: invalid month was specified",
)
@pytest.mark.notimpl(
["datafusion"],
raises=Exception,
reason="Exception: DataFusion error: NotImplemented",
)
def test_timestamp_precision_output(con, ts, scale, unit):
dtype = dt.Timestamp(scale=scale)
expr = ibis.literal(ts).cast(dtype)
Expand Down

0 comments on commit 3206dbc

Please sign in to comment.