Skip to content

Commit

Permalink
fix(duckdb): ensure that array remove doesn't remove NULLs
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Jul 31, 2024
1 parent 4d110a0 commit f0c3be4
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 17 deletions.
9 changes: 8 additions & 1 deletion ibis/backends/pandas/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ def array_position_rowwise(row):
return -1


def array_remove_rowwise(row):
if row["arg"] is None:
return None
return [x for x in row["arg"] if x != row["other"]]


def array_slice_rowwise(row):
arg, start, stop = row["arg"], row["start"], row["stop"]
if isnull(start) and isnull(stop):
Expand Down Expand Up @@ -380,11 +386,12 @@ def wrapper(*args, **kwargs):
ops.Repeat: lambda df: df["arg"] * df["times"],
}


rowwise = {
ops.ArrayContains: lambda row: row["other"] in row["arg"],
ops.ArrayIndex: array_index_rowwise,
ops.ArrayPosition: array_position_rowwise,
ops.ArrayRemove: lambda row: [x for x in row["arg"] if x != row["other"]],
ops.ArrayRemove: array_remove_rowwise,
ops.ArrayRepeat: lambda row: np.tile(row["arg"], max(0, row["times"])),
ops.ArraySlice: array_slice_rowwise,
ops.ArrayUnion: lambda row: toolz.unique(row["left"] + row["right"]),
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,9 @@ def _unnest(self, expression, *, as_, offset=None):
def visit_ArrayRemove(self, op, *, arg, other):
name = sg.to_identifier(util.gen_name("bq_arr"))
unnest = self._unnest(arg, as_=name)
return self.f.array(sg.select(name).from_(unnest).where(name.neq(other)))
both_null = sg.and_(name.is_(NULL), other.is_(NULL))
cond = sg.or_(name.neq(other), both_null)
return self.f.array(sg.select(name).from_(unnest).where(cond))

def visit_ArrayDistinct(self, op, *, arg):
name = util.gen_name("bq_arr")
Expand Down
7 changes: 4 additions & 3 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,10 @@ def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.arrayFilter(func, arg)

def visit_ArrayRemove(self, op, *, arg, other):
x = sg.to_identifier("x")
body = x.neq(other)
return self.f.arrayFilter(sge.Lambda(this=body, expressions=[x]), arg)
x = sg.to_identifier(util.gen_name("x"))
should_keep_null = sg.and_(x.is_(NULL), sg.not_(other.is_(NULL)))
cond = sg.or_(x.neq(other), should_keep_null)
return self.f.arrayFilter(sge.Lambda(this=cond, expressions=[x]), arg)

def visit_ArrayUnion(self, op, *, left, right):
arg = self.f.arrayConcat(left, right)
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DuckDBType
from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect
Expand Down Expand Up @@ -168,10 +169,10 @@ def visit_ArrayIntersect(self, op, *, left, right):
return self.f.list_filter(left, lamduh)

def visit_ArrayRemove(self, op, *, arg, other):
param = sg.to_identifier("x")
body = param.neq(other)
lamduh = sge.Lambda(this=body, expressions=[param])
return self.f.list_filter(arg, lamduh)
x = sg.to_identifier(util.gen_name("x"))
should_keep_null = sg.and_(x.is_(NULL), other.is_(sg.not_(NULL)))
cond = sg.or_(x.neq(other), should_keep_null)
return self.f.list_filter(arg, sge.Lambda(this=cond, expressions=[x]))

def visit_ArrayUnion(self, op, *, left, right):
arg = self.f.list_concat(left, right)
Expand Down
34 changes: 26 additions & 8 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,11 @@ def test_array_position(con, a, expected_array):
@builtin_array
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.parametrize(
("a"),
("inp", "exp"),
[
param(
[[3, 2], [], [42, 2], [2, 2], []],
[[3], [], [42], [], []],
id="including-empty-array",
marks=[
pytest.mark.notyet(
Expand All @@ -614,17 +615,34 @@ def test_array_position(con, a, expected_array):
)
],
),
param([[3, 2], [2], [42, 2], [2, 2], [2]], id="all-non-empty-arrays"),
param(
[[3, 2], [2], [42, 2], [2, 2], [2]],
[[3], [], [42], [], []],
id="all-non-empty-arrays",
),
param(
[[3, 2, None], [None], [42, 2], [2, 2], None],
[[3, None], [None], [42], [], None],
id="including_null",
# marks=[
# pytest.mark.broken(
# ["duckdb"],
# raises=AssertionError,
# reason="not implmented correctly",
# ),
# ],
),
],
)
def test_array_remove(con, a):
t = ibis.memtable({"a": a})
def test_array_remove(con, inp, exp):
t = ibis.memtable({"a": inp})
expr = t.a.remove(2)
result = con.execute(expr)
expected = pd.Series([[3], [], [42], [], []], dtype="object")
assert frozenset(map(tuple, result.values)) == frozenset(
map(tuple, expected.values)
)
expected = pd.Series(exp, dtype="object")

assert frozenset(
tuple(v) if v is not None else None for v in result.values
) == frozenset(tuple(v) if v is not None else None for v in expected.values)


@builtin_array
Expand Down

0 comments on commit f0c3be4

Please sign in to comment.