Skip to content

Commit

Permalink
feat(trino): support temporal operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Oct 13, 2023
1 parent f55d0db commit 8b8e885
Show file tree
Hide file tree
Showing 16 changed files with 132 additions and 32 deletions.
58 changes: 48 additions & 10 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,11 @@ def test_date_truncate(backend, alltypes, df, unit):
raises=com.UnsupportedOperationError,
reason="Interval from integer column is unsupported for the PySpark backend.",
),
pytest.mark.notyet(
["trino"],
raises=com.UnsupportedOperationError,
reason="year not implemented",
),
],
),
param("Q", pd.offsets.DateOffset, marks=pytest.mark.xfail),
Expand Down Expand Up @@ -610,6 +615,11 @@ def test_date_truncate(backend, alltypes, df, unit):
raises=com.UnsupportedOperationError,
reason="Interval from integer column is unsupported for the PySpark backend.",
),
pytest.mark.notyet(
["trino"],
raises=com.UnsupportedOperationError,
reason="month not implemented",
),
],
),
param(
Expand All @@ -632,6 +642,11 @@ def test_date_truncate(backend, alltypes, df, unit):
raises=com.UnsupportedOperationError,
reason="Interval from integer column is unsupported for the PySpark backend.",
),
pytest.mark.notyet(
["trino"],
raises=com.UnsupportedOperationError,
reason="week not implemented",
),
],
),
param(
Expand Down Expand Up @@ -704,12 +719,17 @@ def test_date_truncate(backend, alltypes, df, unit):
raises=com.UnsupportedArgumentError,
reason="Interval unit \"us\" is not allowed. Allowed units are: ['Y', 'W', 'M', 'D', 'h', 'm', 's']",
),
pytest.mark.notimpl(
["trino"],
raises=AssertionError,
reason="we're dropping microseconds to ensure results consistent with pandas",
),
],
),
],
)
@pytest.mark.notimpl(
["datafusion", "sqlite", "snowflake", "trino", "mssql", "oracle"],
["datafusion", "sqlite", "snowflake", "mssql", "oracle"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand Down Expand Up @@ -740,7 +760,23 @@ def convert_to_offset(offset, displacement_type=displacement_type):


@pytest.mark.parametrize(
"unit", ["Y", param("Q", marks=pytest.mark.xfail), "M", "W", "D"]
"unit",
[
param(
"Y",
marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError),
),
param("Q", marks=pytest.mark.xfail),
param(
"M",
marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError),
),
param(
"W",
marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError),
),
"D",
],
)
# TODO - DateOffset - #2553
@pytest.mark.notimpl(
Expand All @@ -753,7 +789,6 @@ def convert_to_offset(offset, displacement_type=displacement_type):
"snowflake",
"polars",
"mssql",
"trino",
"druid",
"oracle",
],
Expand Down Expand Up @@ -942,6 +977,11 @@ def convert_to_offset(x):
raises=AssertionError,
reason="duckdb 0.8.0 returns DateOffset columns",
),
pytest.mark.broken(
["trino"],
raises=AssertionError,
reason="doesn't match pandas results, unclear what the issue is, perhaps timezones",
),
],
),
param(
Expand Down Expand Up @@ -969,8 +1009,7 @@ def convert_to_offset(x):
],
)
@pytest.mark.notimpl(
["datafusion", "mssql", "trino", "oracle"],
raises=com.OperationNotDefinedError,
["datafusion", "mssql", "oracle"], raises=com.OperationNotDefinedError
)
def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
expr = expr_fn(alltypes, backend).name("tmp")
Expand Down Expand Up @@ -1144,8 +1183,7 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn):
],
)
@pytest.mark.notimpl(
["datafusion", "sqlite", "mssql", "trino", "oracle"],
raises=com.OperationNotDefinedError,
["datafusion", "sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError
)
def test_temporal_binop_pandas_timedelta(
backend, con, alltypes, df, timedelta, temporal_fn
Expand Down Expand Up @@ -1286,7 +1324,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name


@pytest.mark.notimpl(
["datafusion", "sqlite", "snowflake", "mssql", "trino", "oracle"],
["datafusion", "sqlite", "snowflake", "mssql", "oracle"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
Expand All @@ -1307,7 +1345,7 @@ def test_interval_add_cast_scalar(backend, alltypes):
["pyspark"], reason="PySpark does not support casting columns to intervals"
)
@pytest.mark.notimpl(
["datafusion", "sqlite", "snowflake", "mssql", "trino", "oracle"],
["datafusion", "sqlite", "snowflake", "mssql", "oracle"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand Down Expand Up @@ -1948,7 +1986,7 @@ def test_extract_time_from_timestamp(con, microsecond):
"bigquery": "INTERVAL",
"clickhouse": "IntervalSecond",
"sqlite": "integer",
"trino": "integer",
"trino": "interval day to second",
"duckdb": "INTERVAL",
"postgres": "interval",
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ FROM (
COUNT(*) AS count_order
FROM "hive".ibis_sf1.lineitem AS t1
WHERE
t1.l_shipdate <= CAST('1998-09-02' AS DATE)
t1.l_shipdate <= FROM_ISO8601_DATE('1998-09-02')
GROUP BY
1,
2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ WITH t0 AS (
ON t4.l_orderkey = t3.o_orderkey
WHERE
t2.c_mktsegment = 'BUILDING'
AND t3.o_orderdate < CAST('1995-03-15' AS DATE)
AND t4.l_shipdate > CAST('1995-03-15' AS DATE)
AND t3.o_orderdate < FROM_ISO8601_DATE('1995-03-15')
AND t4.l_shipdate > FROM_ISO8601_DATE('1995-03-15')
GROUP BY
1,
2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ WHERE
t1.l_orderkey = t0.o_orderkey AND t1.l_commitdate < t1.l_receiptdate
)
)
AND t0.o_orderdate >= CAST('1993-07-01' AS DATE)
AND t0.o_orderdate < CAST('1993-10-01' AS DATE)
AND t0.o_orderdate >= FROM_ISO8601_DATE('1993-07-01')
AND t0.o_orderdate < FROM_ISO8601_DATE('1993-10-01')
GROUP BY
1
ORDER BY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ FROM (
ON t5.n_regionkey = t6.r_regionkey
WHERE
t6.r_name = 'ASIA'
AND t2.o_orderdate >= CAST('1994-01-01' AS DATE)
AND t2.o_orderdate < CAST('1995-01-01' AS DATE)
AND t2.o_orderdate >= FROM_ISO8601_DATE('1994-01-01')
AND t2.o_orderdate < FROM_ISO8601_DATE('1995-01-01')
GROUP BY
1
) AS t0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ SELECT
SUM(t0.l_extendedprice * t0.l_discount) AS revenue
FROM "hive".ibis_sf1.lineitem AS t0
WHERE
t0.l_shipdate >= CAST('1994-01-01' AS DATE)
AND t0.l_shipdate < CAST('1995-01-01' AS DATE)
t0.l_shipdate >= FROM_ISO8601_DATE('1994-01-01')
AND t0.l_shipdate < FROM_ISO8601_DATE('1995-01-01')
AND t0.l_discount BETWEEN 0.05 AND 0.07
AND t0.l_quantity < 24
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ FROM (
OR t0.cust_nation = 'GERMANY'
AND t0.supp_nation = 'FRANCE'
)
AND t0.l_shipdate BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND t0.l_shipdate BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31')
GROUP BY
1,
2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ WITH t0 AS (
FROM t0
WHERE
t0.r_name = 'AMERICA'
AND t0.o_orderdate BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND t0.o_orderdate BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31')
AND t0.p_type = 'ECONOMY ANODIZED STEEL'
), t2 AS (
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ WITH t0 AS (
JOIN "hive".ibis_sf1.nation AS t5
ON t2.c_nationkey = t5.n_nationkey
WHERE
t3.o_orderdate >= CAST('1993-10-01' AS DATE)
AND t3.o_orderdate < CAST('1994-01-01' AS DATE)
t3.o_orderdate >= FROM_ISO8601_DATE('1993-10-01')
AND t3.o_orderdate < FROM_ISO8601_DATE('1994-01-01')
AND t4.l_returnflag = 'R'
GROUP BY
1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ FROM (
t2.l_shipmode IN ('MAIL', 'SHIP')
AND t2.l_commitdate < t2.l_receiptdate
AND t2.l_shipdate < t2.l_commitdate
AND t2.l_receiptdate >= CAST('1994-01-01' AS DATE)
AND t2.l_receiptdate < CAST('1995-01-01' AS DATE)
AND t2.l_receiptdate >= FROM_ISO8601_DATE('1994-01-01')
AND t2.l_receiptdate < FROM_ISO8601_DATE('1995-01-01')
GROUP BY
1
) AS t0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ FROM "hive".ibis_sf1.lineitem AS t0
JOIN "hive".ibis_sf1.part AS t1
ON t0.l_partkey = t1.p_partkey
WHERE
t0.l_shipdate >= CAST('1995-09-01' AS DATE)
AND t0.l_shipdate < CAST('1995-10-01' AS DATE)
t0.l_shipdate >= FROM_ISO8601_DATE('1995-09-01')
AND t0.l_shipdate < FROM_ISO8601_DATE('1995-10-01')
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ WITH t0 AS (
)) AS total_revenue
FROM "hive".ibis_sf1.lineitem AS t3
WHERE
t3.l_shipdate >= CAST('1996-01-01' AS DATE)
AND t3.l_shipdate < CAST('1996-04-01' AS DATE)
t3.l_shipdate >= FROM_ISO8601_DATE('1996-01-01')
AND t3.l_shipdate < FROM_ISO8601_DATE('1996-04-01')
GROUP BY
1
), t1 AS (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ WITH t0 AS (
WHERE
t6.l_partkey = t5.ps_partkey
AND t6.l_suppkey = t5.ps_suppkey
AND t6.l_shipdate >= CAST('1994-01-01' AS DATE)
AND t6.l_shipdate < CAST('1995-01-01' AS DATE)
AND t6.l_shipdate >= FROM_ISO8601_DATE('1994-01-01')
AND t6.l_shipdate < FROM_ISO8601_DATE('1995-01-01')
) * 0.5
) AS t4
)
Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/trino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from ibis.backends.base.sql.alchemy.datatypes import ArrayType
from ibis.backends.trino.compiler import TrinoSQLCompiler
from ibis.backends.trino.datatypes import ROW, TrinoType
from ibis.backends.trino.datatypes import INTERVAL, ROW, TrinoType

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping
Expand Down Expand Up @@ -172,6 +172,12 @@ def column_reflect(inspector, table, column_info):
column_info["type"] = toolz.nth(
typ.dimensions or 1, toolz.iterate(ArrayType, typ.item_type)
)
elif isinstance(typ, sa.Interval):
column_info["type"] = INTERVAL(
native=typ.native,
day_precision=typ.day_precision,
second_precision=typ.second_precision,
)

return meta

Expand Down
21 changes: 21 additions & 0 deletions ibis/backends/trino/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import time, timedelta
from typing import Any

import sqlalchemy.types as sat
Expand Down Expand Up @@ -33,6 +34,26 @@ def process(
return process


class INTERVAL(sat.Interval):
def result_processor(self, dialect, coltype: str) -> None:
def process(value):
if value is None:
return value

# TODO: support year-month intervals
days, duration = value.split(" ", 1)
t = time.fromisoformat(duration)
return timedelta(
days=int(days),
hours=t.hour,
minutes=t.minute,
seconds=t.second,
microseconds=t.microsecond,
)

return process


@compiles(TIMESTAMP)
def compiles_timestamp(typ, compiler, **kw):
result = "TIMESTAMP"
Expand Down
Loading

0 comments on commit 8b8e885

Please sign in to comment.