Skip to content

Commit

Permalink
fix(trino): differentiate between a single column struct and a non-st…
Browse files Browse the repository at this point in the history
…ruct column
  • Loading branch information
cpcloud committed Sep 15, 2023
1 parent b0b62cc commit b1f1939
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
51 changes: 50 additions & 1 deletion ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas.testing as tm
import pytest
import sqlalchemy as sa
import sqlglot as sg
import toolz
from packaging.version import parse as parse_version
from pytest import param
Expand Down Expand Up @@ -36,6 +37,11 @@
except ImportError:
BadRequest = None

try:
from pyspark.sql.utils import AnalysisException as PySparkAnalysisException
except ImportError:
PySparkAnalysisException = None

pytestmark = [
pytest.mark.never(
["sqlite", "mysql", "mssql"],
Expand Down Expand Up @@ -125,7 +131,7 @@ def test_array_concat_variadic(con):
@pytest.mark.notyet(
["postgres", "trino"],
raises=sa.exc.ProgrammingError,
reason="postgres can't infer the type of an empty array",
reason="backend can't infer the type of an empty array",
)
def test_array_concat_some_empty(con):
left = ibis.literal([])
Expand Down Expand Up @@ -759,3 +765,46 @@ def test_zip(backend):
s = res.execute()
assert len(s[0][0]) == len(res.type().value_type)
assert len(x[0]) == len(s[0])


@unnest
@pytest.mark.broken(
["clickhouse"],
raises=sg.ParseError,
reason="we might be generating incorrect code here",
)
@pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError)
@pytest.mark.notimpl(
["polars"],
raises=com.OperationNotDefinedError,
reason="polars unnest cannot be compiled outside of a projection",
)
@pytest.mark.notyet(
["pyspark"],
reason="pyspark doesn't seem to support field selection on explode",
raises=PySparkAnalysisException,
)
def test_array_of_struct_unnest(con):
jobs = ibis.memtable(
{
"steps": [
[
{"status": "success"},
{"status": "success"},
{"status": None},
{"status": "failure"},
],
[
{"status": None},
{"status": "success"},
],
]
},
schema=dict(steps="array<struct<status: string>>"),
)
expr = jobs.limit(1).steps.unnest().status
res = con.execute(expr)
value = res.iat[0]
# `value` can be `None` because the order of results is arbitrary; observed
# in the wild with the trino backend
assert value is None or isinstance(value, str)
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ def test_table(backend):
raises=NotImplementedError,
reason="backends hasn't implemented array literals",
),
mark.notimpl(
["trino"], reason="Cannot render array literals", raises=sa.exc.CompileError
),
],
id="array_literal",
)
Expand Down
24 changes: 24 additions & 0 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter
from ibis.backends.trino.datatypes import TrinoType
from ibis.backends.trino.registry import operation_registry

Expand Down Expand Up @@ -47,7 +48,30 @@ def _rewrite_string_contains(op):
return ops.GreaterEqual(ops.StringFind(op.haystack, op.needle), 0)


class TrinoTableSetFormatter(_AlchemyTableSetFormatter):
def _format_in_memory_table(self, op, ref_op, translator):
if not op.data:
return sa.select(
*(
translator.translate(ops.Literal(None, dtype=type_)).label(name)
for name, type_ in op.schema.items()
)
).limit(0)

op_schema = list(op.schema.items())
rows = [
tuple(
translator.translate(ops.Literal(col, dtype=type_)).label(name)
for col, (name, type_) in zip(row, op_schema)
)
for row in ref_op.data.to_frame().itertuples(index=False)
]
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
return sa.values(*columns, name=ref_op.name).data(rows)


class TrinoSQLCompiler(AlchemyCompiler):
cheap_in_memory_tables = False
translator_class = TrinoSQLExprTranslator
null_limit = sa.literal_column("ALL")
table_set_formatter_class = TrinoTableSetFormatter
33 changes: 28 additions & 5 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sqlalchemy as sa
import toolz
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import FunctionElement
from trino.sqlalchemy.datatype import DOUBLE

import ibis
Expand Down Expand Up @@ -35,15 +36,32 @@ def _array(t, elements):
return t.translate(ibis.array(elements).op())


class make_array(FunctionElement):
pass


@compiles(make_array, "trino")
def compile_make_array(element, compiler, **kw):
return f"ARRAY[{compiler.process(element.clauses, **kw)}]"


def _literal(t, op):
value = op.value
dtype = op.dtype

if value is None:
return sa.null()

if dtype.is_struct():
return sa.cast(sa.func.row(*value.values()), t.get_sqla_type(dtype))
elif dtype.is_struct():
elements = (
t.translate(ops.Literal(element, dtype=field_type))
for element, field_type in zip(value.values(), dtype.types)
)
return sa.cast(sa.func.row(*elements), t.get_sqla_type(dtype))
elif dtype.is_array():
value_type = dtype.value_type
return make_array(
*(t.translate(ops.Literal(element, dtype=value_type)) for element in value)
)
elif dtype.is_map():
return sa.func.map(_array(t, value.keys()), _array(t, value.values()))
elif dtype.is_float64():
Expand Down Expand Up @@ -184,8 +202,13 @@ def _unnest(t, op):
row_type = op.arg.dtype.value_type
names = getattr(row_type, "names", (name,))
rd = sa.func.unnest(t.translate(arg)).table_valued(*names).render_derived()
# wrap in a row column so that we can return a single column from this rule
if len(names) == 1:
# when unnesting a single column, unwrap the single ROW field access that
# would otherwise be generated, but keep the ROW if the array's element
# type is struct
if not row_type.is_struct():
assert (
len(names) == 1
), f"got non-struct dtype {row_type} with more than one name: {len(names)}"
return rd.c[0]
row = sa.func.row(*(rd.c[name] for name in names))
return sa.cast(row, t.get_sqla_type(row_type))
Expand Down

0 comments on commit b1f1939

Please sign in to comment.