Skip to content

Commit

Permalink
feat(arrays): add concat method equivalent to __add__/__radd__
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jul 31, 2023
1 parent 65c57f3 commit 0ed0ab1
Show file tree
Hide file tree
Showing 16 changed files with 142 additions and 94 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _struct_column(translator, op):


def _array_concat(translator, op):
return "ARRAY_CONCAT({})".format(", ".join(map(translator.translate, op.args)))
return "ARRAY_CONCAT({})".format(", ".join(map(translator.translate, op.arg)))


def _array_column(translator, op):
Expand Down
9 changes: 7 additions & 2 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,12 @@ def _map_get(op, **kw):
return f"if(mapContains({arg}, {key}), {arg}[{key}], {default})"


@translate_val.register(ops.ArrayConcat)
def _array_concat(op, **kw):
args = ", ".join(map(_sql, map(partial(translate_val, **kw), op.arg)))
return f"arrayConcat({args})"


def _binary_infix(symbol: str):
def formatter(op, **kw):
left = translate_val(op_left := op.left, **kw)
Expand Down Expand Up @@ -1056,7 +1062,6 @@ def formatter(op, **kw):
# because clickhouse"s greatest and least doesn"t support varargs
ops.Where: "if",
ops.ArrayLength: "length",
ops.ArrayConcat: "arrayConcat",
ops.Unnest: "arrayJoin",
ops.Degrees: "degrees",
ops.Radians: "radians",
Expand Down Expand Up @@ -1395,7 +1400,7 @@ def _array_remove(op, **kw):

@translate_val.register(ops.ArrayUnion)
def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw)
return translate_val(ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right))), **kw)


@translate_val.register(ops.ArrayZip)
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/dask/execution/arrays.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
from functools import partial

import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
Expand Down Expand Up @@ -51,3 +52,8 @@ def execute_array_collect(op, data, where, aggcontext=None, **kwargs):
@execute_node.register(ops.ArrayCollect, ddgb.SeriesGroupBy, type(None))
def execute_array_collect_grouped_series(op, data, where, **kwargs):
return data.agg(collect_list)


@execute_node.register(ops.ArrayConcat, tuple)
def execute_array_concat(op, args, **kwargs):
return execute_node(op, *map(partial(execute, **kwargs), args), **kwargs)
8 changes: 7 additions & 1 deletion ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles

import ibis.backends.base.sql.alchemy.datatypes as sat
Expand Down Expand Up @@ -39,7 +40,12 @@ def compile_uint(element, compiler, **kw):

@compiles(sat.ArrayType, "duckdb")
def compile_array(element, compiler, **kw):
return f"{compiler.process(element.value_type, **kw)}[]"
if isinstance(value_type := element.value_type, sa.types.NullType):
# duckdb infers empty arrays with no other context as array<int32>
typ = "INTEGER"
else:
typ = compiler.process(value_type, **kw)
return f"{typ}[]"


