Skip to content

Commit

Permalink
fix(mutate/select): ensure that unsplatted dictionaries work in `muta…
Browse files Browse the repository at this point in the history
…te`and`select` APIs (#8014)

We were not handling unsplatted dicts in mutate/select, this PR fixes
that. Fixes #8013.
  • Loading branch information
cpcloud authored Jan 18, 2024
1 parent 5bde8da commit 8ed19ea
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 23 deletions.
11 changes: 11 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,3 +1059,14 @@ def test_repr_timestamp_array(con, monkeypatch):
assert ibis.options.default_backend is con
expr = ibis.array(pd.date_range("2010-01-01", "2010-01-03", freq="D").tolist())
assert repr(expr)


@pytest.mark.notyet(
["dask", "datafusion", "flink", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
def test_unnest_range(con):
expr = ibis.range(2).unnest().name("x").as_table().mutate({"y": 1.0})
result = con.execute(expr)
expected = pd.DataFrame({"x": np.array([0, 1], dtype="int8"), "y": [1.0, 1.0]})
tm.assert_frame_equal(result, expected)
13 changes: 13 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,3 +1741,16 @@ def test_simple_memtable_construct(con):
# because memtables have a unique name per table per process, so smoke test
# it
assert str(ibis.to_sql(expr, dialect=con.name)).startswith("SELECT")


def test_select_mutate_with_dict(backend):
t = backend.functional_alltypes
expr = t.mutate({"a": 1.0}).select("a").limit(1)

result = expr.execute()
expected = pd.DataFrame({"a": [1.0]})

backend.assert_frame_equal(result, expected)

expr = t.select({"a": ibis.literal(1.0)}).limit(1)
backend.assert_frame_equal(result, expected)
56 changes: 33 additions & 23 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,17 +1800,24 @@ def mutate(
import ibis.expr.analysis as an

exprs = [] if exprs is None else util.promote_list(exprs)
exprs = itertools.chain(
itertools.chain.from_iterable(
util.promote_list(_ensure_expr(self, expr)) for expr in exprs
),
(
e.name(name)
for name, expr in mutations.items()
for e in util.promote_list(_ensure_expr(self, expr))
),

new_exprs = []

for expr in exprs:
if isinstance(expr, Mapping):
new_exprs.extend(
_ensure_expr(self, val).name(name) for name, val in expr.items()
)
else:
new_exprs.extend(util.promote_list(_ensure_expr(self, expr)))

new_exprs.extend(
e.name(name)
for name, expr in mutations.items()
for e in util.promote_list(_ensure_expr(self, expr))
)
mutation_exprs = an.get_mutation_exprs(list(exprs), self)

mutation_exprs = an.get_mutation_exprs(new_exprs, self)
return self.select(mutation_exprs)

def select(
Expand Down Expand Up @@ -1993,31 +2000,34 @@ def select(
import ibis.expr.analysis as an
from ibis.selectors import Selector

exprs = [
e
for expr in exprs
for e in (
expr.expand(self)
if isinstance(expr, Selector)
else map(self._ensure_expr, util.promote_list(expr))
)
]
exprs.extend(
new_exprs = []

for expr in exprs:
if isinstance(expr, Selector):
new_exprs.extend(expr.expand(self))
elif isinstance(expr, Mapping):
new_exprs.extend(
self._ensure_expr(value).name(name) for name, value in expr.items()
)
else:
new_exprs.extend(map(self._ensure_expr, util.promote_list(expr)))

new_exprs.extend(
self._ensure_expr(expr).name(name) for name, expr in named_exprs.items()
)

if not exprs:
if not new_exprs:
raise com.IbisTypeError(
"You must select at least one column for a valid projection"
)
for ex in exprs:
for ex in new_exprs:
if not isinstance(ex, Expr):
raise com.IbisTypeError(
"All arguments to `.select` must be coerceable to "
f"expressions - got {type(ex)!r}"
)

op = an.Projector(self, exprs).get_result()
op = an.Projector(self, new_exprs).get_result()

return op.to_expr()

Expand Down
2 changes: 2 additions & 0 deletions ibis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def promote_list(val: V | Sequence[V]) -> list[V]:
"""
if isinstance(val, list):
return val
elif isinstance(val, dict):
return [val]
elif is_iterable(val):
return list(val)
elif val is None:
Expand Down

0 comments on commit 8ed19ea

Please sign in to comment.