diff --git a/hamilton/execution/graph_functions.py b/hamilton/execution/graph_functions.py index 0399833dc..dfadbbf90 100644 --- a/hamilton/execution/graph_functions.py +++ b/hamilton/execution/graph_functions.py @@ -291,7 +291,7 @@ def nodes_between( search_condition: lambda node_: bool, ) -> Tuple[Optional[node.Node], List[node.Node]]: """Utility function to search backwards from an end node to a start node. - This returns all nodes for which the following conditions are met: + This returns all nodes for which both of the following conditions are met: 1. It contains a node that matches the start_condition as an ancestor 2. It contains a node that matches the end node as a dependent @@ -299,49 +299,42 @@ def nodes_between( Note that currently it is assumed that only one node will match search_condition. + This just grabs the search node when it finds it -- the nonlocal is a bit hacky but more fun + than passing a ton of data back and forth (who the parent is, etc...). + :param end_node: Node to trace back from :param search_condition: Condition to stop the search for ancestors :return: A tuple of [start_node, between], where start_node is None if there is no path (and between will be empty). """ + out = set() visited = set() + search_node = None def dfs_traverse(node_: node.Node): - """Recursive call. Note that it returns None to signify - that we should not traverse any nodes, and a list to say that - we should continue traversing""" + if search_condition(node_): + # if we hit the end, we want to include all others in it + nonlocal search_node + search_node = node_ + return True if node_ in visited: - return [] + # if we've already seen it, we want to include it + return node_ in out + # now we mark that we've seen it visited.add(node_) - if search_condition(node_): - return [node_] - if node_.user_defined: - return None - out = [] + + any_deps_included = False for n in node_.dependencies: - traverse = dfs_traverse(n) - if traverse is not None: - out.extend(traverse) - out.append(n) - if len(out) == 0: - return None - return out - - output = [] - for node_ in dfs_traverse(end_node) or []: - output.append(node_) - begin_node = None - nodes = [] - for node_ in output: - # TODO -- handle the case that there are multiple nodes that match the search condition - if search_condition(node_): - begin_node = node_ - elif node_ == end_node: - continue - else: - nodes.append(node_) - return begin_node, nodes + any_deps_included |= dfs_traverse(n) + if any_deps_included: + out.add(node_) + return any_deps_included + + for dep in end_node.dependencies: + dfs_traverse(dep) + + return search_node, list(out) def node_is_required_by_anything(node_: node.Node, node_set: Set[node.Node]) -> bool: diff --git a/tests/execution/test_graph_functions.py b/tests/execution/test_graph_functions.py index d1449346a..fa9a63968 100644 --- a/tests/execution/test_graph_functions.py +++ b/tests/execution/test_graph_functions.py @@ -131,8 +131,22 @@ def _inner(n: node.Node) -> bool: "d", ), ({"a": [], "b": [], "c": ["a", "b"], "d": "c"}, {"c"}, "a", "d"), + # https://github.com/DAGWorks-Inc/hamilton/issues/948 + ( + { + "random_int": [], + "numbers": ["random_int"], + "add1": ["numbers", "random_int"], + "add2": ["add1", "random_int"], + "collect_numbers": ["add2"], + "final_result": ["collect_numbers"], + }, + {"add1", "add2"}, + "numbers", + "collect_numbers", + ), ], - ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep"], + ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep", "issue_948"], ) def test_find_nodes_between(dag_repr, expected_nodes_in_between, start_node, end_node): nodes = _create_dummy_dag(dag_repr, dict_output=True) diff --git a/tests/execution/test_node_grouping.py b/tests/execution/test_node_grouping.py index e8912c43f..270159fc3 100644 --- a/tests/execution/test_node_grouping.py +++ b/tests/execution/test_node_grouping.py @@ -64,7 +64,8 @@ def test_group_nodes_by_repeatable_blocks_complex(): } assert len(nodes_grouped_by_name["collect-steps"].nodes) == 1 assert len(nodes_grouped_by_name["expand-steps"].nodes) == 1 - assert len(nodes_grouped_by_name["block-steps"].nodes) == 6 + assert len(nodes_grouped_by_name["block-steps"].nodes) == 5 + # See comments in parallel_complex.py for why this is -- between start/end of parallelizable block assert nodes_grouped_by_name["number_of_steps"].purpose == NodeGroupPurpose.EXECUTE_SINGLE assert nodes_grouped_by_name["block-steps"].purpose == NodeGroupPurpose.EXECUTE_BLOCK assert nodes_grouped_by_name["collect-steps"].purpose == NodeGroupPurpose.GATHER diff --git a/tests/resources/dynamic_parallelism/parallel_complex.py b/tests/resources/dynamic_parallelism/parallel_complex.py index c46557d3b..c805f48f6 100644 --- a/tests/resources/dynamic_parallelism/parallel_complex.py +++ b/tests/resources/dynamic_parallelism/parallel_complex.py @@ -17,6 +17,9 @@ def steps(number_of_steps: int) -> Parallelizable[int]: yield from range(number_of_steps) +# Parallelizable block Start + + def step_modified(steps: int, second_param_external_to_block: int) -> int: return steps + second_param_external_to_block @@ -39,6 +42,9 @@ def double_plus_triple_plus_param_external_to_block( return double_plus_triple_step + param_external_to_block +# Parallelizable block ends here + + def sum_of_some_things(double_plus_triple_plus_param_external_to_block: Collect[int]) -> int: return sum(double_plus_triple_plus_param_external_to_block)