rewrites = DuckDBSQLExprTranslator.rewrites
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,6 @@ def _try_cast(t, op):
)
),
ops.TryCast: _try_cast,
ops.ArrayConcat: fixed_arity(sa.func.array_concat, 2),
ops.ArrayRepeat: fixed_arity(
lambda arg, times: sa.func.flatten(
sa.func.array(
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/pandas/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

@execute_node.register(ops.Node, [object])
def raise_unknown_op(node, *args, **kwargs):
signature = ", ".join(type(arg).__name__ for arg in args)
raise com.OperationNotDefinedError(
f"Operation {type(node).__name__!r} is not implemented for this backend"
"Operation is not implemented for this backend with "
f"signature: execute_node({type(node).__name__}, {signature})"
)


Expand Down
53 changes: 31 additions & 22 deletions ibis/backends/pandas/execution/arrays.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import operator
from functools import partial
from typing import Any, Collection

import numpy as np
Expand Down Expand Up @@ -56,45 +57,53 @@ def execute_array_index_scalar(op, data, index, **kwargs):
return None


def _concat_iterables_to_series(
iter1: Collection[Any],
iter2: Collection[Any],
) -> pd.Series:
def _concat_iterables_to_series(*iters: Collection[Any]) -> pd.Series:
"""Concatenate two collections to create a Series.
The two collections are assumed to have the same length.
Used for ArrayConcat implementation.
"""
assert len(iter1) == len(iter2)
first, *rest = iters
assert all(len(series) == len(first) for series in rest)
# Doing the iteration using `map` is much faster than doing the iteration
# using `Series.apply` due to Pandas-related overhead.
result = pd.Series(map(lambda x, y: np.concatenate([x, y]), iter1, iter2))
return result
return pd.Series(map(lambda *args: np.concatenate(args), first, *rest))


@execute_node.register(ops.ArrayConcat, pd.Series, pd.Series)
def execute_array_concat_series(op, left, right, **kwargs):
return _concat_iterables_to_series(left, right)
@execute_node.register(ops.ArrayConcat, tuple)
def execute_array_concat(op, args, **kwargs):
return execute_node(op, *map(partial(execute, **kwargs), args), **kwargs)


@execute_node.register(ops.ArrayConcat, pd.Series, pd.Series, [pd.Series])
def execute_array_concat_series(op, first, second, *args, **kwargs):
return _concat_iterables_to_series(first, second, *args)


@execute_node.register(ops.ArrayConcat, pd.Series, np.ndarray)
@execute_node.register(ops.ArrayConcat, np.ndarray, pd.Series)
def execute_array_concat_mixed(op, left, right, **kwargs):
@execute_node.register(
ops.ArrayConcat, np.ndarray, pd.Series, [(pd.Series, np.ndarray)]
)
def execute_array_concat_mixed_left(op, left, right, *args, **kwargs):
# ArrayConcat given a column (pd.Series) and a scalar (np.ndarray).
# We will broadcast the scalar to the length of the column.
if isinstance(left, np.ndarray):
# Broadcast `left` to the length of `right`
left = np.tile(left, (len(right), 1))
elif isinstance(right, np.ndarray):
# Broadcast `right` to the length of `left`
right = np.tile(right, (len(left), 1))
# Broadcast `left` to the length of `right`
left = np.tile(left, (len(right), 1))
return _concat_iterables_to_series(left, right)


@execute_node.register(
ops.ArrayConcat, pd.Series, np.ndarray, [(pd.Series, np.ndarray)]
)
def execute_array_concat_mixed_right(op, left, right, *args, **kwargs):
# Broadcast `right` to the length of `left`
right = np.tile(right, (len(left), 1))
return _concat_iterables_to_series(left, right)


@execute_node.register(ops.ArrayConcat, np.ndarray, np.ndarray)
def execute_array_concat_scalar(op, left, right, **kwargs):
return np.concatenate([left, right])
@execute_node.register(ops.ArrayConcat, np.ndarray, np.ndarray, [np.ndarray])
def execute_array_concat_scalar(op, left, right, *args, **kwargs):
return np.concatenate([left, right, *args])


@execute_node.register(ops.ArrayRepeat, pd.Series, int)
Expand Down
15 changes: 9 additions & 6 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,15 @@ def array_length(op, **kw):

@translate.register(ops.ArrayConcat)
def array_concat(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)
try:
return left.arr.concat(right)
except AttributeError:
return left.list.concat(right)
result, *rest = map(partial(translate, **kw), op.arg)

for arg in rest:
try:
result = result.arr.concat(arg)
except AttributeError:
result = result.list.concat(arg)

return result


@translate.register(ops.ArrayColumn)
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
sqlalchemy_operation_registry,
sqlalchemy_window_functions_registry,
unary,
varargs,
)
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.backends.base.sql.alchemy.registry import (
Expand Down Expand Up @@ -658,7 +659,7 @@ def _unnest(t, op):
ops.ArrayIndex: _array_index(
index_converter=_neg_idx_to_pos, func=lambda arg, index: arg[index]
),
ops.ArrayConcat: fixed_arity(sa.sql.expression.ColumnElement.concat, 2),
ops.ArrayConcat: varargs(lambda *args: functools.reduce(operator.add, args)),
ops.ArrayRepeat: _array_repeat,
ops.Unnest: _unnest,
ops.Covariance: _covar,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def test_array_concat(array_types, catop):

def test_array_concat_mixed_types(array_types):
with pytest.raises(TypeError):
array_types.x + array_types.x.cast('array<double>')
array_types.y + array_types.x.cast('array<double>')


@pytest.fixture
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,9 +1640,7 @@ def compile_array_index(t, op, **kwargs):

@compiles(ops.ArrayConcat)
def compile_array_concat(t, op, **kwargs):
left = t.translate(op.left, **kwargs)
right = t.translate(op.right, **kwargs)
return F.concat(left, right)
return F.concat(*map(partial(t.translate, **kwargs), op.arg))


@compiles(ops.ArrayRepeat)
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import itertools

import numpy as np
Expand All @@ -18,6 +19,7 @@
get_sqla_table,
reduction,
unary,
varargs,
)
from ibis.backends.postgres.registry import _literal as _postgres_literal
from ibis.backends.postgres.registry import operation_registry as _operation_registry
Expand Down Expand Up @@ -370,7 +372,9 @@ def _map_get(t, op):
),
ops.ArrayIndex: fixed_arity(sa.func.get, 2),
ops.ArrayLength: fixed_arity(sa.func.array_size, 1),
ops.ArrayConcat: fixed_arity(sa.func.array_cat, 2),
ops.ArrayConcat: varargs(
lambda *args: functools.reduce(sa.func.array_cat, args)
),
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
*map(t.translate, op.cols)
),
Expand Down
27 changes: 27 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ def test_array_concat(con):
assert np.array_equal(result, expected)


