From 90befb22ce4baaeecbed9eb1fb24719b3419e055 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 26 Sep 2023 06:36:36 -0400 Subject: [PATCH] refactor(backends): adjust backends to work with new array representation --- ibis/backends/bigquery/registry.py | 9 +++++---- ibis/backends/duckdb/registry.py | 29 ++++++++++++++++++++--------- ibis/backends/postgres/registry.py | 8 ++++---- ibis/backends/pyspark/compiler.py | 4 ++-- ibis/backends/trino/registry.py | 11 ++++------- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index 92cfb7fcccc2..367c552892e9 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -214,14 +214,15 @@ def _array_zip(translator, op): def _array_map(translator, op): arg = translator.translate(op.arg) - result = translator.translate(op.result) - return f"ARRAY(SELECT {result} FROM UNNEST({arg}) {op.parameter})" + result = translator.translate(op.body) + param = op.param + return f"ARRAY(SELECT {result} FROM UNNEST({arg}) {param})" def _array_filter(translator, op): arg = translator.translate(op.arg) - result = translator.translate(op.result) - param = op.parameter + result = translator.translate(op.body) + param = op.param return f"ARRAY(SELECT {param} FROM UNNEST({arg}) {param} WHERE {result})" diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index 76871d876bcc..a93c8f770376 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -8,7 +8,6 @@ import sqlalchemy as sa from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.functions import GenericFunction -from toolz.curried import flip import ibis.expr.operations as ops from ibis.backends.base.sql import alchemy @@ -225,9 +224,7 @@ def compiles_list_apply(element, compiler, **kw): def _array_map(t, op): return array_map( - t.translate(op.arg), - sa.literal_column(f"({op.parameter})"), - t.translate(op.result), + t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body) ) @@ -239,15 +236,19 @@ def compiles_list_filter(element, compiler, **kw): def _array_filter(t, op): return array_filter( - t.translate(op.arg), - sa.literal_column(f"({op.parameter})"), - t.translate(op.result), + t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body) ) def _array_intersect(t, op): + name = "x" + parameter = ops.Argument( + name=name, shape=op.left.shape, dtype=op.left.dtype.value_type + ) return t.translate( - ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)) + ops.ArrayFilter( + op.left, param=name, body=ops.ArrayContains(op.right, parameter) + ) ) @@ -372,7 +373,17 @@ def _try_cast(t, op): ), ops.ArraySort: fixed_arity(sa.func.list_sort, 1), ops.ArrayRemove: lambda t, op: _array_filter( - t, ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other)) + t, + ops.ArrayFilter( + op.arg, + param="x", + body=ops.NotEquals( + ops.Argument( + name="x", shape=op.arg.shape, dtype=op.arg.dtype.value_type + ), + op.other, + ), + ), ), ops.ArrayUnion: lambda t, op: t.translate( ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))) diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 785d3e240c19..14e13f327584 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -574,26 +574,26 @@ def _array_map(t, op): return sa.func.array( # this translates to the function call, with column names the same as # the parameter names in the lambda - sa.select(t.translate(op.result)) + sa.select(t.translate(op.body)) .select_from( # unnest the input array sa.func.unnest(t.translate(op.arg)) # name the columns of the result the same as the lambda parameter # so that we can reference them as such in the outer query - .table_valued(op.parameter).render_derived() + .table_valued(op.param).render_derived() ) .scalar_subquery() ) def _array_filter(t, op): - param = op.parameter + param = op.param return sa.func.array( sa.select(sa.column(param, type_=t.get_sqla_type(op.arg.dtype.value_type))) .select_from( sa.func.unnest(t.translate(op.arg)).table_valued(param).render_derived() ) - .where(t.translate(op.result)) + .where(t.translate(op.body)) .scalar_subquery() ) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 5b0d3d620ce7..81f5fe2447fc 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -1702,7 +1702,7 @@ def compile_array_filter(t, op, **kwargs): src_column = t.translate(op.arg, **kwargs) return F.filter( src_column, - lambda x: t.translate(op.result, arg_columns={op.parameter: x}, **kwargs), + lambda x: t.translate(op.body, arg_columns={op.param: x}, **kwargs), ) @@ -1711,7 +1711,7 @@ def compile_array_map(t, op, **kwargs): src_column = t.translate(op.arg, **kwargs) return F.transform( src_column, - lambda x: t.translate(op.result, arg_columns={op.parameter: x}, **kwargs), + lambda x: t.translate(op.body, arg_columns={op.param: x}, **kwargs), ) diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index 16e1863cd29e..d7b98e72f3df 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -236,9 +236,7 @@ def compiles_list_apply(element, compiler, **kw): def _array_map(t, op): return array_map( - t.translate(op.arg), - sa.literal_column(f"({op.parameter})"), - t.translate(op.result), + t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body) ) @@ -250,9 +248,7 @@ def compiles_list_filter(element, compiler, **kw): def _array_filter(t, op): return array_filter( - t.translate(op.arg), - sa.literal_column(f"({op.parameter})"), - t.translate(op.result), + t.translate(op.arg), sa.literal_column(f"({op.param})"), t.translate(op.body) ) @@ -313,8 +309,9 @@ def _try_cast(t, op): def _array_intersect(t, op): + x = ops.Argument(name="x", shape=op.left.shape, dtype=op.left.dtype.value_type) return t.translate( - ops.ArrayFilter(op.left, func=lambda x: ops.ArrayContains(op.right, x)) + ops.ArrayFilter(op.left, param="x", body=ops.ArrayContains(op.right, x)) )