Skip to content

Commit

Permalink
feat: support type kwarg in array() and map()
Browse files Browse the repository at this point in the history
fixes #8289

This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in
some cases.

The big structural change is that now the core Operations for Array and Structs have a different internal representation, so they can distringuish between
- the entire value is NULL
- the contained values are NULL

Before, ops.Array held onto a `VarTuple[Value]`. So the contained Values could be NULL, but there was no way to say the entire thing was null. Now, ops.Array stores a `None | VarTuple[Value]`. The same thing for ops.StructValue. ops.Map didn't suffer from this, because it stores a `ops.Array`s internally, so since `ops.Array` can distinguish between entirely-NULL and contains-NULL, so can ops.Map

A fallout of this is that ops.Array needs a way to explicitly store its dtype. Before, it derived its dtype based on the dtype of its args. But
now that `None` is a valid value, it is now possible for there to be no values to inspect! So the Op actually stores its dtype explicitly. If you pass in values, then supplying the dtype on construction is optional, we go back to the old behavior of deriving it from the inputs.

This requires the backend compilers to now deal with that case.

Several of the backends were always broken here, they just weren't getting caught. I marked them as broken,
we can fix them in a followup.

You can test this locally with eg
`pytest -m <backend> -k factory ibis/backends/tests/test_array.py  ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py`

Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns

Also, fix executing NULL arrays on pandas.

Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes.

Also, fix casting structs on pandas.
See #8687

Also, support passing in None to all these constructors.

Also, error when the value type can't be inferred from empty python literals
(eg what is the value type for the elements of []?)

Also, make the type argument for struct() always have an effect, not just when passing in python literals.
So basically it can act like a cast.

Also, make these constructors idempotent.
  • Loading branch information
NickCrews committed Mar 22, 2024
1 parent 133a1f1 commit 55497bb
Show file tree
Hide file tree
Showing 19 changed files with 382 additions and 108 deletions.
8 changes: 4 additions & 4 deletions ibis/backends/clickhouse/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,10 @@ def visit_GroupConcat(self, op, *, arg, sep, where):
def visit_Cot(self, op, *, arg):
return 1.0 / self.f.tan(arg)

def visit_StructColumn(self, op, *, values, names):
# ClickHouse struct types cannot be nullable
# (non-nested fields can be nullable)
return self.cast(self.f.tuple(*values), op.dtype.copy(nullable=False))
def visit_StructColumn(self, op, *, values, names, dtype):
if values is None:
return self.cast(NULL, dtype)
return self.cast(self.f.tuple(*values), dtype)

def visit_Clip(self, op, *, arg, lower, upper):
if upper is not None:
Expand Down
21 changes: 19 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,26 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.StructColumn, names, values, dtype):
if values is None:
return None

def process_row(row):
return {
name: DaskConverter.convert_scalar(value, dty)
for name, value, dty in zip(names, row, dtype.fields.values())
}

pdt = PandasType.from_ibis(op.dtype)
return cls.rowwise(process_row, values, name=op.name, dtype=pdt)

@classmethod
def visit(cls, op: ops.Array, exprs, dtype):
if exprs is None:
return None
pdt = PandasType.from_ibis(op.dtype)
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=pdt
)

@classmethod
Expand Down
29 changes: 20 additions & 9 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,20 @@ def _aggregate(self, funcname: str, *args, where):
return sge.Filter(this=expr, expression=sge.Where(this=where))
return expr

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[
sge.PropertyEQ(
this=sg.to_identifier(name, quoted=self.quoted), expression=value
)
for name, value in zip(names, values)
]
)
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
val = NULL
else:
val = sge.Struct.from_arg_list(
[
sge.PropertyEQ(
this=sg.to_identifier(name, quoted=self.quoted),
expression=value,
)
for name, value in zip(names, values)
]
)
return self.cast(val, dtype)

