Skip to content

Commit

Permalink
chore: fix overlapping args for trino, postgres, bigquery and pyspark
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Dec 14, 2023
1 parent 45fc0ac commit 2e7fb76
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def _integer_range(translator, op):
ops.EndsWith: fixed_arity("ENDS_WITH", 2),
ops.TableColumn: table_column,
ops.CountDistinctStar: _count_distinct_star,
ops.Argument: lambda _, op: op.name,
ops.Argument: lambda _, op: op.param,
ops.Unnest: unary("UNNEST"),
ops.TimeDelta: _time_delta,
ops.DateDelta: _date_delta,
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,9 @@ def _integer_range(t, op):
ops.Literal: _literal,
# We override this here to support time zones
ops.TableColumn: _table_column,
ops.Argument: lambda t, op: sa.column(op.name, type_=t.get_sqla_type(op.dtype)),
ops.Argument: lambda t, op: sa.column(
op.param, type_=t.get_sqla_type(op.dtype)
),
# types
ops.TypeOf: _typeof,
# Floating
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ def compile_array_collect(t, op, **kwargs):

@compiles(ops.Argument)
def compile_argument(t, op, arg_columns, **kwargs):
return arg_columns[op.name]
return arg_columns[op.param]


@compiles(ops.ArrayFilter)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _try_cast(t, op):
def _array_intersect(t, op):
x = ops.Argument(name="x", shape=op.left.shape, dtype=op.left.dtype.value_type)
return t.translate(
ops.ArrayFilter(op.left, param="x", body=ops.ArrayContains(op.right, x))
ops.ArrayFilter(op.left, param=x.param, body=ops.ArrayContains(op.right, x))
)


Expand Down Expand Up @@ -538,7 +538,7 @@ def _integer_range(t, op):
lambda sep, arr: sa.func.array_join(arr, sep), 2
),
ops.StartsWith: fixed_arity(sa.func.starts_with, 2),
ops.Argument: lambda _, op: sa.literal_column(op.name),
ops.Argument: lambda _, op: sa.literal_column(op.param),
ops.First: partial(_first_last, offset=1),
ops.Last: partial(_first_last, offset=-1),
ops.ArrayZip: _zip,
Expand Down

0 comments on commit 2e7fb76

Please sign in to comment.