From 2e7fb76d6e317595cc7102877c38c8fe5c8a36e1 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 13 Dec 2023 16:16:31 -0500 Subject: [PATCH] chore: fix overlapping args for trino, postgres, bigquery and pyspark --- ibis/backends/bigquery/registry.py | 2 +- ibis/backends/postgres/registry.py | 4 +++- ibis/backends/pyspark/compiler.py | 2 +- ibis/backends/trino/registry.py | 4 ++-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index f5102a90aa1c..9c43ad2741ae 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -944,7 +944,7 @@ def _integer_range(translator, op): ops.EndsWith: fixed_arity("ENDS_WITH", 2), ops.TableColumn: table_column, ops.CountDistinctStar: _count_distinct_star, - ops.Argument: lambda _, op: op.name, + ops.Argument: lambda _, op: op.param, ops.Unnest: unary("UNNEST"), ops.TimeDelta: _time_delta, ops.DateDelta: _date_delta, diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index ac8818da01ef..97761f2e0e76 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -644,7 +644,9 @@ def _integer_range(t, op): ops.Literal: _literal, # We override this here to support time zones ops.TableColumn: _table_column, - ops.Argument: lambda t, op: sa.column(op.name, type_=t.get_sqla_type(op.dtype)), + ops.Argument: lambda t, op: sa.column( + op.param, type_=t.get_sqla_type(op.dtype) + ), # types ops.TypeOf: _typeof, # Floating diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 0245c24bda1f..ffe113f27b8a 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -1689,7 +1689,7 @@ def compile_array_collect(t, op, **kwargs): @compiles(ops.Argument) def compile_argument(t, op, arg_columns, **kwargs): - return arg_columns[op.name] + return arg_columns[op.param] @compiles(ops.ArrayFilter) diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index 87c3e776e2ef..a100a034a022 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -327,7 +327,7 @@ def _try_cast(t, op): def _array_intersect(t, op): x = ops.Argument(name="x", shape=op.left.shape, dtype=op.left.dtype.value_type) return t.translate( - ops.ArrayFilter(op.left, param="x", body=ops.ArrayContains(op.right, x)) + ops.ArrayFilter(op.left, param=x.param, body=ops.ArrayContains(op.right, x)) ) @@ -538,7 +538,7 @@ def _integer_range(t, op): lambda sep, arr: sa.func.array_join(arr, sep), 2 ), ops.StartsWith: fixed_arity(sa.func.starts_with, 2), - ops.Argument: lambda _, op: sa.literal_column(op.name), + ops.Argument: lambda _, op: sa.literal_column(op.param), ops.First: partial(_first_last, offset=1), ops.Last: partial(_first_last, offset=-1), ops.ArrayZip: _zip,