Skip to content

Commit

Permalink
Improve pipe_output first node naming
Browse files Browse the repository at this point in the history
Assigning the same naming convention for first node in pipe chain to
avoid user naming clashes.
  • Loading branch information
jernejfrank committed Oct 21, 2024
1 parent 33fd61d commit bad02cf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
9 changes: 5 additions & 4 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ class pipe_output(base.NodeTransformer):
def B(...):
return ...
we obtain the new DAG **A --> B_raw --> B1 --> B2 --> B --> C**, where we can think of the **B_raw --> B1 --> B2 --> B** as a "pipe" that takes the raw output of **B**, applies to it
we obtain the new DAG **A --> B.with_raw --> B1 --> B2 --> B --> C**, where we can think of the **B.with_raw --> B1 --> B2 --> B** as a "pipe" that takes the raw output of **B**, applies to it
**B1**, takes the output of **B1** applies to it **B2** and then gets renamed to **B** to re-connect to the rest of the DAG.
The rules for chaining nodes are the same as for ``pipe_input``.
Expand Down Expand Up @@ -1282,7 +1282,7 @@ def transform_node(
) -> Collection[node.Node]:
"""Injects nodes into the graph.
We create a copy of the original function and rename it to `function_name_raw` to be the
We create a copy of the original function and rename it to `function_name.with_raw` to be the
initial node. Then we create a node for each step in `post-pipe` and chain them together.
The last node is an identity to the previous one with the original name `function_name` to
represent an exit point of `pipe_output`.
Expand All @@ -1299,7 +1299,8 @@ def transform_node(
else:
_namespace = self.namespace

original_node = node_.copy_with(name=f"{node_.name}_raw")
# We pick a reserved prefix that ovoids clashes with user defined functions / nodes
original_node = node_.copy_with(name=f"{node_.name}.with_raw")

def __identity(foo: Any) -> Any:
return foo
Expand Down Expand Up @@ -1455,7 +1456,7 @@ def _transform1(...):
def _transform2(...):
return ...
we obtain the new pipe-like subDAGs **A_raw --> _transform1 --> A** and **B_raw --> _transform1 --> _transform2 --> B**,
we obtain the new pipe-like subDAGs **A.with_raw --> _transform1 --> A** and **B.with_raw --> _transform1 --> _transform2 --> B**,
where the behavior is the same as ``pipe_output``.
While it is generally reasonable to use ``pipe_output``, you should consider ``mutate`` in the following scenarios:
Expand Down
16 changes: 12 additions & 4 deletions tests/function_modifiers/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,11 @@ def test_pipe_output_decorator_positional_single_node():
nodes = decorator.transform_dag([n], {}, result_from_downstream_function)
nodes_by_name = {item.name: item for item in nodes}
chain_node = nodes_by_name["node_1"]
assert chain_node(result_from_downstream_function_raw=2, bar_upstream=10) == 112
assert sorted(chain_node.input_types) == ["bar_upstream", "result_from_downstream_function_raw"]
assert chain_node(**{"result_from_downstream_function.with_raw": 2, "bar_upstream": 10}) == 112
assert sorted(chain_node.input_types) == [
"bar_upstream",
"result_from_downstream_function.with_raw",
]
final_node = nodes_by_name["result_from_downstream_function"]
assert final_node(foo=112) == 112 # original arg name
assert final_node(node_1=112) == 112 # renamed to match the last node
Expand All @@ -800,7 +803,12 @@ def test_pipe_output_decorator_no_collapse_multi_node():
nodes_by_name = {item.name: item for item in nodes}
final_node = nodes_by_name["result_from_downstream_function"]
assert len(nodes_by_name) == 4 # We add fn_raw and identity
assert nodes_by_name["node_1"](result_from_downstream_function_raw=1, bar_upstream=10) == 111
assert (
nodes_by_name["node_1"](
**{"result_from_downstream_function.with_raw": 1, "bar_upstream": 10}
)
== 111
)
assert nodes_by_name["node_2"](node_1=4) == 114
assert final_node(node_2=13) == 13

Expand All @@ -827,7 +835,7 @@ def test_pipe_output_inherits_null_namespace():
decorator.validate(result_from_downstream_function)
nodes = decorator.transform_dag([n], {}, result_from_downstream_function)
assert "node_1" in {item.name for item in nodes}
assert "result_from_downstream_function_raw" in {item.name for item in nodes}
assert "result_from_downstream_function.with_raw" in {item.name for item in nodes}
assert "result_from_downstream_function" in {item.name for item in nodes}


Expand Down

0 comments on commit bad02cf

Please sign in to comment.