Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(datatypes): ensure that array construction supports literals and infers their shape from its inputs #8049

Merged
merged 11 commits into from
Jan 22, 2024
4 changes: 2 additions & 2 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@


def _array_column(translator, op):
return "[{}]".format(", ".join(map(translator.translate, op.cols)))
return "[{}]".format(", ".join(map(translator.translate, op.exprs)))

Check warning on line 132 in ibis/backends/bigquery/registry.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/bigquery/registry.py#L132

Added line #L132 was not covered by tests


def _array_index(translator, op):
Expand Down Expand Up @@ -912,7 +912,7 @@
ops.StructColumn: _struct_column,
ops.ArrayCollect: _array_agg,
ops.ArrayConcat: _array_concat,
ops.ArrayColumn: _array_column,
ops.Array: _array_column,
ops.ArrayIndex: _array_index,
ops.ArrayLength: unary("ARRAY_LENGTH"),
ops.ArrayRepeat: _array_repeat,
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ def _translate(op, *, arg, where, **_):
return _translate


@translate_val.register(ops.ArrayColumn)
def _array_column(op, *, cols, **_):
return F.array(*cols)
@translate_val.register(ops.Array)
def _array_column(op, *, exprs, **_):
return F.array(*exprs)


@translate_val.register(ops.StructColumn)
Expand Down
22 changes: 19 additions & 3 deletions ibis/backends/dask/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dask.dataframe as dd
import dask.dataframe.groupby as ddgb
import numpy as np
import pandas as pd

import ibis.expr.operations as ops
from ibis.backends.dask.core import execute
Expand Down Expand Up @@ -34,10 +35,25 @@
)


@execute_node.register(ops.ArrayColumn, tuple)
@execute_node.register(ops.Array, tuple)
def execute_array_column(op, cols, **kwargs):
cols = [execute(arg, **kwargs) for arg in cols]
df = dd.concat(cols, axis=1)
vals = [execute(arg, **kwargs) for arg in cols]

length = next((len(v) for v in vals if isinstance(v, dd.Series)), None)
if length is None:
return vals

n_partitions = next((v.npartitions for v in vals if isinstance(v, dd.Series)), None)

