From 5877d0b7a92044bb82637225e31d27543ddaadc7 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Fri, 14 Oct 2022 12:44:56 -0500 Subject: [PATCH] fix: fix `table.mutate` with deferred named expressions --- ibis/expr/analysis.py | 1 - ibis/expr/types/relations.py | 18 +++++++++++------- ibis/tests/expr/test_table.py | 36 ++++++++++++++++++++++++++++++----- 3 files changed, 42 insertions(+), 13 deletions(-) diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 24879a78534e..fa649956d0ef 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -229,7 +229,6 @@ def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | # name does not exist in the original table. # Given these two data structures, we can compute the mutation node exprs # based on whether any columns are being overwritten. - # TODO issue #2649 overwriting_cols_to_expr: dict[str, ir.Expr | None] = {} non_overwriting_exprs: list[ir.Expr] = [] table_schema = table.schema() diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 6b38001bb6bd..9393170e156f 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -3,7 +3,6 @@ import collections import functools import itertools -import operator import sys import warnings from typing import ( @@ -590,16 +589,21 @@ def mutate( import ibis.expr.analysis as an import ibis.expr.rules as rlz - exprs = [] if exprs is None else util.promote_list(exprs) - for name, expr in sorted(mutations.items(), key=operator.itemgetter(0)): + def ensure_expr(expr): + # This is different than self._ensure_expr, since we don't want to + # treat `str` or `int` values as column indices if util.is_function(expr): - value = expr(self) + return expr(self) elif isinstance(expr, Deferred): - value = expr.resolve(self) + return expr.resolve(self) else: - value = rlz.any(expr).to_expr() - exprs.append(value.name(name)) + return rlz.any(expr).to_expr() + exprs = [] if exprs is None else util.promote_list(exprs) + exprs = [ensure_expr(expr) for expr in exprs] + exprs.extend( + ensure_expr(expr).name(name) for name, expr in sorted(mutations.items()) + ) mutation_exprs = an.get_mutation_exprs(exprs, self) return self.select(mutation_exprs) diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 2902c6f307dd..e8e5f0597a4c 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -238,11 +238,37 @@ def test_projection_array_expr(table): def test_mutate(table): - one = table.f * 2 - foo = (table.a + table.b).name('foo') - - expr = table.mutate(foo, one=one, two=2) - expected = table[table, foo, one.name('one'), ibis.literal(2).name('two')] + expr = table.mutate( + [ + (table.a + 1).name("x1"), + table.b.sum().name("x2"), + (_.a + 2).name("x3"), + lambda _: (_.a + 3).name("x4"), + 4, + "five", + ], + kw1=(table.a + 6), + kw2=table.b.sum(), + kw3=(_.a + 7), + kw4=lambda _: (_.a + 8), + kw5=9, + kw6="ten", + ) + expected = table[ + table, + (table.a + 1).name("x1"), + table.b.sum().name("x2"), + (table.a + 2).name("x3"), + (table.a + 3).name("x4"), + ibis.literal(4).name("4"), + ibis.literal("five").name("'five'"), + (table.a + 6).name("kw1"), + table.b.sum().name("kw2"), + (table.a + 7).name("kw3"), + (table.a + 8).name("kw4"), + ibis.literal(9).name("kw5"), + ibis.literal("ten").name("kw6"), + ] assert_equal(expr, expected)