Skip to content

Commit

Permalink
feat(trino): implement basic struct operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 28, 2022
1 parent 602999d commit cc3c937
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 10 deletions.
1 change: 1 addition & 0 deletions ibis/backends/trino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def do_connect(
port=port,
database=database,
)
connect_args.setdefault("experimental_python_types", True)
super().do_connect(sa.create_engine(url, connect_args=connect_args))
self._meta = sa.MetaData(bind=self.con, schema=schema)

Expand Down
8 changes: 0 additions & 8 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from __future__ import annotations

import sqlalchemy as sa
from trino.sqlalchemy.datatype import JSON
from trino.sqlalchemy.dialect import TrinoDialect

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.trino.registry import operation_registry
Expand Down Expand Up @@ -48,8 +45,3 @@ def _rewrite_string_contains(op):
class TrinoSQLCompiler(AlchemyCompiler):
cheap_in_memory_tables = False
translator_class = TrinoSQLExprTranslator


@dt.dtype.register(TrinoDialect, JSON)
def sa_jsonb(_, satype, nullable=True):
return dt.JSON(nullable=nullable)
30 changes: 28 additions & 2 deletions ibis/backends/trino/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import parsy as p
import sqlalchemy as sa
import trino
from sqlalchemy.ext.compiler import compiles
from trino.sqlalchemy.datatype import DOUBLE, JSON, ROW
from trino.sqlalchemy.dialect import TrinoDialect

import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.datatypes as dt
from ibis.common.parsing import (
COMMA,
Expand Down Expand Up @@ -124,7 +126,7 @@ def struct():
return ty.parse(text)


@dt.dtype.register(TrinoDialect, trino.sqlalchemy.datatype.DOUBLE)
@dt.dtype.register(TrinoDialect, DOUBLE)
def sa_trino_double(_, satype, nullable=True):
return dt.Float64(nullable=nullable)

Expand All @@ -133,3 +135,27 @@ def sa_trino_double(_, satype, nullable=True):
def sa_trino_array(dialect, satype, nullable=True):
value_dtype = dt.dtype(dialect, satype.item_type)
return dt.Array(value_dtype, nullable=nullable)


@dt.dtype.register(TrinoDialect, ROW)
def sa_trino_row(dialect, satype, nullable=True):
fields = ((name, dt.dtype(dialect, typ)) for name, typ in satype.attr_types)
return dt.Struct.from_tuples(fields, nullable=nullable)


@dt.dtype.register(TrinoDialect, JSON)
def sa_jsonb(_, satype, nullable=True):
return dt.JSON(nullable=nullable)


@compiles(sa.TEXT, "trino")
def compiles_text(element, compiler, **kw):
return "VARCHAR"


@compiles(sat.StructType, "trino")
def compiles_struct(element, compiler, **kw):
content = ", ".join(
f"{field} {compiler.process(typ, **kw)}" for field, typ in element.pairs
)
return f"ROW({content})"
20 changes: 20 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.datatypes import to_sqla_type
from ibis.backends.base.sql.alchemy.registry import _literal as _alchemy_literal
from ibis.backends.base.sql.alchemy.registry import (
fixed_arity,
reduction,
Expand All @@ -16,6 +17,12 @@
operation_registry.update(sqlalchemy_window_functions_registry)


def _literal(t, op):
if (dtype := op.output_dtype).is_struct():
return sa.cast(sa.func.row(*op.value.values()), to_sqla_type(dtype))
return _alchemy_literal(t, op)


def _arbitrary(t, op):
if op.how == "heavy":
raise ValueError('Trino does not support how="heavy"')
Expand Down Expand Up @@ -144,6 +151,16 @@ def _timestamp_from_unix(t, op):
raise ValueError(f"{unit!r} unit is not supported!")


def _struct_field(t, op):
return t.translate(op.arg).op(".")(sa.text(op.field))


def _struct_column(t, op):
return sa.cast(
sa.func.row(*map(t.translate, op.values)), to_sqla_type(op.output_dtype)
)


operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -201,5 +218,8 @@ def _timestamp_from_unix(t, op):
ops.StringToTimestamp: fixed_arity(sa.func.date_parse, 2),
ops.TimestampNow: fixed_arity(sa.func.now, 0),
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.StructField: _struct_field,
ops.StructColumn: _struct_column,
ops.Literal: _literal,
}
)

0 comments on commit cc3c937

Please sign in to comment.