# Issues #2370
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_concat_variadic(con):
left = ibis.literal([1, 2, 3])
right = ibis.literal([2, 1])
expr = left.concat(right, right, right)
result = con.execute(expr.name("tmp"))
expected = np.array([1, 2, 3, 2, 1, 2, 1, 2, 1])
assert np.array_equal(result, expected)


# Issues #2370
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["postgres", "trino"],
raises=sa.exc.ProgrammingError,
reason="postgres can't infer the type of an empty array",
)
def test_array_concat_some_empty(con):
left = ibis.literal([])
right = ibis.literal([2, 1])
expr = left.concat(right)
result = con.execute(expr.name("tmp"))
expected = np.array([2, 1])
assert np.array_equal(result, expected)


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_array_radd_concat(con):
left = [1]
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import ibis
import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.registry import _literal as _alchemy_literal
from ibis.backends.base.sql.alchemy.registry import (
_literal as _alchemy_literal,
)
from ibis.backends.base.sql.alchemy.registry import (
array_filter,
array_map,
Expand All @@ -21,6 +23,7 @@
sqlalchemy_window_functions_registry,
try_cast,
unary,
varargs,
)
from ibis.backends.postgres.registry import _corr, _covar

Expand Down Expand Up @@ -328,7 +331,7 @@ def _try_cast(t, op):
ops.BitwiseRightShift: fixed_arity(sa.func.bitwise_right_shift, 2),
ops.BitwiseNot: unary(sa.func.bitwise_not),
ops.ArrayCollect: reduction(sa.func.array_agg),
ops.ArrayConcat: fixed_arity(sa.func.concat, 2),
ops.ArrayConcat: varargs(sa.func.concat),
ops.ArrayLength: unary(sa.func.cardinality),
ops.ArrayIndex: fixed_arity(
lambda arg, index: sa.func.element_at(arg, index + 1), 2
Expand Down
23 changes: 9 additions & 14 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from public import public

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
Expand Down Expand Up @@ -52,21 +51,17 @@ def output_dtype(self):

@public
class ArrayConcat(Value):
left = rlz.array
right = rlz.array
arg = rlz.tuple_of(rlz.array, min_length=2)

output_dtype = rlz.dtype_like("left")
output_shape = rlz.shape_like("args")
@attribute.default
def output_dtype(self):
return dt.Array(
dt.highest_precedence(arg.output_dtype.value_type for arg in self.arg)
)

def __init__(self, left, right):
if left.output_dtype != right.output_dtype:
raise com.IbisTypeError(
'Array types must match exactly in a {} operation. '
'Left type {} != Right type {}'.format(
type(self).__name__, left.output_dtype, right.output_dtype
)
)
super().__init__(left=left, right=right)
@attribute.default
def output_shape(self):
return rlz.highest_precedence_shape(self.arg)


@public
Expand Down
Loading

0 comments on commit 0ed0ab1

Please sign in to comment.