diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py index 0399833dc..c825e76cc 100644 --- a/hamilton/execution/graph_functions.py +++ b/hamilton/execution/graph_functions.py @@ -291,7 +291,7 @@ def nodes_between( search_condition: lambda node_: bool, ) -> Tuple[Optional[node.Node], List[node.Node]]: """Utility function to search backwards from an end node to a start node. - This returns all nodes for which the following conditions are met: + This returns all nodes for which both of the following conditions are met: 1. It contains a node that matches the start_condition as an ancestor 2. It contains a node that matches the end node as a dependent @@ -299,49 +299,42 @@ def nodes_between( Note that currently it is assumed that only one node will match search_condition. + This just grabs the search node when it finds it -- the nonlocal is a bit hacky but more fun + than passing a ton of data back and forth (who the parent is, etc...). + :param end_node: Node to trace back from :param search_condition: Condition to stop the search for ancestors :return: A tuple of [start_node, between], where start_node is None if there is no path (and between will be empty). - """ + #""" + out = set() visited = set() + search_node = None def dfs_traverse(node_: node.Node): - """Recursive call. Note that it returns None to signify - that we should not traverse any nodes, and a list to say that - we should continue traversing""" + if search_condition(node_): + # if we hit the end, we want to include all others in it + nonlocal search_node + search_node = node_ + return True if node_ in visited: - return [] + # if we've already seen it, we want to include it + return node_ in out + # now we mark that we've seen it visited.add(node_) - if search_condition(node_): - return [node_] - if node_.user_defined: - return None - out = [] + + any_deps_included = False for n in node_.dependencies: - traverse = dfs_traverse(n) - if traverse is not None: - out.extend(traverse) - out.append(n) - if len(out) == 0: - return None - return out - - output = [] - for node_ in dfs_traverse(end_node) or []: - output.append(node_) - begin_node = None - nodes = [] - for node_ in output: - # TODO -- handle the case that there are multiple nodes that match the search condition - if search_condition(node_): - begin_node = node_ - elif node_ == end_node: - continue - else: - nodes.append(node_) - return begin_node, nodes + any_deps_included |= dfs_traverse(n) + if any_deps_included: + out.add(node_) + return any_deps_included + + for dep in end_node.dependencies: + dfs_traverse(dep) + + return search_node, list(out) def node_is_required_by_anything(node_: node.Node, node_set: Set[node.Node]) -> bool: diff --git a/tests/execution/test_graph_functions.py b/tests/execution/test_graph_functions.py index d1449346a..fa9a63968 100644 --- a/tests/execution/test_graph_functions.py +++ b/tests/execution/test_graph_functions.py @@ -131,8 +131,22 @@ def _inner(n: node.Node) -> bool: "d", ), ({"a": [], "b": [], "c": ["a", "b"], "d": "c"}, {"c"}, "a", "d"), + # https://github.com/DAGWorks-Inc/hamilton/issues/948 + ( + { + "random_int": [], + "numbers": ["random_int"], + "add1": ["numbers", "random_int"], + "add2": ["add1", "random_int"], + "collect_numbers": ["add2"], + "final_result": ["collect_numbers"], + }, + {"add1", "add2"}, + "numbers", + "collect_numbers", + ), ], - ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep"], + ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep", "issue_948"], ) def test_find_nodes_between(dag_repr, expected_nodes_in_between, start_node, end_node): nodes = _create_dummy_dag(dag_repr, dict_output=True)