From 0ed0ab14f93621c5d647338c5b2cb882abe36a85 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 29 Jul 2023 10:20:38 -0400 Subject: [PATCH] feat(arrays): add `concat` method equivalent to `__add__`/`__radd__` --- ibis/backends/bigquery/registry.py | 2 +- ibis/backends/clickhouse/compiler/values.py | 9 ++- ibis/backends/dask/execution/arrays.py | 6 ++ ibis/backends/duckdb/compiler.py | 8 ++- ibis/backends/duckdb/registry.py | 1 - ibis/backends/pandas/dispatch.py | 4 +- ibis/backends/pandas/execution/arrays.py | 53 ++++++++------- ibis/backends/polars/compiler.py | 15 +++-- ibis/backends/postgres/registry.py | 3 +- .../backends/postgres/tests/test_functions.py | 2 +- ibis/backends/pyspark/compiler.py | 4 +- ibis/backends/snowflake/registry.py | 6 +- ibis/backends/tests/test_array.py | 27 ++++++++ ibis/backends/trino/registry.py | 7 +- ibis/expr/operations/arrays.py | 23 +++---- ibis/expr/types/arrays.py | 66 ++++++++----------- 16 files changed, 142 insertions(+), 94 deletions(-) diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index 3f066673b72c..6d5e60048d25 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -124,7 +124,7 @@ def _struct_column(translator, op): def _array_concat(translator, op): - return "ARRAY_CONCAT({})".format(", ".join(map(translator.translate, op.args))) + return "ARRAY_CONCAT({})".format(", ".join(map(translator.translate, op.arg))) def _array_column(translator, op): diff --git a/ibis/backends/clickhouse/compiler/values.py b/ibis/backends/clickhouse/compiler/values.py index 6c4d2e511ab4..bc45b09193bd 100644 --- a/ibis/backends/clickhouse/compiler/values.py +++ b/ibis/backends/clickhouse/compiler/values.py @@ -924,6 +924,12 @@ def _map_get(op, **kw): return f"if(mapContains({arg}, {key}), {arg}[{key}], {default})" +@translate_val.register(ops.ArrayConcat) +def _array_concat(op, **kw): + args = ", ".join(map(_sql, map(partial(translate_val, **kw), op.arg))) + return f"arrayConcat({args})" + + def _binary_infix(symbol: str): def formatter(op, **kw): left = translate_val(op_left := op.left, **kw) @@ -1056,7 +1062,6 @@ def formatter(op, **kw): # because clickhouse"s greatest and least doesn"t support varargs ops.Where: "if", ops.ArrayLength: "length", - ops.ArrayConcat: "arrayConcat", ops.Unnest: "arrayJoin", ops.Degrees: "degrees", ops.Radians: "radians", @@ -1395,7 +1400,7 @@ def _array_remove(op, **kw): @translate_val.register(ops.ArrayUnion) def _array_union(op, **kw): - return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw) + return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw) @translate_val.register(ops.ArrayZip) diff --git a/ibis/backends/dask/execution/arrays.py b/ibis/backends/dask/execution/arrays.py index d7c53e0c1fb4..04e9607ede44 100644 --- a/ibis/backends/dask/execution/arrays.py +++ b/ibis/backends/dask/execution/arrays.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +from functools import partial import dask.dataframe as dd import dask.dataframe.groupby as ddgb @@ -51,3 +52,8 @@ def execute_array_collect(op, data, where, aggcontext=None, **kwargs): @execute_node.register(ops.ArrayCollect, ddgb.SeriesGroupBy, type(None)) def execute_array_collect_grouped_series(op, data, where, **kwargs): return data.agg(collect_list) + + +@execute_node.register(ops.ArrayConcat, tuple) +def execute_array_concat(op, args, **kwargs): + return execute_node(op, *map(partial(execute, **kwargs), args), **kwargs) diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index a03cac54fab0..719427c4f4e2 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sqlalchemy as sa from sqlalchemy.ext.compiler import compiles import ibis.backends.base.sql.alchemy.datatypes as sat @@ -39,7 +40,12 @@ def compile_uint(element, compiler, **kw): @compiles(sat.ArrayType, "duckdb") def compile_array(element, compiler, **kw): - return f"{compiler.process(element.value_type, **kw)}[]" + if isinstance(value_type := element.value_type, sa.types.NullType): + # duckdb infers empty arrays with no other context as array + typ = "INTEGER" + else: + typ = compiler.process(value_type, **kw) + return f"{typ}[]" rewrites = DuckDBSQLExprTranslator.rewrites diff --git a/ibis/backends/duckdb/registry.py b/ibis/backends/duckdb/registry.py index 60d7d0a35b48..ea756bd46544 100644 --- a/ibis/backends/duckdb/registry.py +++ b/ibis/backends/duckdb/registry.py @@ -353,7 +353,6 @@ def _try_cast(t, op): ) ), ops.TryCast: _try_cast, - ops.ArrayConcat: fixed_arity(sa.func.array_concat, 2), ops.ArrayRepeat: fixed_arity( lambda arg, times: sa.func.flatten( sa.func.array( diff --git a/ibis/backends/pandas/dispatch.py b/ibis/backends/pandas/dispatch.py index 88794a146ac4..0972f189d031 100644 --- a/ibis/backends/pandas/dispatch.py +++ b/ibis/backends/pandas/dispatch.py @@ -22,8 +22,10 @@ @execute_node.register(ops.Node, [object]) def raise_unknown_op(node, *args, **kwargs): + signature = ", ".join(type(arg).__name__ for arg in args) raise com.OperationNotDefinedError( - f"Operation {type(node).__name__!r} is not implemented for this backend" + "Operation is not implemented for this backend with " + f"signature: execute_node({type(node).__name__}, {signature})" ) diff --git a/ibis/backends/pandas/execution/arrays.py b/ibis/backends/pandas/execution/arrays.py index 89a3848bcb89..8b23a5c93f43 100644 --- a/ibis/backends/pandas/execution/arrays.py +++ b/ibis/backends/pandas/execution/arrays.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from functools import partial from typing import Any, Collection import numpy as np @@ -56,45 +57,53 @@ def execute_array_index_scalar(op, data, index, **kwargs): return None -def _concat_iterables_to_series( - iter1: Collection[Any], - iter2: Collection[Any], -) -> pd.Series: +def _concat_iterables_to_series(*iters: Collection[Any]) -> pd.Series: """Concatenate two collections to create a Series. The two collections are assumed to have the same length. Used for ArrayConcat implementation. """ - assert len(iter1) == len(iter2) + first, *rest = iters + assert all(len(series) == len(first) for series in rest) # Doing the iteration using `map` is much faster than doing the iteration # using `Series.apply` due to Pandas-related overhead. - result = pd.Series(map(lambda x, y: np.concatenate([x, y]), iter1, iter2)) - return result + return pd.Series(map(lambda *args: np.concatenate(args), first, *rest)) -@execute_node.register(ops.ArrayConcat, pd.Series, pd.Series) -def execute_array_concat_series(op, left, right, **kwargs): - return _concat_iterables_to_series(left, right) +@execute_node.register(ops.ArrayConcat, tuple) +def execute_array_concat(op, args, **kwargs): + return execute_node(op, *map(partial(execute, **kwargs), args), **kwargs) + + +@execute_node.register(ops.ArrayConcat, pd.Series, pd.Series, [pd.Series]) +def execute_array_concat_series(op, first, second, *args, **kwargs): + return _concat_iterables_to_series(first, second, *args) -@execute_node.register(ops.ArrayConcat, pd.Series, np.ndarray) -@execute_node.register(ops.ArrayConcat, np.ndarray, pd.Series) -def execute_array_concat_mixed(op, left, right, **kwargs): +@execute_node.register( + ops.ArrayConcat, np.ndarray, pd.Series, [(pd.Series, np.ndarray)] +) +def execute_array_concat_mixed_left(op, left, right, *args, **kwargs): # ArrayConcat given a column (pd.Series) and a scalar (np.ndarray). # We will broadcast the scalar to the length of the column. - if isinstance(left, np.ndarray): - # Broadcast `left` to the length of `right` - left = np.tile(left, (len(right), 1)) - elif isinstance(right, np.ndarray): - # Broadcast `right` to the length of `left` - right = np.tile(right, (len(left), 1)) + # Broadcast `left` to the length of `right` + left = np.tile(left, (len(right), 1)) + return _concat_iterables_to_series(left, right) + + +@execute_node.register( + ops.ArrayConcat, pd.Series, np.ndarray, [(pd.Series, np.ndarray)] +) +def execute_array_concat_mixed_right(op, left, right, *args, **kwargs): + # Broadcast `right` to the length of `left` + right = np.tile(right, (len(left), 1)) return _concat_iterables_to_series(left, right) -@execute_node.register(ops.ArrayConcat, np.ndarray, np.ndarray) -def execute_array_concat_scalar(op, left, right, **kwargs): - return np.concatenate([left, right]) +@execute_node.register(ops.ArrayConcat, np.ndarray, np.ndarray, [np.ndarray]) +def execute_array_concat_scalar(op, left, right, *args, **kwargs): + return np.concatenate([left, right, *args]) @execute_node.register(ops.ArrayRepeat, pd.Series, int) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index ab2dbaa5b54b..b41dc047442f 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -862,12 +862,15 @@ def array_length(op, **kw): @translate.register(ops.ArrayConcat) def array_concat(op, **kw): - left = translate(op.left, **kw) - right = translate(op.right, **kw) - try: - return left.arr.concat(right) - except AttributeError: - return left.list.concat(right) + result, *rest = map(partial(translate, **kw), op.arg) + + for arg in rest: + try: + result = result.arr.concat(arg) + except AttributeError: + result = result.list.concat(arg) + + return result @translate.register(ops.ArrayColumn) diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index 1e5402920da8..989a2c9030c7 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -27,6 +27,7 @@ sqlalchemy_operation_registry, sqlalchemy_window_functions_registry, unary, + varargs, ) from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported from ibis.backends.base.sql.alchemy.registry import ( @@ -658,7 +659,7 @@ def _unnest(t, op): ops.ArrayIndex: _array_index( index_converter=_neg_idx_to_pos, func=lambda arg, index: arg[index] ), - ops.ArrayConcat: fixed_arity(sa.sql.expression.ColumnElement.concat, 2), + ops.ArrayConcat: varargs(lambda *args: functools.reduce(operator.add, args)), ops.ArrayRepeat: _array_repeat, ops.Unnest: _unnest, ops.Covariance: _covar, diff --git a/ibis/backends/postgres/tests/test_functions.py b/ibis/backends/postgres/tests/test_functions.py index 0c1b9cf8ee6d..3564a5faae3c 100644 --- a/ibis/backends/postgres/tests/test_functions.py +++ b/ibis/backends/postgres/tests/test_functions.py @@ -1008,7 +1008,7 @@ def test_array_concat(array_types, catop): def test_array_concat_mixed_types(array_types): with pytest.raises(TypeError): - array_types.x + array_types.x.cast('array') + array_types.y + array_types.x.cast('array') @pytest.fixture diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 8343bce98080..fe5d8f260a78 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -1640,9 +1640,7 @@ def compile_array_index(t, op, **kwargs): @compiles(ops.ArrayConcat) def compile_array_concat(t, op, **kwargs): - left = t.translate(op.left, **kwargs) - right = t.translate(op.right, **kwargs) - return F.concat(left, right) + return F.concat(*map(partial(t.translate, **kwargs), op.arg)) @compiles(ops.ArrayRepeat) diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 3d5db1b13900..428bf391d600 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools import itertools import numpy as np @@ -18,6 +19,7 @@ get_sqla_table, reduction, unary, + varargs, ) from ibis.backends.postgres.registry import _literal as _postgres_literal from ibis.backends.postgres.registry import operation_registry as _operation_registry @@ -370,7 +372,9 @@ def _map_get(t, op): ), ops.ArrayIndex: fixed_arity(sa.func.get, 2), ops.ArrayLength: fixed_arity(sa.func.array_size, 1), - ops.ArrayConcat: fixed_arity(sa.func.array_cat, 2), + ops.ArrayConcat: varargs( + lambda *args: functools.reduce(sa.func.array_cat, args) + ), ops.ArrayColumn: lambda t, op: sa.func.array_construct( *map(t.translate, op.cols) ), diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 8cb37d48d156..e882db2b07da 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -107,6 +107,33 @@ def test_array_concat(con): assert np.array_equal(result, expected) +# Issues #2370 +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +def test_array_concat_variadic(con): + left = ibis.literal([1, 2, 3]) + right = ibis.literal([2, 1]) + expr = left.concat(right, right, right) + result = con.execute(expr.name("tmp")) + expected = np.array([1, 2, 3, 2, 1, 2, 1, 2, 1]) + assert np.array_equal(result, expected) + + +# Issues #2370 +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notyet( + ["postgres", "trino"], + raises=sa.exc.ProgrammingError, + reason="postgres can't infer the type of an empty array", +) +def test_array_concat_some_empty(con): + left = ibis.literal([]) + right = ibis.literal([2, 1]) + expr = left.concat(right) + result = con.execute(expr.name("tmp")) + expected = np.array([2, 1]) + assert np.array_equal(result, expected) + + @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_array_radd_concat(con): left = [1] diff --git a/ibis/backends/trino/registry.py b/ibis/backends/trino/registry.py index 17b24d001337..d1fbc5921061 100644 --- a/ibis/backends/trino/registry.py +++ b/ibis/backends/trino/registry.py @@ -11,7 +11,9 @@ import ibis import ibis.common.exceptions as com import ibis.expr.operations as ops -from ibis.backends.base.sql.alchemy.registry import _literal as _alchemy_literal +from ibis.backends.base.sql.alchemy.registry import ( + _literal as _alchemy_literal, +) from ibis.backends.base.sql.alchemy.registry import ( array_filter, array_map, @@ -21,6 +23,7 @@ sqlalchemy_window_functions_registry, try_cast, unary, + varargs, ) from ibis.backends.postgres.registry import _corr, _covar @@ -328,7 +331,7 @@ def _try_cast(t, op): ops.BitwiseRightShift: fixed_arity(sa.func.bitwise_right_shift, 2), ops.BitwiseNot: unary(sa.func.bitwise_not), ops.ArrayCollect: reduction(sa.func.array_agg), - ops.ArrayConcat: fixed_arity(sa.func.concat, 2), + ops.ArrayConcat: varargs(sa.func.concat), ops.ArrayLength: unary(sa.func.cardinality), ops.ArrayIndex: fixed_arity( lambda arg, index: sa.func.element_at(arg, index + 1), 2 diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 6039f1d3fbc7..ea7e30b7a801 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -2,7 +2,6 @@ from public import public -import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.rules as rlz from ibis.common.annotations import attribute @@ -52,21 +51,17 @@ def output_dtype(self): @public class ArrayConcat(Value): - left = rlz.array - right = rlz.array + arg = rlz.tuple_of(rlz.array, min_length=2) - output_dtype = rlz.dtype_like("left") - output_shape = rlz.shape_like("args") + @attribute.default + def output_dtype(self): + return dt.Array( + dt.highest_precedence(arg.output_dtype.value_type for arg in self.arg) + ) - def __init__(self, left, right): - if left.output_dtype != right.output_dtype: - raise com.IbisTypeError( - 'Array types must match exactly in a {} operation. ' - 'Left type {} != Right type {}'.format( - type(self).__name__, left.output_dtype, right.output_dtype - ) - ) - super().__init__(left=left, right=right) + @attribute.default + def output_shape(self): + return rlz.highest_precedence_shape(self.arg) @public diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index c411e4898a6f..9c8433c459d3 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -125,18 +125,20 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Value: op = ops.ArrayIndex(self, index) return op.to_expr() - def __add__(self, other: ArrayValue) -> ArrayValue: - """Concatenate this array with another. + def concat(self, other: ArrayValue, *args: ArrayValue) -> ArrayValue: + """Concatenate this array with one or more arrays. Parameters ---------- other - Array to concat with `self` + Other array to concat with `self` + args + Other arrays to concat with `self` Returns ------- ArrayValue - `self` concatenated with `other` + `self` concatenated with `other` and `args` Examples -------- @@ -153,9 +155,9 @@ def __add__(self, other: ArrayValue) -> ArrayValue: │ [3] │ │ NULL │ └──────────────────────┘ - >>> t.a + t.a + >>> t.a.concat(t.a) ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayConcat(a, a) ┃ + ┃ ArrayConcat() ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ │ array │ ├──────────────────────┤ @@ -163,9 +165,9 @@ def __add__(self, other: ArrayValue) -> ArrayValue: │ [3, 3] │ │ NULL │ └──────────────────────┘ - >>> t.a + ibis.literal([4], type="array") + >>> t.a.concat(ibis.literal([4], type="array")) ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayConcat(a, (4,)) ┃ + ┃ ArrayConcat() ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ │ array │ ├──────────────────────┤ @@ -173,49 +175,37 @@ def __add__(self, other: ArrayValue) -> ArrayValue: │ [3, 4] │ │ [4] │ └──────────────────────┘ - """ - return ops.ArrayConcat(self, other).to_expr() - - def __radd__(self, other: ArrayValue) -> ArrayValue: - """Concatenate this array with another. - Parameters - ---------- - other - Array to concat with `self` + `concat` is also available using the `+` operator - Returns - ------- - ArrayValue - `self` concatenated with `other` - - Examples - -------- - >>> import ibis - >>> ibis.options.interactive = True - >>> t = ibis.memtable({"a": [[7], [3] , None]}) - >>> t + >>> [1] + t.a ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ a ┃ + ┃ ArrayConcat() ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ │ array │ ├──────────────────────┤ - │ [7] │ - │ [3] │ - │ NULL │ + │ [1, 7] │ + │ [1, 3] │ + │ [1] │ └──────────────────────┘ - >>> ibis.literal([4], type="array") + t.a + >>> t.a + [1] ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ ArrayConcat((4,), a) ┃ + ┃ ArrayConcat() ┃ ┡━━━━━━━━━━━━━━━━━━━━━━┩ │ array │ ├──────────────────────┤ - │ [4, 7] │ - │ [4, 3] │ - │ [4] │ + │ [7, 1] │ + │ [3, 1] │ + │ [1] │ └──────────────────────┘ """ - return ops.ArrayConcat(other, self).to_expr() + return ops.ArrayConcat((self, other, *args)).to_expr() + + def __add__(self, other: ArrayValue) -> ArrayValue: + return self.concat(other) + + def __radd__(self, other: ArrayValue) -> ArrayValue: + return ops.ArrayConcat((other, self)).to_expr() def __mul__(self, n: int | ir.IntegerValue) -> ArrayValue: """Repeat this array `n` times.