Skip to content

Commit

Permalink
Fix for parallelism issue #948
Browse files Browse the repository at this point in the history
Before we had difficulty with the traversal when determining a parallel
block. The algorithm was broken and would often traverse to the end,
including source nodes. This just rewrites it -- I didn't bother
debugigng because the approach was not great from the start. This does a
simple DFS + uses a nonlocal variable to track the origin (and update
it), so we can return it.

See #948 for context.
  • Loading branch information
elijahbenizzy authored and skrawcz committed Jun 13, 2024
1 parent 881010e commit ea33f4a
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 34 deletions.
57 changes: 25 additions & 32 deletions hamilton/execution/graph_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,57 +291,50 @@ def nodes_between(
search_condition: lambda node_: bool,
) -> Tuple[Optional[node.Node], List[node.Node]]:
"""Utility function to search backwards from an end node to a start node.
This returns all nodes for which the following conditions are met:
This returns all nodes for which both of the following conditions are met:
1. It contains a node that matches the start_condition as an ancestor
2. It contains a node that matches the end node as a dependent
Note that currently it is assumed that only one node will
match search_condition.
This just grabs the search node when it finds it -- the nonlocal is a bit hacky but more fun
than passing a ton of data back and forth (who the parent is, etc...).
:param end_node: Node to trace back from
:param search_condition: Condition to stop the search for ancestors
:return: A tuple of [start_node, between], where start_node is None
if there is no path (and between will be empty).
"""

out = set()
visited = set()
search_node = None

def dfs_traverse(node_: node.Node):
"""Recursive call. Note that it returns None to signify
that we should not traverse any nodes, and a list to say that
we should continue traversing"""
if search_condition(node_):
# if we hit the end, we want to include all others in it
nonlocal search_node
search_node = node_
return True
if node_ in visited:
return []
# if we've already seen it, we want to include it
return node_ in out
# now we mark that we've seen it
visited.add(node_)
if search_condition(node_):
return [node_]
if node_.user_defined:
return None
out = []

any_deps_included = False
for n in node_.dependencies:
traverse = dfs_traverse(n)
if traverse is not None:
out.extend(traverse)
out.append(n)
if len(out) == 0:
return None
return out

output = []
for node_ in dfs_traverse(end_node) or []:
output.append(node_)
begin_node = None
nodes = []
for node_ in output:
# TODO -- handle the case that there are multiple nodes that match the search condition
if search_condition(node_):
begin_node = node_
elif node_ == end_node:
continue
else:
nodes.append(node_)
return begin_node, nodes
any_deps_included |= dfs_traverse(n)
if any_deps_included:
out.add(node_)
return any_deps_included

for dep in end_node.dependencies:
dfs_traverse(dep)

return search_node, list(out)


def node_is_required_by_anything(node_: node.Node, node_set: Set[node.Node]) -> bool:
Expand Down
16 changes: 15 additions & 1 deletion tests/execution/test_graph_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,22 @@ def _inner(n: node.Node) -> bool:
"d",
),
({"a": [], "b": [], "c": ["a", "b"], "d": "c"}, {"c"}, "a", "d"),
# https://github.com/DAGWorks-Inc/hamilton/issues/948
(
{
"random_int": [],
"numbers": ["random_int"],
"add1": ["numbers", "random_int"],
"add2": ["add1", "random_int"],
"collect_numbers": ["add2"],
"final_result": ["collect_numbers"],
},
{"add1", "add2"},
"numbers",
"collect_numbers",
),
],
ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep"],
ids=["simple_base", "longer_chain", "complex_dag", "subdag_with_external_dep", "issue_948"],
)
def test_find_nodes_between(dag_repr, expected_nodes_in_between, start_node, end_node):
nodes = _create_dummy_dag(dag_repr, dict_output=True)
Expand Down
3 changes: 2 additions & 1 deletion tests/execution/test_node_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_group_nodes_by_repeatable_blocks_complex():
}
assert len(nodes_grouped_by_name["collect-steps"].nodes) == 1
assert len(nodes_grouped_by_name["expand-steps"].nodes) == 1
assert len(nodes_grouped_by_name["block-steps"].nodes) == 6
assert len(nodes_grouped_by_name["block-steps"].nodes) == 5
# See comments in parallel_complex.py for why this is -- between start/end of parallelizable block
assert nodes_grouped_by_name["number_of_steps"].purpose == NodeGroupPurpose.EXECUTE_SINGLE
assert nodes_grouped_by_name["block-steps"].purpose == NodeGroupPurpose.EXECUTE_BLOCK
assert nodes_grouped_by_name["collect-steps"].purpose == NodeGroupPurpose.GATHER
Expand Down
6 changes: 6 additions & 0 deletions tests/resources/dynamic_parallelism/parallel_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def steps(number_of_steps: int) -> Parallelizable[int]:
yield from range(number_of_steps)


# Parallelizable block Start


def step_modified(steps: int, second_param_external_to_block: int) -> int:
return steps + second_param_external_to_block

Expand All @@ -39,6 +42,9 @@ def double_plus_triple_plus_param_external_to_block(
return double_plus_triple_step + param_external_to_block


# Parallelizable block ends here


def sum_of_some_things(double_plus_triple_plus_param_external_to_block: Collect[int]) -> int:
return sum(double_plus_triple_plus_param_external_to_block)

Expand Down

0 comments on commit ea33f4a

Please sign in to comment.