From e3f2217c4767a13f84e62adb2c8f8fc31290b61f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 29 Dec 2023 16:50:17 +0100 Subject: [PATCH] feat(common): add a memory efficient `Node.map()` implementation Alternative implementation of `map` to reduce memory usage. While `map` keeps all the results in memory until the end of the traversal, the new `map_clear()` method removes intermediate results as soon as they are not needed anymore. --- ibis/common/graph.py | 56 +++++++++++++++++++++++++++++++-- ibis/common/tests/test_graph.py | 24 +++++++++++++- 2 files changed, 77 insertions(+), 3 deletions(-) diff --git a/ibis/common/graph.py b/ibis/common/graph.py index b8bbe932dd95..89f81af4d8cd 100644 --- a/ibis/common/graph.py +++ b/ibis/common/graph.py @@ -245,15 +245,67 @@ def map(self, fn: Callable, filter: Optional[Finder] = None) -> dict[Node, Any]: A mapping of nodes to their results. """ results: dict[Node, Any] = {} - for node in Graph.from_bfs(self, filter=filter).toposort(): + + graph, _ = Graph.from_bfs(self, filter=filter).toposort() + for node in graph: # minor optimization to directly recurse into the children kwargs = { k: _recursive_lookup(v, results) for k, v in zip(node.__argnames__, node.__args__) } results[node] = fn(node, results, **kwargs) + return results + @experimental + def map_clear( + self, fn: Callable, filter: Optional[Finder] = None + ) -> dict[Node, Any]: + """Apply a function to all nodes in the graph more memory efficiently. + + Alternative implementation of `map` to reduce memory usage. While `map` keeps + all the results in memory until the end of the traversal, this method removes + intermediate results as soon as they are not needed anymore. + + Prefer this method over `map` if the results consume significant amount of + memory and if the intermediate results are not needed. + + Parameters + ---------- + fn + Function to apply to each node. It receives the node as the first argument, + the results as the second and the results of the children as keyword + arguments. + filter + Pattern-like object to filter out nodes from the traversal. The traversal + will only visit nodes that match the given pattern and stop otherwise. + + Returns + ------- + In contrast to `map`, this method returns the result of the root node only since + the rest of the results are already discarded. + """ + results: dict[Node, Any] = {} + + graph, dependents = Graph.from_bfs(self, filter=filter).toposort() + dependents = {k: set(v) for k, v in dependents.items()} + + for node, dependencies in graph.items(): + kwargs = { + k: _recursive_lookup(v, results) + for k, v in zip(node.__argnames__, node.__args__) + } + results[node] = fn(node, results, **kwargs) + + # remove the results belonging to the dependencies if they are not + # needed by other nodes during the rest of the traversal + for dependency in dependencies: + dependents[dependency].remove(node) + if not dependents[dependency]: + del results[dependency] + + return results[self] + # TODO(kszucs): perhaps rename it to find_all() for better clarity def find( self, @@ -489,7 +541,7 @@ def toposort(self) -> Self: if any(in_degree.values()): raise ValueError("cycle detected in the graph") - return result + return result, dependents # these could be callables instead diff --git a/ibis/common/tests/test_graph.py b/ibis/common/tests/test_graph.py index 50839529aa9a..787926660a9e 100644 --- a/ibis/common/tests/test_graph.py +++ b/ibis/common/tests/test_graph.py @@ -101,8 +101,9 @@ def test_invert(): def test_toposort(): - g = Graph(A).toposort() + g, dependents = Graph(A).toposort() assert list(g.keys()) == [C, D, E, B, A] + assert dependents == Graph(A).invert() def test_toposort_cycle_detection(): @@ -427,3 +428,24 @@ def test_node_find_topmost_dont_traverse_the_same_node_twice(): result = E.find_topmost(If(_.name == "G")) expected = [G] assert result == expected + + +def test_map_clear(): + Z = MyNode(name="Z", children=[A]) + result_sequence = {} + + def record_result_keys(node, results, **kwargs): + result_sequence[node] = tuple(results.keys()) + return node + + expected_result_sequence = { + C: (), + D: (C,), + E: (C, D), + B: (C, D, E), + A: (C, B), + Z: (A,), + } + result = Z.map_clear(record_result_keys) + assert result == Z + assert result_sequence == expected_result_sequence