diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 723c95c61cf1..8e8ca5eb3548 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -9,6 +9,7 @@ import pandas.testing as tm import pytest import sqlalchemy as sa +import sqlglot as sg import toolz from packaging.version import parse as parse_version from pytest import param @@ -36,6 +37,11 @@ except ImportError: BadRequest = None +try: + from pyspark.sql.utils import AnalysisException as PySparkAnalysisException +except ImportError: + PySparkAnalysisException = None + pytestmark = [ pytest.mark.never( ["sqlite", "mysql", "mssql"], @@ -125,7 +131,7 @@ def test_array_concat_variadic(con): @pytest.mark.notyet( ["postgres", "trino"], raises=sa.exc.ProgrammingError, - reason="postgres can't infer the type of an empty array", + reason="backend can't infer the type of an empty array", ) def test_array_concat_some_empty(con): left = ibis.literal([]) @@ -759,3 +765,46 @@ def test_zip(backend): s = res.execute() assert len(s[0][0]) == len(res.type().value_type) assert len(x[0]) == len(s[0]) + + +@unnest +@pytest.mark.broken( + ["clickhouse"], + raises=sg.ParseError, + reason="we might be generating incorrect code here", +) +@pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError) +@pytest.mark.notimpl( + ["polars"], + raises=com.OperationNotDefinedError, + reason="polars unnest cannot be compiled outside of a projection", +) +@pytest.mark.notyet( + ["pyspark"], + reason="pyspark doesn't seem to support field selection on explode", + raises=PySparkAnalysisException, +) +def test_array_of_struct_unnest(con): + jobs = ibis.memtable( + { + "steps": [ + [ + {"status": "success"}, + {"status": "success"}, + {"status": None}, + {"status": "failure"}, + ], + [ + {"status": None}, + {"status": "success"}, + ], + ] + }, + schema=dict(steps="array>"), + ) + expr = jobs.limit(1).steps.unnest().status + res = con.execute(expr) + value = res.iat[0] + # `value` can be `None` because the order of results is arbitrary; observed + # in the wild with the trino backend + assert value is None or isinstance(value, str) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index f8a814dbf2a7..837d46e7b414 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -47,9 +47,6 @@ def test_table(backend): raises=NotImplementedError, reason="backends hasn't implemented array literals", ), - mark.notimpl( - ["trino"], reason="Cannot render array literals", raises=sa.exc.CompileError - ), ], id="array_literal", ) diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index 8f52d58e04e5..b484419273c4 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -4,6 +4,7 @@ import ibis.expr.operations as ops from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator +from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter from ibis.backends.trino.datatypes import TrinoType from ibis.backends.trino.registry import operation_registry @@ -47,7 +48,30 @@ def _rewrite_string_contains(op): return ops.GreaterEqual(ops.StringFind(op.haystack, op.needle), 0) +class TrinoTableSetFormatter(_AlchemyTableSetFormatter): + def _format_in_memory_table(self, op, ref_op, translator): + if not op.data: + return sa.select( + *( + translator.translate(ops.Literal(None, dtype=type_)).label(name) + for name, type_ in op.schema.items() + ) + ).limit(0) + + op_schema = list(op.schema.items()) + rows = [ + tuple( + translator.translate(ops.Literal(col, dtype=type_)).label(name) + for col, (name, type_) in zip(row, op_schema) + ) + for row in ref_op.data.to_frame().itertuples(index=False) + ] + columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) + return sa.values(*columns, name=ref_op.name).data(rows) + + class TrinoSQLCompiler(AlchemyCompiler): cheap_in_memory_tables = False translator_class = TrinoSQLExprTranslator null_limit = sa.literal_column("ALL") + table_set_formatter_class = TrinoTableSetFormatter diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index dafda97bdd8a..16e1863cd29e 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -6,6 +6,7 @@ import sqlalchemy as sa import toolz from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.expression import FunctionElement from trino.sqlalchemy.datatype import DOUBLE import ibis @@ -35,15 +36,32 @@ def _array(t, elements): return t.translate(ibis.array(elements).op()) +class make_array(FunctionElement): + pass + + +@compiles(make_array, "trino") +def compile_make_array(element, compiler, **kw): + return f"ARRAY[{compiler.process(element.clauses, **kw)}]" + + def _literal(t, op): value = op.value dtype = op.dtype if value is None: return sa.null() - - if dtype.is_struct(): - return sa.cast(sa.func.row(*value.values()), t.get_sqla_type(dtype)) + elif dtype.is_struct(): + elements = ( + t.translate(ops.Literal(element, dtype=field_type)) + for element, field_type in zip(value.values(), dtype.types) + ) + return sa.cast(sa.func.row(*elements), t.get_sqla_type(dtype)) + elif dtype.is_array(): + value_type = dtype.value_type + return make_array( + *(t.translate(ops.Literal(element, dtype=value_type)) for element in value) + ) elif dtype.is_map(): return sa.func.map(_array(t, value.keys()), _array(t, value.values())) elif dtype.is_float64(): @@ -184,8 +202,13 @@ def _unnest(t, op): row_type = op.arg.dtype.value_type names = getattr(row_type, "names", (name,)) rd = sa.func.unnest(t.translate(arg)).table_valued(*names).render_derived() - # wrap in a row column so that we can return a single column from this rule - if len(names) == 1: + # when unnesting a single column, unwrap the single ROW field access that + # would otherwise be generated, but keep the ROW if the array's element + # type is struct + if not row_type.is_struct(): + assert ( + len(names) == 1 + ), f"got non-struct dtype {row_type} with more than one name: {len(names)}" return rd.c[0] row = sa.func.row(*(rd.c[name] for name in names)) return sa.cast(row, t.get_sqla_type(row_type))