From 6bc301ab07c33b2afb3d66750367e086ddc96400 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 4 Apr 2023 15:11:59 -0400 Subject: [PATCH] fix(sql): generate consistent `pivot_longer` semantics in the presence of multiple `unnest`s --- ibis/backends/base/sql/alchemy/datatypes.py | 9 ++-- ibis/backends/duckdb/registry.py | 1 + ibis/backends/postgres/compiler.py | 3 ++ ibis/backends/postgres/registry.py | 60 ++++++++++++++++++++- ibis/backends/tests/test_generic.py | 19 ++++++- ibis/backends/trino/datatypes.py | 8 +-- ibis/backends/trino/registry.py | 9 +++- ibis/expr/types/relations.py | 25 +++++---- 8 files changed, 112 insertions(+), 22 deletions(-) diff --git a/ibis/backends/base/sql/alchemy/datatypes.py b/ibis/backends/base/sql/alchemy/datatypes.py index bdd275994da5..6f766c9ad05b 100644 --- a/ibis/backends/base/sql/alchemy/datatypes.py +++ b/ibis/backends/base/sql/alchemy/datatypes.py @@ -13,6 +13,7 @@ import ibis.expr.datatypes as dt import ibis.expr.schema as sch from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported +from ibis.common.collections import FrozenDict if geospatial_supported: import geoalchemy2 as ga @@ -29,10 +30,12 @@ def compiles_array(element, compiler, **kw): class StructType(sat.UserDefinedType): + cache_ok = True + def __init__(self, fields: Mapping[str, sat.TypeEngine]) -> None: - self.fields = { - name: sat.to_instance(type) for name, type in dict(fields).items() - } + self.fields = FrozenDict( + {name: sat.to_instance(typ) for name, typ in fields.items()} + ) @compiles(StructType, "default") diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index 3e835eceda57..72e2618b0121 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -348,6 +348,7 @@ def _array_filter(t, op): ops.ArrayMap: _array_map, ops.ArrayFilter: _array_filter, ops.Argument: lambda _, op: sa.literal_column(op.name), + ops.Unnest: unary(sa.func.unnest), } ) diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 22de0bef869f..ed4de9041794 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -23,6 +23,9 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator): _has_reduction_filter_syntax = True _dialect_name = "postgresql" + # it does support it, but we can't use it because of support for pivot + supports_unnest_in_select = False + rewrites = PostgreSQLExprTranslator.rewrites diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index d65c341890d1..8f55cfa3930a 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -11,6 +11,8 @@ import sqlalchemy as sa from sqlalchemy.dialects import postgresql as pg +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.functions import GenericFunction import ibis.backends.base.sql.registry.geospatial as geo import ibis.common.exceptions as com @@ -488,6 +490,60 @@ def _arbitrary(t, op): return t._reduction(func, op) +class struct_field(GenericFunction): + inherit_cache = True + + +@compiles(struct_field) +def compile_struct_field_postgresql(element, compiler, **kw): + arg, field = element.clauses + return f"({compiler.process(arg, **kw)}).{field.name}" + + +def _struct_field(t, op): + arg = op.arg + idx = arg.output_dtype.names.index(op.field) + 1 + field_name = sa.literal_column(f"f{idx:d}") + return struct_field( + t.translate(arg), field_name, type_=t.get_sqla_type(op.output_dtype) + ) + + +def _struct_column(t, op): + types = op.output_dtype.types + return sa.func.row( + # we have to cast here, otherwise postgres refuses to allow the statement + *map(t.translate, map(ops.Cast, op.values, types)), + type_=t.get_sqla_type( + dt.Struct({f"f{i:d}": typ for i, typ in enumerate(types, start=1)}) + ), + ) + + +def _unnest(t, op): + arg = op.arg + row_type = arg.output_dtype.value_type + + types = getattr(row_type, "types", (row_type,)) + + is_struct = row_type.is_struct() + derived = ( + sa.func.unnest(t.translate(arg)) + .table_valued( + *( + sa.column(f"f{i:d}", stype) + for i, stype in enumerate(map(t.get_sqla_type, types), start=1) + ) + ) + .render_derived(with_types=is_struct) + ) + + # wrap in a row column so that we can return a single column from this rule + if not is_struct: + return derived.c[0] + return sa.func.row(*derived.c) + + operation_registry.update( { ops.Literal: _literal, @@ -594,7 +650,7 @@ def _arbitrary(t, op): ), ops.ArrayConcat: fixed_arity(sa.sql.expression.ColumnElement.concat, 2), ops.ArrayRepeat: _array_repeat, - ops.Unnest: unary(sa.func.unnest), + ops.Unnest: _unnest, ops.Covariance: _covar, ops.Correlation: _corr, ops.BitwiseXor: _bitwise_op("#"), @@ -639,5 +695,7 @@ def _arbitrary(t, op): ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)), ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2), ops.Arbitrary: _arbitrary, + ops.StructColumn: _struct_column, + ops.StructField: _struct_field, } ) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ec626759820a..f1b6c27320fd 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -987,6 +987,7 @@ def query(t, group_cols): ) def test_pivot_longer(backend): diamonds = backend.diamonds + df = diamonds.execute() res = diamonds.pivot_longer(s.c("x", "y", "z"), names_to="pos", values_to="xyz") assert res.schema().names == ( "carat", @@ -999,8 +1000,22 @@ def test_pivot_longer(backend): "pos", "xyz", ) - df = res.limit(5).execute() - assert not df.empty + expected = pd.melt( + df, + id_vars=[ + "carat", + "cut", + "color", + "clarity", + "depth", + "table", + "price", + ], + value_vars=list('xyz'), + var_name="pos", + value_name="xyz", + ) + assert len(res.execute()) == len(expected) @pytest.mark.notyet(["datafusion"], raises=com.OperationNotDefinedError) diff --git a/ibis/backends/trino/datatypes.py b/ibis/backends/trino/datatypes.py index c89ad281cbc0..08000f7bf5cb 100644 --- a/ibis/backends/trino/datatypes.py +++ b/ibis/backends/trino/datatypes.py @@ -137,7 +137,7 @@ def _timestamp(_, itype): return TIMESTAMP(precision=itype.scale, timezone=bool(itype.timezone)) -@compiles(TIMESTAMP, "trino") +@compiles(TIMESTAMP) def compiles_timestamp(typ, compiler, **kw): result = "TIMESTAMP" @@ -150,7 +150,7 @@ def compiles_timestamp(typ, compiler, **kw): return result -@compiles(ROW, "trino") +@compiles(ROW) def _compiles_row(element, compiler, **kw): # TODO: @compiles should live in the dialect quote = compiler.dialect.identifier_preparer.quote @@ -168,7 +168,7 @@ def _map(dialect, itype): ) -@compiles(MAP, "trino") +@compiles(MAP) def compiles_map(typ, compiler, **kw): # TODO: @compiles should live in the dialect key_type = compiler.process(typ.key_type, **kw) @@ -191,7 +191,7 @@ def _real(*_): return sa.REAL() -@compiles(DOUBLE, "trino") +@compiles(DOUBLE) @compiles(sa.REAL, "trino") def _floating(element, compiler, **kw): return type(element).__name__.upper() diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index 968b8d1e4208..b1adac58da50 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -172,7 +172,14 @@ def _round(t, op): def _unnest(t, op): arg = op.arg name = arg.name - return sa.func.unnest(t.translate(arg)).table_valued(name).render_derived().c[name] + row_type = op.arg.output_dtype.value_type + names = getattr(row_type, "names", (name,)) + rd = sa.func.unnest(t.translate(arg)).table_valued(*names).render_derived() + # wrap in a row column so that we can return a single column from this rule + if len(names) == 1: + return rd.c[0] + row = sa.func.row(*(rd.c[name] for name in names)) + return sa.cast(row, t.get_sqla_type(row_type)) def _where(t, op): diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 76967f937831..bbf1b9577989 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -3104,21 +3104,24 @@ def pivot_longer( elif isinstance(values_transform, Deferred): values_transform = values_transform.resolve - names_map = {name: [] for name in names_to} - values = [] + pieces = [] for pivot_col in pivot_cols: col_name = pivot_col.get_name() match_result = names_pattern.match(col_name) - for name, value in zip(names_to, match_result.groups()): - transformer = names_transform[name] - names_map[name].append(transformer(value)) - values.append(values_transform(pivot_col)) - - new_cols = {key: ibis.array(value).unnest() for key, value in names_map.items()} - new_cols[values_to] = ibis.array(values).unnest() - - return self.select(~pivot_sel, **new_cols) + row = { + name: names_transform[name](value) + for name, value in zip(names_to, match_result.groups()) + } + row[values_to] = values_transform(pivot_col) + pieces.append(ibis.struct(row)) + + # nest into an array of structs to zip unnests together + pieces = ibis.array(pieces) + + return self.select(~pivot_sel, __pivoted__=pieces.unnest()).unpack( + "__pivoted__" + ) @util.experimental def pivot_wider(