def ensure_series(v):
if isinstance(v, dd.Series):
return v
else:
return dd.from_pandas(pd.Series([v] * length), npartitions=n_partitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also worth double checking me here, I'm not familiar with dask, so I was mostly just trying to make this not crash, but I'm not sure if setting n_partitions this way makes sense/is the most performant.


# dd.concat() can only handle array-likes.
# If we're given a scalar, we need to broadcast it as a Series.
df = dd.concat([ensure_series(v) for v in vals], axis=1)
return df.apply(
lambda row: np.array(row, dtype=object), axis=1, meta=(None, "object")
)
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ def _not_null(op, *, arg, **_):
return sg.not_(arg.is_(NULL))


@translate_val.register(ops.ArrayColumn)
def array_column(op, *, cols, **_):
return F.make_array(*cols)
@translate_val.register(ops.Array)
def array_column(op, *, exprs, **_):
return F.make_array(*exprs)


@translate_val.register(ops.ArrayRepeat)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def _array_remove(t, op):

operation_registry.update(
{
ops.ArrayColumn: (
ops.Array: (
lambda t, op: sa.cast(
sa.func.list_value(*map(t.translate, op.cols)),
sa.func.list_value(*map(t.translate, op.exprs)),
t.get_sqla_type(op.dtype),
)
),
Expand Down
39 changes: 27 additions & 12 deletions ibis/backends/pandas/execution/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,23 @@
from collections.abc import Collection


@execute_node.register(ops.ArrayColumn, tuple)
def execute_array_column(op, cols, **kwargs):
cols = [execute(arg, **kwargs) for arg in cols]
df = pd.concat(cols, axis=1)
@execute_node.register(ops.Array, tuple)
def execute_array(op, cols, **kwargs):
vals = [execute(arg, **kwargs) for arg in cols]
length = next((len(v) for v in vals if isinstance(v, pd.Series)), None)

if length is None:
return vals

def ensure_series(v):
if isinstance(v, pd.Series):
return v
else:
return pd.Series(v, index=range(length))

# pd.concat() can only handle array-likes.
# If we're given a scalar, we need to broadcast it as a Series.
df = pd.concat([ensure_series(v) for v in vals], axis=1)
return df.apply(lambda row: np.array(row, dtype=object), axis=1)


Expand All @@ -29,7 +42,7 @@ def execute_array_length(op, data, **kwargs):
return data.apply(len)


@execute_node.register(ops.ArrayLength, np.ndarray)
@execute_node.register(ops.ArrayLength, (list, np.ndarray))
def execute_array_length_scalar(op, data, **kwargs):
return len(data)

Expand All @@ -39,7 +52,7 @@ def execute_array_slice(op, data, start, stop, **kwargs):
return data.apply(operator.itemgetter(slice(start, stop)))


@execute_node.register(ops.ArraySlice, np.ndarray, int, (int, type(None)))
@execute_node.register(ops.ArraySlice, (list, np.ndarray), int, (int, type(None)))
def execute_array_slice_scalar(op, data, start, stop, **kwargs):
return data[start:stop]

Expand All @@ -53,15 +66,15 @@ def execute_array_index(op, data, index, **kwargs):
)


@execute_node.register(ops.ArrayIndex, np.ndarray, int)
@execute_node.register(ops.ArrayIndex, (list, np.ndarray), int)
def execute_array_index_scalar(op, data, index, **kwargs):
try:
return data[index]
except IndexError:
return None


@execute_node.register(ops.ArrayContains, np.ndarray, object)
@execute_node.register(ops.ArrayContains, (list, np.ndarray), object)
def execute_node_contains_value_array(op, haystack, needle, **kwargs):
return needle in haystack

Expand Down Expand Up @@ -91,7 +104,7 @@ def execute_array_concat_series(op, first, second, *args, **kwargs):


@execute_node.register(
ops.ArrayConcat, np.ndarray, pd.Series, [(pd.Series, np.ndarray)]
ops.ArrayConcat, (list, np.ndarray), pd.Series, [(pd.Series, list, np.ndarray)]
)
def execute_array_concat_mixed_left(op, left, right, *args, **kwargs):
# ArrayConcat given a column (pd.Series) and a scalar (np.ndarray).
Expand All @@ -102,15 +115,17 @@ def execute_array_concat_mixed_left(op, left, right, *args, **kwargs):


@execute_node.register(
ops.ArrayConcat, pd.Series, np.ndarray, [(pd.Series, np.ndarray)]
ops.ArrayConcat, pd.Series, (list, np.ndarray), [(pd.Series, list, np.ndarray)]
)
def execute_array_concat_mixed_right(op, left, right, *args, **kwargs):
# Broadcast `right` to the length of `left`
right = np.tile(right, (len(left), 1))
return _concat_iterables_to_series(left, right)


@execute_node.register(ops.ArrayConcat, np.ndarray, np.ndarray, [np.ndarray])
@execute_node.register(
ops.ArrayConcat, (list, np.ndarray), (list, np.ndarray), [(list, np.ndarray)]
)
def execute_array_concat_scalar(op, left, right, *args, **kwargs):
return np.concatenate([left, right, *args])

Expand All @@ -122,7 +137,7 @@ def execute_array_repeat(op, data, n, **kwargs):
return pd.Series(np.tile(arr, n) for arr in data)


@execute_node.register(ops.ArrayRepeat, np.ndarray, int)
@execute_node.register(ops.ArrayRepeat, (list, np.ndarray), int)
def execute_array_repeat_scalar(op, data, n, **kwargs):
# Negative n will be treated as 0 (repeat will produce empty array)
return np.tile(data, max(n, 0))
Expand Down
22 changes: 22 additions & 0 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,28 @@
return data.map(cast_to_array)


@execute_node.register(ops.Cast, list, dt.Array)
def execute_cast_list_array(op, data, type, **kwargs):
value_type = type.value_type
numpy_type = constants.IBIS_TYPE_TO_PANDAS_TYPE.get(value_type, None)
if numpy_type is None:
raise ValueError(

Check warning on line 153 in ibis/backends/pandas/execution/generic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pandas/execution/generic.py#L153

Added line #L153 was not covered by tests
"Array value type must be a primitive type "
"(e.g., number, string, or timestamp)"
)

def cast_to_array(array, numpy_type=numpy_type):
elems = [
el if el is None else np.array(el, dtype=numpy_type).item() for el in array
]
try:
return np.array(elems, dtype=numpy_type)
except TypeError:
return np.array(elems)

Check warning on line 165 in ibis/backends/pandas/execution/generic.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/pandas/execution/generic.py#L164-L165

Added lines #L164 - L165 were not covered by tests

return cast_to_array(data)


@execute_node.register(ops.Cast, pd.Series, dt.Timestamp)
def execute_cast_series_timestamp(op, data, type, **kwargs):
arg = op.arg
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,9 +888,9 @@ def array_concat(op, **kw):
return result


@translate.register(ops.ArrayColumn)
@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.cols]
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def _range(t, op):
# array operations
ops.ArrayLength: unary(sa.func.cardinality),
ops.ArrayCollect: reduction(sa.func.array_agg),
ops.ArrayColumn: (lambda t, op: pg.array(list(map(t.translate, op.cols)))),
ops.Array: (lambda t, op: pg.array(list(map(t.translate, op.exprs)))),
ops.ArraySlice: _array_slice(
index_converter=_neg_idx_to_pos,
array_length=sa.func.cardinality,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,9 +1634,9 @@ def compile_interval_from_integer(t, op, **kwargs):
# -------------------------- Array Operations ----------------------------


@compiles(ops.ArrayColumn)
@compiles(ops.Array)
def compile_array_column(t, op, **kwargs):
cols = [t.translate(col, **kwargs) for col in op.cols]
cols = [t.translate(col, **kwargs) for col in op.exprs]
return F.array(cols)


Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def _timestamp_range(t, op):
ops.ArrayConcat: varargs(
lambda *args: functools.reduce(sa.func.array_cat, args)
),
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
*map(t.translate, op.cols)
),
ops.Array: lambda t, op: sa.func.array_construct(*map(t.translate, op.exprs)),
ops.ArraySlice: _array_slice,
ops.ArrayCollect: reduction(
lambda arg: sa.func.array_agg(
Expand Down
Loading
Loading