def visit_ArrayDistinct(self, op, *, arg):
return self.if_(
Expand Down Expand Up @@ -199,6 +204,12 @@ def visit_ArrayZip(self, op, *, arg):
any_arg_null = sg.or_(*(arr.is_(NULL) for arr in arg))
return self.if_(any_arg_null, NULL, zipped_arrays)

def visit_Map(self, op, *, keys, values):
# workaround for https://github.com/ibis-project/ibis/issues/8632
regular = self.f.map(keys, values)
either_null = sg.or_(keys.is_(NULL), values.is_(NULL))
return self.if_(either_null, NULL, regular)

def visit_MapGet(self, op, *, arg, key, default):
return self.f.ifnull(
self.f.list_extract(self.f.element_at(arg, key), 1), default
Expand Down
17 changes: 16 additions & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,22 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.StructColumn, names, values, dtype):
if values is None:
return None

def process_row(row):
return {
name: PandasConverter.convert_scalar(value, dty)
for name, value, dty in zip(names, row, dtype.fields.values())
}

return cls.rowwise(process_row, values)

@classmethod
def visit(cls, op: ops.Array, exprs, dtype):
if exprs is None:
return None
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,14 @@ def round_serieswise(arg, digits):
return np.round(arg, digits).astype("float64")


def map_(row: pd.Series) -> dict:
k = row["keys"]
v = row["values"]
if k is None or v is None:
return None
return dict(zip(k, v))


reductions = {
ops.Min: lambda x: x.min(),
ops.Max: lambda x: x.max(),
Expand Down Expand Up @@ -362,7 +370,7 @@ def round_serieswise(arg, digits):
ops.EndsWith: lambda row: row["arg"].endswith(row["end"]),
ops.IntegerRange: integer_range_rowwise,
ops.JSONGetItem: lambda row: safe_json_getitem(row["arg"], row["index"]),
ops.Map: lambda row: dict(zip(row["keys"], row["values"])),
ops.Map: map_,
ops.MapGet: lambda row: safe_get(row["arg"], row["key"], row["default"]),
ops.MapContains: lambda row: safe_contains(row["arg"], row["key"]),
ops.MapMerge: lambda row: safe_merge(row["left"], row["right"]),
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ def struct_field(op, **kw):

@translate.register(ops.StructColumn)
def struct_column(op, **kw):
if op.values is None:
return pl.lit(None)
fields = [translate(v, **kw).alias(k) for k, v in zip(op.names, op.values)]
return pl.struct(fields)

Expand Down Expand Up @@ -969,8 +971,13 @@ def array_concat(op, **kw):

@translate.register(ops.Array)
def array_column(op, **kw):
if op.exprs is None:
return pl.lit(None, dtype=PolarsType.from_ibis(op.dtype))
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
if len(cols) > 0:
return pl.concat_list(cols)
else:
return pl.lit([], dtype=PolarsType.from_ibis(op.dtype))


@translate.register(ops.ArrayCollect)
Expand Down
11 changes: 8 additions & 3 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,10 @@ def visit_StructField(self, op, *, arg, field):
op.dtype,
)

def visit_StructColumn(self, op, *, names, values):
return self.f.row(*map(self.cast, values, op.dtype.types))
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
return self.cast(self.f.jsonb_build_object(), op.dtype)
return self.f.row(*map(self.cast, values, dtype.types))

def visit_ToJSONArray(self, op, *, arg):
return self.if_(
Expand All @@ -330,7 +332,10 @@ def visit_ToJSONArray(self, op, *, arg):
)

def visit_Map(self, op, *, keys, values):
return self.f.map(self.f.array(*keys), self.f.array(*values))
# map(["a", "b"], NULL) results in {"a": NULL, "b": NULL} in regular postgres,
# so we need to modify it to return NULL instead
regular = self.f.map(keys, values)
return self.if_(values.is_(NULL), NULL, regular)

def visit_MapLength(self, op, *, arg):
return self.f.cardinality(self.f.akeys(arg))
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def visit_IntegerRange(self, op, *, start, stop, step):
step.neq(0), self.f.array_generate_range(start, stop, step), self.f.array()
)

