Skip to content

Commit

Permalink
feat(trino): add EXTRACT-based functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Dec 9, 2022
1 parent 2f262cd commit 6549657
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 22 deletions.
14 changes: 14 additions & 0 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,13 @@ def _count_star(t, op):
return sa.func.count(t.translate(ops.Where(where, 1, None)))


def _extract(fmt: str):
def translator(t, op: ops.Node):
return sa.cast(sa.extract(fmt, t.translate(op.arg)), sa.SMALLINT)

return translator


sqlalchemy_operation_registry: Dict[Any, Any] = {
ops.Alias: _alias,
ops.And: fixed_arity(operator.and_, 2),
Expand Down Expand Up @@ -612,6 +619,13 @@ def _count_star(t, op):
ops.BitwiseRightShift: _bitwise_op(">>"),
ops.BitwiseNot: _bitwise_not,
ops.JSONGetItem: fixed_arity(lambda x, y: x.op("->")(y), 2),
ops.ExtractYear: _extract('year'),
ops.ExtractQuarter: _extract('quarter'),
ops.ExtractMonth: _extract('month'),
ops.ExtractDay: _extract('day'),
ops.ExtractHour: _extract('hour'),
ops.ExtractMinute: _extract('minute'),
ops.ExtractSecond: _extract('second'),
}


Expand Down
20 changes: 6 additions & 14 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.backends.base.sql.alchemy.registry import (
_bitwise_op,
_extract,
geospatial_functions,
get_col_or_deferred_col,
)
Expand All @@ -38,19 +39,16 @@
operation_registry.update(geospatial_functions)


def _extract(fmt, output_type=sa.SMALLINT):
def translator(t, op, output_type=output_type):
sa_arg = t.translate(op.arg)
return sa.cast(sa.extract(fmt, sa_arg), output_type)

return translator
def _epoch(t, op):
sa_arg = t.translate(op.arg)
return sa.cast(sa.extract('epoch', sa_arg), sa.INTEGER)


def _second(t, op):
# extracting the second gives us the fractional part as well, so smash that
# with a cast to SMALLINT
sa_arg = t.translate(op.arg)
return sa.cast(sa.func.FLOOR(sa.extract('second', sa_arg)), sa.SMALLINT)
return sa.cast(sa.func.floor(sa.extract('second', sa_arg)), sa.SMALLINT)


def _millisecond(t, op):
Expand Down Expand Up @@ -570,15 +568,9 @@ def _array_collect(t, op):
ops.TimestampSub: fixed_arity(operator.sub, 2),
ops.TimestampDiff: fixed_arity(operator.sub, 2),
ops.Strftime: _strftime,
ops.ExtractYear: _extract('year'),
ops.ExtractMonth: _extract('month'),
ops.ExtractDay: _extract('day'),
ops.ExtractEpochSeconds: _epoch,
ops.ExtractDayOfYear: _extract('doy'),
ops.ExtractQuarter: _extract('quarter'),
ops.ExtractEpochSeconds: _extract('epoch', sa.Integer),
ops.ExtractWeekOfYear: _extract('week'),
ops.ExtractHour: _extract('hour'),
ops.ExtractMinute: _extract('minute'),
ops.ExtractSecond: _second,
ops.ExtractMillisecond: _millisecond,
ops.DayOfWeekIndex: _day_of_week_index,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_distinct_column(alltypes, df, column):
("day", set(range(1, 32))),
],
)
@pytest.mark.notimpl(["datafusion", "trino"])
@pytest.mark.notimpl(["datafusion"])
@pytest.mark.notyet(["impala"])
def test_date_extract_field(con, opname, expected):
op = operator.methodcaller(opname)
Expand Down
15 changes: 8 additions & 7 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
),
],
)
@pytest.mark.notimpl(["datafusion", "trino"])
@pytest.mark.notimpl(["datafusion"])
def test_date_extract(backend, alltypes, df, attr, expr_fn):
expr = getattr(expr_fn(alltypes.timestamp_col), attr)()
expected = getattr(df.timestamp_col.dt, attr).astype('int32')
Expand All @@ -44,15 +44,16 @@ def test_date_extract(backend, alltypes, df, attr, expr_fn):
'month',
'day',
param(
'day_of_year', marks=pytest.mark.notimpl(["bigquery", "impala", "mssql"])
'day_of_year',
marks=pytest.mark.notimpl(["bigquery", "impala", "mssql", "trino"]),
),
param('quarter', marks=pytest.mark.notimpl(["mssql"])),
'hour',
'minute',
'second',
],
)
@pytest.mark.notimpl(["datafusion", "trino"])
@pytest.mark.notimpl(["datafusion"])
def test_timestamp_extract(backend, alltypes, df, attr):
method = getattr(alltypes.timestamp_col, attr)
expr = method().name(attr)
Expand All @@ -77,7 +78,7 @@ def test_timestamp_extract(backend, alltypes, df, attr):
359,
id='millisecond',
marks=[
pytest.mark.notimpl(["clickhouse", "pyspark"]),
pytest.mark.notimpl(["clickhouse", "pyspark", "trino"]),
pytest.mark.broken(
["mysql"],
reason="MySQL implementation of milliseconds is broken",
Expand All @@ -88,17 +89,17 @@ def test_timestamp_extract(backend, alltypes, df, attr):
lambda x: x.day_of_week.index(),
1,
id='day_of_week_index',
marks=pytest.mark.notimpl(["mssql"]),
marks=pytest.mark.notimpl(["mssql", "trino"]),
),
param(
lambda x: x.day_of_week.full_name(),
'Tuesday',
id='day_of_week_full_name',
marks=pytest.mark.notimpl(["mssql"]),
marks=pytest.mark.notimpl(["mssql", "trino"]),
),
],
)
@pytest.mark.notimpl(["datafusion", "snowflake", "trino"])
@pytest.mark.notimpl(["datafusion", "snowflake"])
def test_timestamp_extract_literal(con, func, expected):
value = ibis.timestamp('2015-09-01 14:48:05.359')
assert con.execute(func(value).name("tmp")) == expected
Expand Down

0 comments on commit 6549657

Please sign in to comment.