Skip to content

Commit

Permalink
fix(datatypes): ensure that array construction supports literals and …
Browse files Browse the repository at this point in the history
…infers their shape from its inputs (#8049)

We were previously returning `ArrayColumn` from `ibis.array` when any
inputs were expressions regardless of their shape. This PR renames
`ArrayColumn` -> `Array` and uses the input arguments shapes to
determine the output array shape.

Fixes #8022.

---------

Co-authored-by: Nick Crews <nicholas.b.crews@gmail.com>
  • Loading branch information
cpcloud and NickCrews authored Jan 22, 2024
1 parent 424b206 commit 899dce1
Show file tree
Hide file tree
Showing 19 changed files with 159 additions and 126 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _array_concat(translator, op):


def _array_column(translator, op):
return "[{}]".format(", ".join(map(translator.translate, op.cols)))
return "[{}]".format(", ".join(map(translator.translate, op.exprs)))


def _array_index(translator, op):
Expand Down Expand Up @@ -912,7 +912,7 @@ def _timestamp_range(translator, op):
ops.StructColumn: _struct_column,
ops.ArrayCollect: _array_agg,
ops.ArrayConcat: _array_concat,
ops.ArrayColumn: _array_column,
ops.Array: _array_column,
ops.ArrayIndex: _array_index,
ops.ArrayLength: unary("ARRAY_LENGTH"),
ops.ArrayRepeat: _array_repeat,
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ def _translate(op, *, arg, where, **_):
return _translate


@translate_val.register(ops.ArrayColumn)
def _array_column(op, *, cols, **_):
return F.array(*cols)
@translate_val.register(ops.Array)
def _array_column(op, *, exprs, **_):
return F.array(*exprs)


@translate_val.register(ops.StructColumn)
Expand Down
22 changes: 19 additions & 3 deletions ibis/backends/dask/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
import numpy as np
import pandas as pd

import ibis.expr.operations as ops
from ibis.backends.dask.core import execute
Expand Down Expand Up @@ -34,10 +35,25 @@
)


@execute_node.register(ops.ArrayColumn, tuple)
@execute_node.register(ops.Array, tuple)
def execute_array_column(op, cols, **kwargs):
cols = [execute(arg, **kwargs) for arg in cols]
df = dd.concat(cols, axis=1)
vals = [execute(arg, **kwargs) for arg in cols]

length = next((len(v) for v in vals if isinstance(v, dd.Series)), None)
if length is None:
return vals

n_partitions = next((v.npartitions for v in vals if isinstance(v, dd.Series)), None)

def ensure_series(v):
if isinstance(v, dd.Series):
return v
else:
return dd.from_pandas(pd.Series([v] * length), npartitions=n_partitions)

# dd.concat() can only handle array-likes.
# If we're given a scalar, we need to broadcast it as a Series.
df = dd.concat([ensure_series(v) for v in vals], axis=1)
return df.apply(
lambda row: np.array(row, dtype=object), axis=1, meta=(None, "object")
)
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ def _not_null(op, *, arg, **_):
return sg.not_(arg.is_(NULL))


@translate_val.register(ops.ArrayColumn)
def array_column(op, *, cols, **_):
return F.make_array(*cols)
@translate_val.register(ops.Array)
def array_column(op, *, exprs, **_):
return F.make_array(*exprs)


@translate_val.register(ops.ArrayRepeat)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def _array_remove(t, op):

