From 8b8e885d35ac613ec0b9c682cac028fa86d58209 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Fri, 13 Oct 2023 07:48:12 -0400 Subject: [PATCH] feat(trino): support temporal operations --- ibis/backends/tests/test_temporal.py | 58 +++++++++++++++---- .../test_h01/test_tpc_h01/trino/h01.sql | 2 +- .../test_h03/test_tpc_h03/trino/h03.sql | 4 +- .../test_h04/test_tpc_h04/trino/h04.sql | 4 +- .../test_h05/test_tpc_h05/trino/h05.sql | 4 +- .../test_h06/test_tpc_h06/trino/h06.sql | 4 +- .../test_h07/test_tpc_h07/trino/h07.sql | 2 +- .../test_h08/test_tpc_h08/trino/h08.sql | 2 +- .../test_h10/test_tpc_h10/trino/h10.sql | 4 +- .../test_h12/test_tpc_h12/trino/h12.sql | 4 +- .../test_h14/test_tpc_h14/trino/h14.sql | 4 +- .../test_h15/test_tpc_h15/trino/h15.sql | 4 +- .../test_h20/test_tpc_h20/trino/h20.sql | 4 +- ibis/backends/trino/__init__.py | 8 ++- ibis/backends/trino/datatypes.py | 21 +++++++ ibis/backends/trino/registry.py | 35 +++++++++++ 16 files changed, 132 insertions(+), 32 deletions(-) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index 28384bb97856..6951339cda3b 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -582,6 +582,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=com.UnsupportedOperationError, reason="Interval from integer column is unsupported for the PySpark backend.", ), + pytest.mark.notyet( + ["trino"], + raises=com.UnsupportedOperationError, + reason="year not implemented", + ), ], ), param("Q", pd.offsets.DateOffset, marks=pytest.mark.xfail), @@ -610,6 +615,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=com.UnsupportedOperationError, reason="Interval from integer column is unsupported for the PySpark backend.", ), + pytest.mark.notyet( + ["trino"], + raises=com.UnsupportedOperationError, + reason="month not implemented", + ), ], ), param( @@ -632,6 +642,11 @@ def test_date_truncate(backend, alltypes, df, unit): raises=com.UnsupportedOperationError, reason="Interval from integer column is unsupported for the PySpark backend.", ), + pytest.mark.notyet( + ["trino"], + raises=com.UnsupportedOperationError, + reason="week not implemented", + ), ], ), param( @@ -704,12 +719,17 @@ def test_date_truncate(backend, alltypes, df, unit): raises=com.UnsupportedArgumentError, reason="Interval unit \"us\" is not allowed. Allowed units are: ['Y', 'W', 'M', 'D', 'h', 'm', 's']", ), + pytest.mark.notimpl( + ["trino"], + raises=AssertionError, + reason="we're dropping microseconds to ensure results consistent with pandas", + ), ], ), ], ) @pytest.mark.notimpl( - ["datafusion", "sqlite", "snowflake", "trino", "mssql", "oracle"], + ["datafusion", "sqlite", "snowflake", "mssql", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -740,7 +760,23 @@ def convert_to_offset(offset, displacement_type=displacement_type): @pytest.mark.parametrize( - "unit", ["Y", param("Q", marks=pytest.mark.xfail), "M", "W", "D"] + "unit", + [ + param( + "Y", + marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + ), + param("Q", marks=pytest.mark.xfail), + param( + "M", + marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + ), + param( + "W", + marks=pytest.mark.notyet(["trino"], raises=com.UnsupportedOperationError), + ), + "D", + ], ) # TODO - DateOffset - #2553 @pytest.mark.notimpl( @@ -753,7 +789,6 @@ def convert_to_offset(offset, displacement_type=displacement_type): "snowflake", "polars", "mssql", - "trino", "druid", "oracle", ], @@ -942,6 +977,11 @@ def convert_to_offset(x): raises=AssertionError, reason="duckdb 0.8.0 returns DateOffset columns", ), + pytest.mark.broken( + ["trino"], + raises=AssertionError, + reason="doesn't match pandas results, unclear what the issue is, perhaps timezones", + ), ], ), param( @@ -969,8 +1009,7 @@ def convert_to_offset(x): ], ) @pytest.mark.notimpl( - ["datafusion", "mssql", "trino", "oracle"], - raises=com.OperationNotDefinedError, + ["datafusion", "mssql", "oracle"], raises=com.OperationNotDefinedError ) def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): expr = expr_fn(alltypes, backend).name("tmp") @@ -1144,8 +1183,7 @@ def test_temporal_binop(backend, con, alltypes, df, expr_fn, expected_fn): ], ) @pytest.mark.notimpl( - ["datafusion", "sqlite", "mssql", "trino", "oracle"], - raises=com.OperationNotDefinedError, + ["datafusion", "sqlite", "mssql", "oracle"], raises=com.OperationNotDefinedError ) def test_temporal_binop_pandas_timedelta( backend, con, alltypes, df, timedelta, temporal_fn @@ -1286,7 +1324,7 @@ def test_timestamp_comparison_filter_numpy(backend, con, alltypes, df, func_name @pytest.mark.notimpl( - ["datafusion", "sqlite", "snowflake", "mssql", "trino", "oracle"], + ["datafusion", "sqlite", "snowflake", "mssql", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.broken( @@ -1307,7 +1345,7 @@ def test_interval_add_cast_scalar(backend, alltypes): ["pyspark"], reason="PySpark does not support casting columns to intervals" ) @pytest.mark.notimpl( - ["datafusion", "sqlite", "snowflake", "mssql", "trino", "oracle"], + ["datafusion", "sqlite", "snowflake", "mssql", "oracle"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -1948,7 +1986,7 @@ def test_extract_time_from_timestamp(con, microsecond): "bigquery": "INTERVAL", "clickhouse": "IntervalSecond", "sqlite": "integer", - "trino": "integer", + "trino": "interval day to second", "duckdb": "INTERVAL", "postgres": "interval", } diff --git a/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/trino/h01.sql b/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/trino/h01.sql index efac88b989d9..6285952673a4 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/trino/h01.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h01/test_tpc_h01/trino/h01.sql @@ -29,7 +29,7 @@ FROM ( COUNT(*) AS count_order FROM "hive".ibis_sf1.lineitem AS t1 WHERE - t1.l_shipdate <= CAST('1998-09-02' AS DATE) + t1.l_shipdate <= FROM_ISO8601_DATE('1998-09-02') GROUP BY 1, 2 diff --git a/ibis/backends/tests/tpch/snapshots/test_h03/test_tpc_h03/trino/h03.sql b/ibis/backends/tests/tpch/snapshots/test_h03/test_tpc_h03/trino/h03.sql index c0b34fcaeff3..717d6e2be2cb 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h03/test_tpc_h03/trino/h03.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h03/test_tpc_h03/trino/h03.sql @@ -13,8 +13,8 @@ WITH t0 AS ( ON t4.l_orderkey = t3.o_orderkey WHERE t2.c_mktsegment = 'BUILDING' - AND t3.o_orderdate < CAST('1995-03-15' AS DATE) - AND t4.l_shipdate > CAST('1995-03-15' AS DATE) + AND t3.o_orderdate < FROM_ISO8601_DATE('1995-03-15') + AND t4.l_shipdate > FROM_ISO8601_DATE('1995-03-15') GROUP BY 1, 2, diff --git a/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/trino/h04.sql b/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/trino/h04.sql index e1e6fc3fda30..50c7bc78e714 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/trino/h04.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h04/test_tpc_h04/trino/h04.sql @@ -12,8 +12,8 @@ WHERE t1.l_orderkey = t0.o_orderkey AND t1.l_commitdate < t1.l_receiptdate ) ) - AND t0.o_orderdate >= CAST('1993-07-01' AS DATE) - AND t0.o_orderdate < CAST('1993-10-01' AS DATE) + AND t0.o_orderdate >= FROM_ISO8601_DATE('1993-07-01') + AND t0.o_orderdate < FROM_ISO8601_DATE('1993-10-01') GROUP BY 1 ORDER BY diff --git a/ibis/backends/tests/tpch/snapshots/test_h05/test_tpc_h05/trino/h05.sql b/ibis/backends/tests/tpch/snapshots/test_h05/test_tpc_h05/trino/h05.sql index 3924d1d20967..46526d998f1d 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h05/test_tpc_h05/trino/h05.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h05/test_tpc_h05/trino/h05.sql @@ -20,8 +20,8 @@ FROM ( ON t5.n_regionkey = t6.r_regionkey WHERE t6.r_name = 'ASIA' - AND t2.o_orderdate >= CAST('1994-01-01' AS DATE) - AND t2.o_orderdate < CAST('1995-01-01' AS DATE) + AND t2.o_orderdate >= FROM_ISO8601_DATE('1994-01-01') + AND t2.o_orderdate < FROM_ISO8601_DATE('1995-01-01') GROUP BY 1 ) AS t0 diff --git a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/trino/h06.sql b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/trino/h06.sql index e6e0c790f251..15c774b007b3 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/trino/h06.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h06/test_tpc_h06/trino/h06.sql @@ -2,7 +2,7 @@ SELECT SUM(t0.l_extendedprice * t0.l_discount) AS revenue FROM "hive".ibis_sf1.lineitem AS t0 WHERE - t0.l_shipdate >= CAST('1994-01-01' AS DATE) - AND t0.l_shipdate < CAST('1995-01-01' AS DATE) + t0.l_shipdate >= FROM_ISO8601_DATE('1994-01-01') + AND t0.l_shipdate < FROM_ISO8601_DATE('1995-01-01') AND t0.l_discount BETWEEN 0.05 AND 0.07 AND t0.l_quantity < 24 \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql index 8f999e273acf..d24ce2d30877 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h07/test_tpc_h07/trino/h07.sql @@ -39,7 +39,7 @@ FROM ( OR t0.cust_nation = 'GERMANY' AND t0.supp_nation = 'FRANCE' ) - AND t0.l_shipdate BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND t0.l_shipdate BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31') GROUP BY 1, 2, diff --git a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql index 877269ad2167..9512aedcf57f 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h08/test_tpc_h08/trino/h08.sql @@ -34,7 +34,7 @@ WITH t0 AS ( FROM t0 WHERE t0.r_name = 'AMERICA' - AND t0.o_orderdate BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND t0.o_orderdate BETWEEN FROM_ISO8601_DATE('1995-01-01') AND FROM_ISO8601_DATE('1996-12-31') AND t0.p_type = 'ECONOMY ANODIZED STEEL' ), t2 AS ( SELECT diff --git a/ibis/backends/tests/tpch/snapshots/test_h10/test_tpc_h10/trino/h10.sql b/ibis/backends/tests/tpch/snapshots/test_h10/test_tpc_h10/trino/h10.sql index d56c7eb5b478..30b785ea0a07 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h10/test_tpc_h10/trino/h10.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h10/test_tpc_h10/trino/h10.sql @@ -18,8 +18,8 @@ WITH t0 AS ( JOIN "hive".ibis_sf1.nation AS t5 ON t2.c_nationkey = t5.n_nationkey WHERE - t3.o_orderdate >= CAST('1993-10-01' AS DATE) - AND t3.o_orderdate < CAST('1994-01-01' AS DATE) + t3.o_orderdate >= FROM_ISO8601_DATE('1993-10-01') + AND t3.o_orderdate < FROM_ISO8601_DATE('1994-01-01') AND t4.l_returnflag = 'R' GROUP BY 1, diff --git a/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/trino/h12.sql b/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/trino/h12.sql index 6d4eaf2cf9ce..25a39abaf140 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/trino/h12.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h12/test_tpc_h12/trino/h12.sql @@ -14,8 +14,8 @@ FROM ( t2.l_shipmode IN ('MAIL', 'SHIP') AND t2.l_commitdate < t2.l_receiptdate AND t2.l_shipdate < t2.l_commitdate - AND t2.l_receiptdate >= CAST('1994-01-01' AS DATE) - AND t2.l_receiptdate < CAST('1995-01-01' AS DATE) + AND t2.l_receiptdate >= FROM_ISO8601_DATE('1994-01-01') + AND t2.l_receiptdate < FROM_ISO8601_DATE('1995-01-01') GROUP BY 1 ) AS t0 diff --git a/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/trino/h14.sql b/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/trino/h14.sql index 57acb13965b1..de8b5b83ad20 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/trino/h14.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h14/test_tpc_h14/trino/h14.sql @@ -10,5 +10,5 @@ FROM "hive".ibis_sf1.lineitem AS t0 JOIN "hive".ibis_sf1.part AS t1 ON t0.l_partkey = t1.p_partkey WHERE - t0.l_shipdate >= CAST('1995-09-01' AS DATE) - AND t0.l_shipdate < CAST('1995-10-01' AS DATE) \ No newline at end of file + t0.l_shipdate >= FROM_ISO8601_DATE('1995-09-01') + AND t0.l_shipdate < FROM_ISO8601_DATE('1995-10-01') \ No newline at end of file diff --git a/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/trino/h15.sql b/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/trino/h15.sql index 2a8bc3032571..7e43b375d8ef 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/trino/h15.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h15/test_tpc_h15/trino/h15.sql @@ -6,8 +6,8 @@ WITH t0 AS ( )) AS total_revenue FROM "hive".ibis_sf1.lineitem AS t3 WHERE - t3.l_shipdate >= CAST('1996-01-01' AS DATE) - AND t3.l_shipdate < CAST('1996-04-01' AS DATE) + t3.l_shipdate >= FROM_ISO8601_DATE('1996-01-01') + AND t3.l_shipdate < FROM_ISO8601_DATE('1996-04-01') GROUP BY 1 ), t1 AS ( diff --git a/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/trino/h20.sql b/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/trino/h20.sql index 710a129e6b1e..634f45c20b79 100644 --- a/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/trino/h20.sql +++ b/ibis/backends/tests/tpch/snapshots/test_h20/test_tpc_h20/trino/h20.sql @@ -54,8 +54,8 @@ WITH t0 AS ( WHERE t6.l_partkey = t5.ps_partkey AND t6.l_suppkey = t5.ps_suppkey - AND t6.l_shipdate >= CAST('1994-01-01' AS DATE) - AND t6.l_shipdate < CAST('1995-01-01' AS DATE) + AND t6.l_shipdate >= FROM_ISO8601_DATE('1994-01-01') + AND t6.l_shipdate < FROM_ISO8601_DATE('1995-01-01') ) * 0.5 ) AS t4 ) diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 4e55a08fabd3..956aae3716c1 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -26,7 +26,7 @@ ) from ibis.backends.base.sql.alchemy.datatypes import ArrayType from ibis.backends.trino.compiler import TrinoSQLCompiler -from ibis.backends.trino.datatypes import ROW, TrinoType +from ibis.backends.trino.datatypes import INTERVAL, ROW, TrinoType if TYPE_CHECKING: from collections.abc import Iterator, Mapping @@ -172,6 +172,12 @@ def column_reflect(inspector, table, column_info): column_info["type"] = toolz.nth( typ.dimensions or 1, toolz.iterate(ArrayType, typ.item_type) ) + elif isinstance(typ, sa.Interval): + column_info["type"] = INTERVAL( + native=typ.native, + day_precision=typ.day_precision, + second_precision=typ.second_precision, + ) return meta diff --git a/ibis/backends/trino/datatypes.py b/ibis/backends/trino/datatypes.py index fc9e33cd2200..d34bc2c81997 100644 --- a/ibis/backends/trino/datatypes.py +++ b/ibis/backends/trino/datatypes.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import time, timedelta from typing import Any import sqlalchemy.types as sat @@ -33,6 +34,26 @@ def process( return process +class INTERVAL(sat.Interval): + def result_processor(self, dialect, coltype: str) -> None: + def process(value): + if value is None: + return value + + # TODO: support year-month intervals + days, duration = value.split(" ", 1) + t = time.fromisoformat(duration) + return timedelta( + days=int(days), + hours=t.hour, + minutes=t.minute, + seconds=t.second, + microseconds=t.microsecond, + ) + + return process + + @compiles(TIMESTAMP) def compiles_timestamp(typ, compiler, **kw): result = "TIMESTAMP" diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index ed17a247ad34..7d25d633f223 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from functools import partial, reduce from typing import Literal @@ -11,6 +12,7 @@ import ibis import ibis.common.exceptions as com +import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.base.sql.alchemy.registry import ( _literal as _alchemy_literal, @@ -27,6 +29,7 @@ varargs, ) from ibis.backends.postgres.registry import _corr, _covar +from ibis.backends.trino.datatypes import INTERVAL operation_registry = sqlalchemy_operation_registry.copy() operation_registry.update(sqlalchemy_window_functions_registry) @@ -68,8 +71,19 @@ def _literal(t, op): return sa.literal(float(value), type_=DOUBLE()) elif dtype.is_integer(): return sa.literal(int(value), type_=t.get_sqla_type(dtype)) + elif dtype.is_timestamp(): + return sa.cast( + sa.func.from_iso8601_timestamp(value.isoformat()), t.get_sqla_type(dtype) + ) elif dtype.is_date(): + return sa.func.from_iso8601_date(value.isoformat()) + elif dtype.is_time(): return sa.cast(sa.literal(str(value)), t.get_sqla_type(dtype)) + elif dtype.is_interval(): + return sa.literal_column( + f"INTERVAL '{value}' {dtype.resolution.upper()}", type_=INTERVAL + ) + return _alchemy_literal(t, op) @@ -324,6 +338,18 @@ def _array_intersect(t, op): 3, ) + +def _interval_from_integer(t, op): + unit = op.unit.short + if unit in ("Y", "Q", "M", "W"): + raise com.UnsupportedOperationError(f"Interval unit {unit!r} not supported") + arg = sa.func.concat( + t.translate(ops.Cast(op.arg, dt.String(nullable=op.arg.dtype.nullable))), + unit.lower(), + ) + return sa.type_coerce(sa.func.parse_duration(arg), INTERVAL) + + operation_registry.update( { # conditional expressions @@ -511,6 +537,15 @@ def _array_intersect(t, op): ops.TimeDelta: _temporal_delta, ops.DateDelta: _temporal_delta, ops.TimestampDelta: _temporal_delta, + ops.TimestampAdd: fixed_arity(operator.add, 2), + ops.TimestampSub: fixed_arity(operator.sub, 2), + ops.TimestampDiff: fixed_arity(lambda x, y: sa.type_coerce(x - y, INTERVAL), 2), + ops.DateAdd: fixed_arity(operator.add, 2), + ops.DateSub: fixed_arity(operator.sub, 2), + ops.DateDiff: fixed_arity(lambda x, y: sa.type_coerce(x - y, INTERVAL), 2), + ops.IntervalAdd: fixed_arity(operator.add, 2), + ops.IntervalSubtract: fixed_arity(operator.sub, 2), + ops.IntervalFromInteger: _interval_from_integer, } )