Skip to content

Commit

Permalink
refactor(analysis): remove sub_for(), substitute(), `find_topleve…
Browse files Browse the repository at this point in the history
…l_aggs()`
  • Loading branch information
kszucs authored and cpcloud committed Oct 17, 2023
1 parent cf95ff7 commit 492b296
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 106 deletions.
1 change: 0 additions & 1 deletion ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def decorator(f):
rewrites = ExprTranslator.rewrites


# TODO(kszucs): use analysis.substitute() instead of a custom rewriter
@rewrites(ops.Bucket)
def _bucket(op):
# TODO(kszucs): avoid the expression roundtrip
Expand Down
79 changes: 6 additions & 73 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
import ibis.expr.operations.relations as rels
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import ValidationError
from ibis.common.deferred import _, deferred, var
from ibis.common.exceptions import IbisTypeError, IntegrityError
from ibis.common.patterns import Eq, In, pattern
from ibis.util import Namespace

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from collections.abc import Iterable, Iterator

p = Namespace(pattern, module=ops)
c = Namespace(deferred, module=ops)
Expand All @@ -30,39 +29,10 @@
# compilation later


def sub_for(node: ops.Node, substitutions: Mapping[ops.Node, ops.Node]) -> ops.Node:
"""Substitute operations in `node` with nodes in `substitutions`.
Parameters
----------
node
An Ibis operation
substitutions
A mapping from node to node. If any subnode of `node` is equal to any
of the keys in `substitutions`, the value for that key will replace the
corresponding node in `node`.
Returns
-------
Node
An Ibis expression
"""
assert isinstance(node, ops.Node), type(node)

def fn(node):
try:
return substitutions[node]
except KeyError:
if isinstance(node, ops.TableNode):
return g.halt
return g.proceed

return substitute(fn, node)


def sub_immediate_parents(op: ops.Node, table: ops.TableNode) -> ops.Node:
def sub_immediate_parents(node: ops.Node, table: ops.TableNode) -> ops.Node:
"""Replace immediate parent tables in `op` with `table`."""
return sub_for(op, {base: table for base in find_immediate_parent_tables(op)})
parents = find_immediate_parent_tables(node)
return node.replace(In(parents) >> table)


def find_immediate_parent_tables(input_node, keep_input=True):
Expand Down Expand Up @@ -116,34 +86,6 @@ def finder(node):
return list(toolz.unique(g.traverse(finder, input_node)))


def substitute(fn, node):
"""Substitute expressions with other expressions."""

assert isinstance(node, ops.Node), type(node)

result = fn(node)
if result is g.halt:
return node
elif result is not g.proceed:
assert isinstance(result, ops.Node), type(result)
return result

new_args = []
for arg in node.args:
if isinstance(arg, tuple):
arg = tuple(
substitute(fn, x) if isinstance(arg, ops.Node) else x for x in arg
)
elif isinstance(arg, ops.Node):
arg = substitute(fn, arg)
new_args.append(arg)

try:
return node.__class__(*new_args)
except (TypeError, ValidationError):
return node


def get_mutation_exprs(exprs: list[ir.Expr], table: ir.Table) -> list[ir.Expr | None]:
"""Return the exprs to use to instantiate the mutation."""
# The below logic computes the mutation node exprs by splitting the
Expand Down Expand Up @@ -256,7 +198,8 @@ def simplify_aggregation(agg):
def _pushdown(nodes):
subbed = []
for node in nodes:
subbed.append(sub_for(node, {agg.table: agg.table.table}))
new_node = node.replace(Eq(agg.table) >> agg.table.table)
subbed.append(new_node)

# TODO(kszucs): perhaps this validation could be omitted
if subbed:
Expand Down Expand Up @@ -560,13 +503,3 @@ def finder(node):
)

return g.traverse(finder, nodes, filter=ops.Node)


def find_toplevel_aggs(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]:
def finder(node):
return (
isinstance(node, ops.Value),
node if isinstance(node, ops.Reduction) else None,
)

return g.traverse(finder, nodes, filter=ops.Node)
9 changes: 3 additions & 6 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import Deferred
from ibis.common.grounds import Immutable
from ibis.common.patterns import Coercible
from ibis.common.patterns import Coercible, Eq
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Column, Named, Node, Scalar, Value
from ibis.expr.operations.sortkeys import SortKey # noqa: TCH001
Expand Down Expand Up @@ -236,7 +236,6 @@ def __init__(self, left, right, predicates, **kwargs):
# TODO(kszucs): predicates should be already a list of operations, need
# to update the validation rule for the Join classes which is a noop
# currently
import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.types as ir

Expand All @@ -258,10 +257,8 @@ def __init__(self, left, right, predicates, **kwargs):
# and tables on the right are incorrectly scoped
old = right
new = right = ops.SelfReference(right)
predicates = [
an.sub_for(pred, {old: new}) if isinstance(pred, ops.Node) else pred
for pred in predicates
]
rule = Eq(old) >> new
predicates = [pred.replace(rule) for pred in predicates]

predicates = _clean_join_predicates(left, right, predicates)

Expand Down
22 changes: 7 additions & 15 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3854,6 +3854,7 @@ def pivot_wider(
import ibis.expr.analysis as an
import ibis.selectors as s
from ibis import _
from ibis.expr.analysis import p, x

orig_names_from = util.promote_list(names_from)

Expand Down Expand Up @@ -3895,23 +3896,14 @@ def pivot_wider(
for values_col in values_cols:
arg = values_agg(values_col)

# add in the where clause to filter the appropriate values
# in/out
#
# this allows users to write the aggregate without having to deal with
# the filter themselves
existing_aggs = an.find_toplevel_aggs(arg.op())
subs = {
agg: agg.copy(
where=(
where
if (existing := agg.where) is None
else where & existing
)
)
for agg in existing_aggs
}
arg = an.sub_for(arg.op(), subs).to_expr()
rules = (
# add in the where clause to filter the appropriate values
p.Reduction(where=None) >> _.copy(where=where)
| p.Reduction(where=x) >> _.copy(where=where & x)
)
arg = arg.op().replace(rules, filter=p.Value).to_expr()

# build the components of the group by key
key_components = (
Expand Down
11 changes: 0 additions & 11 deletions ibis/tests/expr/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ibis
import ibis.common.exceptions as com
import ibis.expr.analysis as an
import ibis.expr.operations as ops
from ibis.tests.util import assert_equal

Expand Down Expand Up @@ -149,16 +148,6 @@ def test_filter_self_join():
assert_equal(proj_exprs[1], metric.op())


def test_join_table_choice():
# GH807
x = ibis.table(ibis.schema([("n", "int64")]), "x")
t = x.aggregate(cnt=x.n.count())
predicate = t.cnt > 0

result = an.sub_for(predicate.op(), {t.op(): t.op().table})
assert result == predicate.op()


def test_is_ancestor_analytic():
x = ibis.table(ibis.schema([("col", "int32")]), "x")
with_filter_col = x[x.columns + [ibis.null().name("filter")]]
Expand Down

0 comments on commit 492b296

Please sign in to comment.