operation_registry.update(
{
ops.ArrayColumn: (
ops.Array: (
lambda t, op: sa.cast(
sa.func.list_value(*map(t.translate, op.cols)),
sa.func.list_value(*map(t.translate, op.exprs)),
t.get_sqla_type(op.dtype),
)
),
Expand Down
39 changes: 27 additions & 12 deletions ibis/backends/pandas/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,23 @@
from collections.abc import Collection


@execute_node.register(ops.ArrayColumn, tuple)
def execute_array_column(op, cols, **kwargs):
cols = [execute(arg, **kwargs) for arg in cols]
df = pd.concat(cols, axis=1)
@execute_node.register(ops.Array, tuple)
def execute_array(op, cols, **kwargs):
vals = [execute(arg, **kwargs) for arg in cols]
length = next((len(v) for v in vals if isinstance(v, pd.Series)), None)

if length is None:
return vals

def ensure_series(v):
if isinstance(v, pd.Series):
return v
else:
return pd.Series(v, index=range(length))

# pd.concat() can only handle array-likes.
# If we're given a scalar, we need to broadcast it as a Series.
df = pd.concat([ensure_series(v) for v in vals], axis=1)
return df.apply(lambda row: np.array(row, dtype=object), axis=1)


Expand All @@ -29,7 +42,7 @@ def execute_array_length(op, data, **kwargs):
return data.apply(len)


@execute_node.register(ops.ArrayLength, np.ndarray)
@execute_node.register(ops.ArrayLength, (list, np.ndarray))
def execute_array_length_scalar(op, data, **kwargs):
return len(data)

Expand All @@ -39,7 +52,7 @@ def execute_array_slice(op, data, start, stop, **kwargs):
return data.apply(operator.itemgetter(slice(start, stop)))


@execute_node.register(ops.ArraySlice, np.ndarray, int, (int, type(None)))
@execute_node.register(ops.ArraySlice, (list, np.ndarray), int, (int, type(None)))
def execute_array_slice_scalar(op, data, start, stop, **kwargs):
return data[start:stop]

Expand All @@ -53,15 +66,15 @@ def execute_array_index(op, data, index, **kwargs):
)


@execute_node.register(ops.ArrayIndex, np.ndarray, int)
@execute_node.register(ops.ArrayIndex, (list, np.ndarray), int)
def execute_array_index_scalar(op, data, index, **kwargs):
try:
return data[index]
except IndexError:
return None


@execute_node.register(ops.ArrayContains, np.ndarray, object)
@execute_node.register(ops.ArrayContains, (list, np.ndarray), object)
def execute_node_contains_value_array(op, haystack, needle, **kwargs):
return needle in haystack

Expand Down Expand Up @@ -91,7 +104,7 @@ def execute_array_concat_series(op, first, second, *args, **kwargs):


@execute_node.register(
ops.ArrayConcat, np.ndarray, pd.Series, [(pd.Series, np.ndarray)]
ops.ArrayConcat, (list, np.ndarray), pd.Series, [(pd.Series, list, np.ndarray)]
)
def execute_array_concat_mixed_left(op, left, right, *args, **kwargs):
# ArrayConcat given a column (pd.Series) and a scalar (np.ndarray).
Expand All @@ -102,15 +115,17 @@ def execute_array_concat_mixed_left(op, left, right, *args, **kwargs):


@execute_node.register(
ops.ArrayConcat, pd.Series, np.ndarray, [(pd.Series, np.ndarray)]
ops.ArrayConcat, pd.Series, (list, np.ndarray), [(pd.Series, list, np.ndarray)]
)
def execute_array_concat_mixed_right(op, left, right, *args, **kwargs):
# Broadcast `right` to the length of `left`
right = np.tile(right, (len(left), 1))
return _concat_iterables_to_series(left, right)


@execute_node.register(ops.ArrayConcat, np.ndarray, np.ndarray, [np.ndarray])
@execute_node.register(
ops.ArrayConcat, (list, np.ndarray), (list, np.ndarray), [(list, np.ndarray)]
)
def execute_array_concat_scalar(op, left, right, *args, **kwargs):
return np.concatenate([left, right, *args])

Expand All @@ -122,7 +137,7 @@ def execute_array_repeat(op, data, n, **kwargs):
return pd.Series(np.tile(arr, n) for arr in data)


@execute_node.register(ops.ArrayRepeat, np.ndarray, int)
@execute_node.register(ops.ArrayRepeat, (list, np.ndarray), int)
def execute_array_repeat_scalar(op, data, n, **kwargs):
# Negative n will be treated as 0 (repeat will produce empty array)
return np.tile(data, max(n, 0))
Expand Down
22 changes: 22 additions & 0 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@ def cast_to_array(array, numpy_type=numpy_type):
return data.map(cast_to_array)


@execute_node.register(ops.Cast, list, dt.Array)
def execute_cast_list_array(op, data, type, **kwargs):
value_type = type.value_type
numpy_type = constants.IBIS_TYPE_TO_PANDAS_TYPE.get(value_type, None)
if numpy_type is None:
raise ValueError(
"Array value type must be a primitive type "
"(e.g., number, string, or timestamp)"
)

def cast_to_array(array, numpy_type=numpy_type):
elems = [
el if el is None else np.array(el, dtype=numpy_type).item() for el in array
]
try:
return np.array(elems, dtype=numpy_type)
except TypeError:
return np.array(elems)

return cast_to_array(data)


@execute_node.register(ops.Cast, pd.Series, dt.Timestamp)
def execute_cast_series_timestamp(op, data, type, **kwargs):
arg = op.arg
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,9 @@ def array_concat(op, **kw):
return result


@translate.register(ops.ArrayColumn)
@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.cols]
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def _range(t, op):
# array operations
ops.ArrayLength: unary(sa.func.cardinality),
ops.ArrayCollect: reduction(sa.func.array_agg),
ops.ArrayColumn: (lambda t, op: pg.array(list(map(t.translate, op.cols)))),
ops.Array: (lambda t, op: pg.array(list(map(t.translate, op.exprs)))),
ops.ArraySlice: _array_slice(
index_converter=_neg_idx_to_pos,
array_length=sa.func.cardinality,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,9 +1634,9 @@ def compile_interval_from_integer(t, op, **kwargs):
# -------------------------- Array Operations ----------------------------


@compiles(ops.ArrayColumn)
@compiles(ops.Array)
def compile_array_column(t, op, **kwargs):
cols = [t.translate(col, **kwargs) for col in op.cols]
cols = [t.translate(col, **kwargs) for col in op.exprs]
return F.array(cols)


Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def _timestamp_range(t, op):
ops.ArrayConcat: varargs(
lambda *args: functools.reduce(sa.func.array_cat, args)
),
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
*map(t.translate, op.cols)
),
ops.Array: lambda t, op: sa.func.array_construct(*map(t.translate, op.exprs)),
ops.ArraySlice: _array_slice,
ops.ArrayCollect: reduction(
lambda arg: sa.func.array_agg(
Expand Down
Loading

0 comments on commit 899dce1

Please sign in to comment.