diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 52f8ceba19df..820d9f890c22 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -370,7 +370,7 @@ def sqlize( # lower the expression graph to a SQL-like relational algebra context = {"params": params} - sqlized = node.replace( + result = node.replace( replace_parameter | project_to_select | filter_to_select @@ -385,24 +385,23 @@ def sqlize( # squash subsequent Select nodes into one if fuse_selects: - simplified = sqlized.replace(merge_select_select) - else: - simplified = sqlized + result = result.replace(merge_select_select) if post_rewrites: - simplified = simplified.replace(reduce(operator.or_, post_rewrites)) + result = result.replace(reduce(operator.or_, post_rewrites)) # extract common table expressions while wrapping them in a CTE node - ctes = extract_ctes(simplified) + ctes = extract_ctes(result) - def wrap(node, _, **kwargs): - new = node.__recreate__(kwargs) - return CTE(new) if node in ctes else new + if ctes: - result = simplified.replace(wrap) - ctes = [cte.parent for cte in result.find(CTE, ordered=True)] + def apply_ctes(node, kwargs): + new = node.__recreate__(kwargs) if kwargs else node + return CTE(new) if node in ctes else new - return result, ctes + result = result.replace(apply_ctes) + return result, [cte.parent for cte in result.find(CTE, ordered=True)] + return result, [] # supplemental rewrites selectively used on a per-backend basis diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index cf9e43dc75ac..7790dfcca111 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -1382,16 +1382,14 @@ def test_histogram(con, alltypes): hist = con.execute(alltypes.int_col.histogram(n).name("hist")) vc = hist.value_counts().sort_index() vc_np, _bin_edges = np.histogram(alltypes.int_col.execute(), bins=n) - assert vc.tolist() == vc_np.tolist() - assert ( - con.execute( - ibis.memtable({"value": range(100)}) - .select(bin=_.value.histogram(10)) - .value_counts() - .bin_count.nunique() - ) - == 1 + expr = ( + ibis.memtable({"value": range(100)}) + .select(bin=_.value.histogram(10)) + .value_counts() + .bin_count.nunique() ) + assert vc.tolist() == vc_np.tolist() + assert con.execute(expr) == 1 @pytest.mark.parametrize("const", ["pi", "e"]) diff --git a/ibis/common/graph.py b/ibis/common/graph.py index 9667cf3ce747..1b11e1f00d75 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -22,7 +22,7 @@ Finder = Callable[["Node"], bool] FinderLike = Union[Finder, Pattern, _ClassInfo] -Replacer = Callable[["Node", dict["Node", Any]], "Node"] +Replacer = Callable[["Node", dict["Node", Any] | None], "Node"] ReplacerLike = Union[Replacer, Pattern, Mapping] @@ -127,6 +127,47 @@ def _recursive_lookup(obj: Any, dct: dict) -> Any: return obj +def _apply_replacements(obj: Any, replacements: dict) -> tuple[Any, bool]: + """Replace nodes in a possibly nested object. + + Parameters + ---------- + obj + The object to traverse. + replacements + A mapping of replacement values. + + Returns + ------- + tuple[Any, bool] + A tuple of the replaced object and whether any replacements were made. + """ + if isinstance(obj, Node): + val = replacements.get(obj) + return (obj, False) if val is None else (val, True) + typ = type(obj) + if typ in (tuple, frozenset, list): + changed = False + items = [] + for i in obj: + i, ichanged = _apply_replacements(i, replacements) + changed |= ichanged + items.append(i) + return typ(items), changed + elif isinstance(obj, dict): + changed = False + items = {} + for k, v in obj.items(): + k, kchanged = _apply_replacements(k, replacements) + v, vchanged = _apply_replacements(v, replacements) + changed |= kchanged + changed |= vchanged + items[k] = v + return items, changed + else: + return obj, False + + def _coerce_finder(obj: FinderLike, context: Optional[dict] = None) -> Finder: """Coerce an object into a callable finder function. @@ -165,8 +206,7 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla Parameters ---------- obj - A Pattern, a Mapping or a callable which can be fed to `node.map()` - to replace nodes. + A Pattern, Mapping, or Callable. context Optional context to use if the replacer is a pattern. @@ -177,26 +217,26 @@ def _coerce_replacer(obj: ReplacerLike, context: Optional[dict] = None) -> Repla """ if isinstance(obj, Pattern): - def fn(node, _, **kwargs): + def fn(node, kwargs): ctx = context or {} # need to first reconstruct the node from the possible rewritten # children, so we can match on the new node containing the rewritten # child arguments, this way we can propagate the rewritten nodes - # upward in the hierarchy, using a specialized __recreate__ method - # improves the performance by 17% compared node.__class__(**kwargs) - recreated = node.__recreate__(kwargs) + # upward in the hierarchy + recreated = node.__recreate__(kwargs) if kwargs else node if (result := obj.match(recreated, ctx)) is NoMatch: return recreated - else: - return result + return result elif isinstance(obj, Mapping): - def fn(node, _, **kwargs): + def fn(node, kwargs): + # For a mapping we want to lookup the original node first, and + # return a recreated one from the children if it's not present try: return obj[node] except KeyError: - return node.__recreate__(kwargs) + return node.__recreate__(kwargs) if kwargs else node elif callable(obj): fn = obj else: @@ -313,7 +353,7 @@ def map_clear(self, fn: Callable, filter: Optional[Finder] = None) -> Any: if not dependents[dependency]: del results[dependency] - return results[self] + return results.get(self, self) @experimental def map_nodes(self, fn: Callable, filter: Optional[Finder] = None) -> Any: @@ -451,8 +491,9 @@ def replace( Parameters ---------- replacer - A `Pattern`, a `Mapping` or a callable which can be fed to - `node.map()` directly to replace nodes. + A `Pattern`, `Mapping` or Callable taking the original unrewritten + node, and a mapping of attribute name to value of its rewritten + children (or None if no children were rewritten). filter A type, tuple of types, a pattern or a callable to filter out nodes from the traversal. The traversal will only visit nodes that match @@ -465,9 +506,28 @@ def replace( The root node of the graph with the replaced nodes. """ - replacer = _coerce_replacer(replacer, context) - results = self.map(replacer, filter=filter) - return results.get(self, self) + replacements: dict[Node, Any] = {} + + fn = _coerce_replacer(replacer, context) + + graph, _ = Graph.from_bfs(self, filter=filter).toposort() + for node in graph: + kwargs = {} + # Apply already rewritten nodes to the children of the node + changed = False + for k, v in zip(node.__argnames__, node.__args__): + v, vchanged = _apply_replacements(v, replacements) + changed |= vchanged + kwargs[k] = v + + # Call the replacer on the node with any rewritten nodes (or None + # if unchanged). + result = fn(node, kwargs if changed else None) + if result is not node: + # The node is changed, store it in the mapping of replacements + replacements[node] = result + + return replacements.get(self, self) class Graph(dict[Node, Sequence[Node]]): diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index ea0ea5c9136f..db4f466fdfd2 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -19,7 +19,7 @@ traverse, ) from ibis.common.grounds import Annotable, Concrete -from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _ +from ibis.common.patterns import Eq, If, InstanceOf, Object, TupleOf, _, pattern class MyNode(Node): @@ -170,6 +170,36 @@ def test_replace_with_mapping(): assert result == new_A +@pytest.mark.parametrize("kind", ["pattern", "mapping", "function"]) +def test_replace_doesnt_recreate_unchanged_nodes(kind): + A1 = MyNode(name="A1", children=[]) + A2 = MyNode(name="A2", children=[A1]) + B1 = MyNode(name="B1", children=[]) + B2 = MyNode(name="B2", children=[B1]) + C = MyNode(name="C", children=[A2, B2]) + + B3 = MyNode(name="B3", children=[]) + + if kind == "pattern": + replacer = pattern(MyNode)(name="B2") >> B3 + elif kind == "mapping": + replacer = {B2: B3} + else: + + def replacer(node, children): + if node is B2: + return B3 + return node.__recreate__(children) if children else node + + res = C.replace(replacer) + + assert res is not C + assert res.name == "C" + assert len(res.children) == 2 + assert res.children[0] is A2 + assert res.children[1] is B3 + + def test_example(): class Example(Annotable, Node): def __hash__(self): @@ -343,17 +373,18 @@ def test_coerce_finder(): def test_coerce_replacer(): - r = _coerce_replacer(lambda x, _, **kwargs: D) - assert r(C, {}) == D + r = _coerce_replacer(lambda x, children: D if children else C) + assert r(C, {"children": []}) is D + assert r(C, None) is C r = _coerce_replacer({C: D, D: E}) assert r(C, {}) == D assert r(D, {}) == E - assert r(A, {}, name="A", children=[B, C]) == A + assert r(A, {"name": "A", "children": [B, C]}) == A r = _coerce_replacer(InstanceOf(MyNode) >> _.copy(name=_.name.lower())) - assert r(C, {}, name="C", children=[]) == MyNode(name="c", children=[]) - assert r(D, {}, name="D", children=[]) == MyNode(name="d", children=[]) + assert r(C, {"name": "C", "children": []}) == MyNode(name="c", children=[]) + assert r(D, {"name": "D", "children": []}) == MyNode(name="d", children=[]) def test_node_find_using_type():