Skip to content

Commit

Permalink
Improve pipe_output first node naming
Browse files Browse the repository at this point in the history
Assigning the same naming convention for first node in pipe chain to
avoid user naming clashes.
  • Loading branch information
jernejfrank committed Oct 21, 2024
1 parent 33fd61d commit 84b7b2a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
3 changes: 2 additions & 1 deletion hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,8 @@ def transform_node(
else:
_namespace = self.namespace

original_node = node_.copy_with(name=f"{node_.name}_raw")
# We pick a reserved prefix that ovoids clashes with user defined functions / nodes
original_node = node_.copy_with(name=f"{node_.name}.with_raw")

def __identity(foo: Any) -> Any:
return foo
Expand Down
16 changes: 12 additions & 4 deletions tests/function_modifiers/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,11 @@ def test_pipe_output_decorator_positional_single_node():
nodes = decorator.transform_dag([n], {}, result_from_downstream_function)
nodes_by_name = {item.name: item for item in nodes}
chain_node = nodes_by_name["node_1"]
assert chain_node(result_from_downstream_function_raw=2, bar_upstream=10) == 112
assert sorted(chain_node.input_types) == ["bar_upstream", "result_from_downstream_function_raw"]
assert chain_node(**{"result_from_downstream_function.with_raw": 2, "bar_upstream": 10}) == 112
assert sorted(chain_node.input_types) == [
"bar_upstream",
"result_from_downstream_function.with_raw",
]
final_node = nodes_by_name["result_from_downstream_function"]
assert final_node(foo=112) == 112 # original arg name
assert final_node(node_1=112) == 112 # renamed to match the last node
Expand All @@ -800,7 +803,12 @@ def test_pipe_output_decorator_no_collapse_multi_node():
nodes_by_name = {item.name: item for item in nodes}
final_node = nodes_by_name["result_from_downstream_function"]
assert len(nodes_by_name) == 4 # We add fn_raw and identity
assert nodes_by_name["node_1"](result_from_downstream_function_raw=1, bar_upstream=10) == 111
assert (
nodes_by_name["node_1"](
**{"result_from_downstream_function.with_raw": 1, "bar_upstream": 10}
)
== 111
)
assert nodes_by_name["node_2"](node_1=4) == 114
assert final_node(node_2=13) == 13

Expand All @@ -827,7 +835,7 @@ def test_pipe_output_inherits_null_namespace():
decorator.validate(result_from_downstream_function)
nodes = decorator.transform_dag([n], {}, result_from_downstream_function)
assert "node_1" in {item.name for item in nodes}
assert "result_from_downstream_function_raw" in {item.name for item in nodes}
assert "result_from_downstream_function.with_raw" in {item.name for item in nodes}
assert "result_from_downstream_function" in {item.name for item in nodes}


Expand Down

0 comments on commit 84b7b2a

Please sign in to comment.