Skip to content

Commit

Permalink
feat(common): add a memory efficient Node.map() implementation
Browse files Browse the repository at this point in the history
Alternative implementation of `map` to reduce memory usage. While `map`
keeps all the results in memory until the end of the traversal, the
new `map_clear()` method removes intermediate results as soon as they
are not needed anymore.
  • Loading branch information
kszucs committed Feb 12, 2024
1 parent d269776 commit e3f2217
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
56 changes: 54 additions & 2 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,15 +245,67 @@ def map(self, fn: Callable, filter: Optional[Finder] = None) -> dict[Node, Any]:
A mapping of nodes to their results.
"""
results: dict[Node, Any] = {}
for node in Graph.from_bfs(self, filter=filter).toposort():

graph, _ = Graph.from_bfs(self, filter=filter).toposort()
for node in graph:
# minor optimization to directly recurse into the children
kwargs = {
k: _recursive_lookup(v, results)
for k, v in zip(node.__argnames__, node.__args__)
}
results[node] = fn(node, results, **kwargs)

return results

@experimental
def map_clear(
self, fn: Callable, filter: Optional[Finder] = None
) -> dict[Node, Any]:
"""Apply a function to all nodes in the graph more memory efficiently.
Alternative implementation of `map` to reduce memory usage. While `map` keeps
all the results in memory until the end of the traversal, this method removes
intermediate results as soon as they are not needed anymore.
Prefer this method over `map` if the results consume significant amount of
memory and if the intermediate results are not needed.
Parameters
----------
fn
Function to apply to each node. It receives the node as the first argument,
the results as the second and the results of the children as keyword
arguments.
filter
Pattern-like object to filter out nodes from the traversal. The traversal
will only visit nodes that match the given pattern and stop otherwise.
Returns
-------
In contrast to `map`, this method returns the result of the root node only since
the rest of the results are already discarded.
"""
results: dict[Node, Any] = {}

graph, dependents = Graph.from_bfs(self, filter=filter).toposort()
dependents = {k: set(v) for k, v in dependents.items()}

for node, dependencies in graph.items():
kwargs = {
k: _recursive_lookup(v, results)
for k, v in zip(node.__argnames__, node.__args__)
}
results[node] = fn(node, results, **kwargs)

# remove the results belonging to the dependencies if they are not
# needed by other nodes during the rest of the traversal
for dependency in dependencies:
dependents[dependency].remove(node)
if not dependents[dependency]:
del results[dependency]

return results[self]

# TODO(kszucs): perhaps rename it to find_all() for better clarity
def find(
self,
Expand Down Expand Up @@ -489,7 +541,7 @@ def toposort(self) -> Self:
if any(in_degree.values()):
raise ValueError("cycle detected in the graph")

return result
return result, dependents


# these could be callables instead
Expand Down
24 changes: 23 additions & 1 deletion ibis/common/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ def test_invert():


def test_toposort():
g = Graph(A).toposort()
g, dependents = Graph(A).toposort()
assert list(g.keys()) == [C, D, E, B, A]
assert dependents == Graph(A).invert()


def test_toposort_cycle_detection():
Expand Down Expand Up @@ -427,3 +428,24 @@ def test_node_find_topmost_dont_traverse_the_same_node_twice():
result = E.find_topmost(If(_.name == "G"))
expected = [G]
assert result == expected


def test_map_clear():
Z = MyNode(name="Z", children=[A])
result_sequence = {}

def record_result_keys(node, results, **kwargs):
result_sequence[node] = tuple(results.keys())
return node

expected_result_sequence = {
C: (),
D: (C,),
E: (C, D),
B: (C, D, E),
A: (C, B),
Z: (A,),
}
result = Z.map_clear(record_result_keys)
assert result == Z
assert result_sequence == expected_result_sequence

0 comments on commit e3f2217

Please sign in to comment.