Skip to content

Commit

Permalink
fix(duckdb): make sure that array1.union(array2) null handling matc…
Browse files Browse the repository at this point in the history
…hes across backends
  • Loading branch information
cpcloud committed Aug 5, 2023
1 parent 04f5a11 commit 849dea4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
20 changes: 17 additions & 3 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,27 @@ def _try_cast(t, op):
ops.ArrayPosition: fixed_arity(
lambda lst, el: sa.func.list_indexof(lst, el) - 1, 2
),
ops.ArrayDistinct: fixed_arity(sa.func.list_distinct, 1),
ops.ArrayDistinct: fixed_arity(
lambda arg: if_(
arg.is_(sa.null()),
sa.null(),
# append a null if the input array has a null
sa.func.list_distinct(arg)
+ if_(
# list_count doesn't count nulls
sa.func.list_count(arg) < sa.func.array_length(arg),
sa.func.list_value(sa.null()),
sa.func.list_value(),
),
),
1,
),
ops.ArraySort: fixed_arity(sa.func.list_sort, 1),
ops.ArrayRemove: lambda t, op: _array_filter(
t, ops.ArrayFilter(op.arg, flip(ops.NotEquals, op.other))
),
ops.ArrayUnion: fixed_arity(
lambda left, right: sa.func.list_distinct(sa.func.list_cat(left, right)), 2
ops.ArrayUnion: lambda t, op: t.translate(
ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right)))
),
ops.ArrayZip: _array_zip,
ops.DayOfWeekName: unary(sa.func.dayname),
Expand Down
11 changes: 3 additions & 8 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,10 @@ def test_array_remove(backend, con):
reason="argument passes none of the following rules:....",
)
def test_array_unique(backend, con):
t = ibis.memtable({"a": [[1, 3, 3], [], [42, 42], []]})
t = ibis.memtable({"a": [[1, 3, 3], [], [42, 42], [], [None], None]})
expr = t.a.unique()
result = con.execute(expr).map(set, na_action="ignore")
expected = pd.Series([{3, 1}, set(), {42}, set()], dtype="object")
expected = pd.Series([{3, 1}, set(), {42}, set(), {None}, None], dtype="object")
backend.assert_series_equal(result, expected, check_names=False)


Expand All @@ -646,11 +646,6 @@ def test_array_sort(backend, con):
["dask", "datafusion", "impala", "mssql", "pandas", "polars", "postgres"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
["snowflake", "trino", "pyspark"],
raises=AssertionError,
reason="array_distinct([NULL]) seems to differ from other backends",
)
@pytest.mark.notyet(
["bigquery"],
raises=BadRequest,
Expand All @@ -665,7 +660,7 @@ def test_array_union(con):
t = ibis.memtable({"a": [[3, 2], [], []], "b": [[1, 3], [None], [5]]})
expr = t.a.union(t.b)
result = con.execute(expr).map(set, na_action="ignore")
expected = pd.Series([{1, 2, 3}, set(), {5}], dtype="object")
expected = pd.Series([{1, 2, 3}, {None}, {5}], dtype="object")
assert len(result) == len(expected)

for i, (lhs, rhs) in enumerate(zip(result, expected)):
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def union(self, other: ir.ArrayValue) -> ir.ArrayValue:
│ array<int64> │
├────────────────────────┤
│ [1, 2, ... +1] │
│ []
│ [None]
│ [5] │
└────────────────────────┘
>>> t.arr1.union(t.arr2).contains(3)
Expand Down

0 comments on commit 849dea4

Please sign in to comment.