diff --git a/ibis/backends/pandas/kernels.py b/ibis/backends/pandas/kernels.py index da650d1211c3..61dc278c5562 100644 --- a/ibis/backends/pandas/kernels.py +++ b/ibis/backends/pandas/kernels.py @@ -144,6 +144,12 @@ def array_position_rowwise(row): return -1 +def array_remove_rowwise(row): + if row["arg"] is None: + return None + return [x for x in row["arg"] if x != row["other"]] + + def array_slice_rowwise(row): arg, start, stop = row["arg"], row["start"], row["stop"] if isnull(start) and isnull(stop): @@ -380,11 +386,12 @@ def wrapper(*args, **kwargs): ops.Repeat: lambda df: df["arg"] * df["times"], } + rowwise = { ops.ArrayContains: lambda row: row["other"] in row["arg"], ops.ArrayIndex: array_index_rowwise, ops.ArrayPosition: array_position_rowwise, - ops.ArrayRemove: lambda row: [x for x in row["arg"] if x != row["other"]], + ops.ArrayRemove: array_remove_rowwise, ops.ArrayRepeat: lambda row: np.tile(row["arg"], max(0, row["times"])), ops.ArraySlice: array_slice_rowwise, ops.ArrayUnion: lambda row: toolz.unique(row["left"] + row["right"]), diff --git a/ibis/backends/sql/compilers/bigquery.py b/ibis/backends/sql/compilers/bigquery.py index e9cb87e2e4ec..77fd5f6bdb24 100644 --- a/ibis/backends/sql/compilers/bigquery.py +++ b/ibis/backends/sql/compilers/bigquery.py @@ -526,7 +526,9 @@ def _unnest(self, expression, *, as_, offset=None): def visit_ArrayRemove(self, op, *, arg, other): name = sg.to_identifier(util.gen_name("bq_arr")) unnest = self._unnest(arg, as_=name) - return self.f.array(sg.select(name).from_(unnest).where(name.neq(other))) + both_null = sg.and_(name.is_(NULL), other.is_(NULL)) + cond = sg.or_(name.neq(other), both_null) + return self.f.array(sg.select(name).from_(unnest).where(cond)) def visit_ArrayDistinct(self, op, *, arg): name = util.gen_name("bq_arr") diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 2ad60977d2f9..ac669229be45 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -584,9 +584,10 @@ def visit_ArrayFilter(self, op, *, arg, param, body): return self.f.arrayFilter(func, arg) def visit_ArrayRemove(self, op, *, arg, other): - x = sg.to_identifier("x") - body = x.neq(other) - return self.f.arrayFilter(sge.Lambda(this=body, expressions=[x]), arg) + x = sg.to_identifier(util.gen_name("x")) + should_keep_null = sg.and_(x.is_(NULL), sg.not_(other.is_(NULL))) + cond = sg.or_(x.neq(other), should_keep_null) + return self.f.arrayFilter(sge.Lambda(this=cond, expressions=[x]), arg) def visit_ArrayUnion(self, op, *, left, right): arg = self.f.arrayConcat(left, right) diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index 77806461b74a..bcd45c82f28b 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -10,6 +10,7 @@ import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis import util from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler from ibis.backends.sql.datatypes import DuckDBType from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect @@ -168,10 +169,10 @@ def visit_ArrayIntersect(self, op, *, left, right): return self.f.list_filter(left, lamduh) def visit_ArrayRemove(self, op, *, arg, other): - param = sg.to_identifier("x") - body = param.neq(other) - lamduh = sge.Lambda(this=body, expressions=[param]) - return self.f.list_filter(arg, lamduh) + x = sg.to_identifier(util.gen_name("x")) + should_keep_null = sg.and_(x.is_(NULL), other.is_(sg.not_(NULL))) + cond = sg.or_(x.neq(other), should_keep_null) + return self.f.list_filter(arg, sge.Lambda(this=cond, expressions=[x])) def visit_ArrayUnion(self, op, *, left, right): arg = self.f.list_concat(left, right) diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 3d38fe2fbfeb..680c0f63e05b 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -601,10 +601,11 @@ def test_array_position(con, a, expected_array): @builtin_array @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.parametrize( - ("a"), + ("inp", "exp"), [ param( [[3, 2], [], [42, 2], [2, 2], []], + [[3], [], [42], [], []], id="including-empty-array", marks=[ pytest.mark.notyet( @@ -614,17 +615,34 @@ def test_array_position(con, a, expected_array): ) ], ), - param([[3, 2], [2], [42, 2], [2, 2], [2]], id="all-non-empty-arrays"), + param( + [[3, 2], [2], [42, 2], [2, 2], [2]], + [[3], [], [42], [], []], + id="all-non-empty-arrays", + ), + param( + [[3, 2, None], [None], [42, 2], [2, 2], None], + [[3, None], [None], [42], [], None], + id="including_null", + # marks=[ + # pytest.mark.broken( + # ["duckdb"], + # raises=AssertionError, + # reason="not implmented correctly", + # ), + # ], + ), ], ) -def test_array_remove(con, a): - t = ibis.memtable({"a": a}) +def test_array_remove(con, inp, exp): + t = ibis.memtable({"a": inp}) expr = t.a.remove(2) result = con.execute(expr) - expected = pd.Series([[3], [], [42], [], []], dtype="object") - assert frozenset(map(tuple, result.values)) == frozenset( - map(tuple, expected.values) - ) + expected = pd.Series(exp, dtype="object") + + assert frozenset( + tuple(v) if v is not None else None for v in result.values + ) == frozenset(tuple(v) if v is not None else None for v in expected.values) @builtin_array