From cd91a7e4fb03040c7e20cbb4d3d23504ea24f348 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 16 Oct 2023 12:45:44 +0200 Subject: [PATCH] refactor(analysis): remove `substitute_parents()` --- .../base/sql/compiler/select_builder.py | 19 ++++++------- ibis/expr/analysis.py | 27 ------------------- ibis/tests/expr/test_analysis.py | 15 +++-------- 3 files changed, 11 insertions(+), 50 deletions(-) diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py index 4966c3778811..3b66f8f88828 100644 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ b/ibis/backends/base/sql/compiler/select_builder.py @@ -213,14 +213,12 @@ def _collect_Aggregation(self, op, toplevel=False): # format these depending on the database. Most likely the # GROUP BY 1, 2, ... style if toplevel: - sub_op = an.substitute_parents(op) - - self.group_by = self._convert_group_by(sub_op.by) - self.having = sub_op.having - self.select_set = sub_op.by + sub_op.metrics - self.table_set = sub_op.table - self.filters = sub_op.predicates - self.order_by = sub_op.sort_keys + self.group_by = self._convert_group_by(op.by) + self.having = op.having + self.select_set = op.by + op.metrics + self.table_set = op.table + self.filters = op.predicates + self.order_by = op.sort_keys self._collect(op.table) @@ -256,9 +254,8 @@ def _convert_group_by(self, nodes): def _collect_Join(self, op, toplevel=False): if toplevel: - subbed = an.substitute_parents(op) - self.table_set = subbed - self.select_set = [subbed] + self.table_set = op + self.select_set = [op] def _collect_PhysicalTable(self, op, toplevel=False): if toplevel: diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 303b653caf14..033588470767 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -144,33 +144,6 @@ def substitute(fn, node): return node -def substitute_parents(node): - """Rewrite `node` by replacing table nodes that commute.""" - assert isinstance(node, ops.Node), type(node) - - def fn(node): - if isinstance(node, ops.Selection): - # stop substituting child nodes - return g.halt - elif isinstance(node, ops.TableColumn): - # For table column references, in the event that we're on top of a - # projection, we need to check whether the ref comes from the base - # table schema or is a derived field. If we've projected out of - # something other than a physical table, then lifting should not - # occur - table = node.table - - if isinstance(table, ops.Selection): - for val in table.selections: - if isinstance(val, ops.PhysicalTable) and node.name in val.schema: - return ops.TableColumn(val, node.name) - - # keep looking for nodes to substitute - return g.proceed - - return substitute(fn, node) - - def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | None]: """Return the exprs to use to instantiate the mutation.""" # The below logic computes the mutation node exprs by splitting the diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index c98c4fe96d5e..bb5eefd3c045 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -11,6 +11,8 @@ # Place to collect esoteric expression analysis bugs and tests +# TODO(kszucs): not directly using an analysis function anymore, move to a +# more appropriate test module def test_rewrite_join_projection_without_other_ops(con): # See #790, predicate pushdown in joins not supported @@ -34,9 +36,7 @@ def test_rewrite_join_projection_without_other_ops(con): ex_pred2 = table["bar_id"] == table3["bar_id"] ex_expr = table.left_join(table2, [pred1]).inner_join(table3, [ex_pred2]) - rewritten_proj = an.substitute_parents(view.op()) - - assert not rewritten_proj.table.equals(ex_expr.op()) + assert view.op().table != ex_expr.op() def test_multiple_join_deeper_reference(): @@ -149,15 +149,6 @@ def test_filter_self_join(): assert_equal(proj_exprs[1], metric.op()) -def test_no_rewrite(con): - table = con.table("test1") - table4 = table[["c", (table["c"] * 2).name("foo")]] - expr = table4["c"] == table4["foo"] - result = an.substitute_parents(expr.op()).to_expr() - expected = expr - assert result.equals(expected) - - def test_join_table_choice(): # GH807 x = ibis.table(ibis.schema([("n", "int64")]), "x")