def visit_StructColumn(self, op, *, names, values):
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
return self.cast(NULL, dtype)
return self.f.object_construct_keep_null(
*itertools.chain.from_iterable(zip(names, values))
)
Expand Down
23 changes: 17 additions & 6 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,13 +951,24 @@ def visit_ExistsSubquery(self, op, *, rel):
def visit_InSubquery(self, op, *, rel, needle):
return needle.isin(query=rel.this)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
if exprs is None:
vals = NULL
else:
vals = self.f.array(*exprs)
return self.cast(vals, dtype)

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
[value.as_(name, quoted=self.quoted) for name, value in zip(names, values)]
)
def visit_StructColumn(self, op, *, names, values, dtype):
if values is None:
vals = NULL
else:
vals = sge.Struct.from_arg_list(
[
value.as_(name, quoted=self.quoted)
for name, value in zip(names, values)
]
)
return self.cast(vals, dtype)

def visit_StructField(self, op, *, arg, field):
return sge.Dot(this=arg, expression=sg.to_identifier(field, quoted=self.quoted))
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,10 @@ class ClickHouseType(SqlglotType):
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not (dtype.is_map() or dtype.is_array()):
# map cannot be nullable in clickhouse
# nested types cannot be nullable in clickhouse
if dtype.nullable and not (
dtype.is_map() or dtype.is_array() or dtype.is_struct()
):
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ
Expand Down
39 changes: 39 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError

pytestmark = [
pytest.mark.never(
Expand Down Expand Up @@ -70,6 +71,43 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert con.execute(a2) == [1, 2, 3]

typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(a, type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -107,6 +145,7 @@ def test_array_scalar(con):


@pytest.mark.notimpl(["flink", "polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl("postgres", raises=PsycoPg2SyntaxError)
def test_array_repeat(con):
expr = ibis.array([1.0, 2.0]) * 2

Expand Down
45 changes: 43 additions & 2 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import ibis
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis.backends.tests.errors import Py4JJavaError
from ibis.backends.tests.errors import (
ClickHouseDatabaseError,
Py4JJavaError,
)
from ibis.common.annotations import ValidationError

pytestmark = [
pytest.mark.never(
Expand Down Expand Up @@ -39,6 +43,43 @@
)


@mark_notimpl_risingwave_hstore
@mark_notyet_postgres
def test_map_factory(con):
m = ibis.map({"a": 1, "b": 2})
assert con.execute(m) == {"a": 1, "b": 2}

m2 = ibis.map(m)
assert con.execute(m2) == {"a": 1, "b": 2}

typed = ibis.map({"a": 1, "b": 2}, type="map<string, string>")
assert con.execute(typed) == {"a": "1", "b": "2"}

typed2 = ibis.map(m, type="map<string, string>")
assert con.execute(typed2) == {"a": "1", "b": "2"}


@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
@mark_notimpl_risingwave_hstore
def test_map_factory_empty(con):
with pytest.raises(ValidationError):
ibis.map({})
empty_typed = ibis.map({}, type="map<string, string>")
assert empty_typed.type() == dt.Map(key_type=dt.string, value_type=dt.string)
assert con.execute(empty_typed) == {}


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_map_factory_null(con):
with pytest.raises(ValidationError):
ibis.map(None)
null_typed = ibis.map(None, type="map<string, string>")
assert null_typed.type() == dt.Map(key_type=dt.string, value_type=dt.string)
assert con.execute(null_typed) is None


@pytest.mark.notimpl(["pandas", "dask"])
def test_map_table(backend):
table = backend.map
Expand Down Expand Up @@ -474,6 +515,6 @@ def test_map_keys_unnest(backend):

@mark_notimpl_risingwave_hstore
def test_map_contains_null(con):
expr = ibis.map(["a"], ibis.literal([None], type="array<string>"))
expr = ibis.map(["a"], ibis.array([None], type="array<string>"))
assert con.execute(expr.contains("a"))
assert not con.execute(expr.contains("b"))
Loading

0 comments on commit 55497bb

Please sign in to comment.