Skip to content

Commit

Permalink
fix(mutate/select): ensure that unsplatted dictionaries work in `muta…
Browse files Browse the repository at this point in the history
…te` and `select` APIs
  • Loading branch information
cpcloud committed Jan 18, 2024
1 parent 1d1417d commit 6cb933f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
52 changes: 29 additions & 23 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,17 +1800,22 @@ def mutate(
import ibis.expr.analysis as an

exprs = [] if exprs is None else util.promote_list(exprs)
exprs = itertools.chain(
itertools.chain.from_iterable(
util.promote_list(_ensure_expr(self, expr)) for expr in exprs
),
(
e.name(name)
for name, expr in mutations.items()
for e in util.promote_list(_ensure_expr(self, expr))
),

new_exprs = []

for expr in exprs:
if isinstance(expr, Mapping):
new_exprs.extend(val.name(name) for name, val in expr.items())
else:
new_exprs.extend(util.promote_list(_ensure_expr(self, expr)))

new_exprs.extend(
e.name(name)
for name, expr in mutations.items()
for e in util.promote_list(_ensure_expr(self, expr))
)
mutation_exprs = an.get_mutation_exprs(list(exprs), self)

mutation_exprs = an.get_mutation_exprs(new_exprs, self)
return self.select(mutation_exprs)

def select(
Expand Down Expand Up @@ -1993,31 +1998,32 @@ def select(
import ibis.expr.analysis as an
from ibis.selectors import Selector

exprs = [
e
for expr in exprs
for e in (
expr.expand(self)
if isinstance(expr, Selector)
else map(self._ensure_expr, util.promote_list(expr))
)
]
exprs.extend(
new_exprs = []

for expr in exprs:
if isinstance(expr, Selector):
new_exprs.extend(expr.expand(self))
elif isinstance(expr, Mapping):
new_exprs.extend(value.name(name) for name, value in expr.items())
else:
new_exprs.extend(map(self._ensure_expr, util.promote_list(expr)))

new_exprs.extend(
self._ensure_expr(expr).name(name) for name, expr in named_exprs.items()
)

if not exprs:
if not new_exprs:
raise com.IbisTypeError(
"You must select at least one column for a valid projection"
)
for ex in exprs:
for ex in new_exprs:
if not isinstance(ex, Expr):
raise com.IbisTypeError(
"All arguments to `.select` must be coerceable to "
f"expressions - got {type(ex)!r}"
)

op = an.Projector(self, exprs).get_result()
op = an.Projector(self, new_exprs).get_result()

return op.to_expr()

Expand Down
2 changes: 2 additions & 0 deletions ibis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def promote_list(val: V | Sequence[V]) -> list[V]:
"""
if isinstance(val, list):
return val
elif isinstance(val, dict):
return [val]
elif is_iterable(val):
return list(val)
elif val is None:
Expand Down

0 comments on commit 6cb933f

Please sign in to comment.