diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index 1cdcb733b..c8712a0da 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -1158,7 +1158,7 @@ class pipe_output(base.NodeTransformer): def B(...): return ... - we obtain the new DAG **A --> B_raw --> B1 --> B2 --> B --> C**, where we can think of the **B_raw --> B1 --> B2 --> B** as a "pipe" that takes the raw output of **B**, applies to it + we obtain the new DAG **A --> B.with_raw --> B1 --> B2 --> B --> C**, where we can think of the **B.with_raw --> B1 --> B2 --> B** as a "pipe" that takes the raw output of **B**, applies to it **B1**, takes the output of **B1** applies to it **B2** and then gets renamed to **B** to re-connect to the rest of the DAG. The rules for chaining nodes are the same as for ``pipe_input``. @@ -1282,7 +1282,7 @@ def transform_node( ) -> Collection[node.Node]: """Injects nodes into the graph. - We create a copy of the original function and rename it to `function_name_raw` to be the + We create a copy of the original function and rename it to `function_name.with_raw` to be the initial node. Then we create a node for each step in `post-pipe` and chain them together. The last node is an identity to the previous one with the original name `function_name` to represent an exit point of `pipe_output`. @@ -1299,7 +1299,8 @@ def transform_node( else: _namespace = self.namespace - original_node = node_.copy_with(name=f"{node_.name}_raw") + # We pick a reserved prefix that ovoids clashes with user defined functions / nodes + original_node = node_.copy_with(name=f"{node_.name}.with_raw") def __identity(foo: Any) -> Any: return foo @@ -1455,7 +1456,7 @@ def _transform1(...): def _transform2(...): return ... - we obtain the new pipe-like subDAGs **A_raw --> _transform1 --> A** and **B_raw --> _transform1 --> _transform2 --> B**, + we obtain the new pipe-like subDAGs **A.with_raw --> _transform1 --> A** and **B.with_raw --> _transform1 --> _transform2 --> B**, where the behavior is the same as ``pipe_output``. While it is generally reasonable to use ``pipe_output``, you should consider ``mutate`` in the following scenarios: diff --git a/tests/function_modifiers/test_macros.py b/tests/function_modifiers/test_macros.py index 27f96b504..ca08827c0 100644 --- a/tests/function_modifiers/test_macros.py +++ b/tests/function_modifiers/test_macros.py @@ -781,8 +781,11 @@ def test_pipe_output_decorator_positional_single_node(): nodes = decorator.transform_dag([n], {}, result_from_downstream_function) nodes_by_name = {item.name: item for item in nodes} chain_node = nodes_by_name["node_1"] - assert chain_node(result_from_downstream_function_raw=2, bar_upstream=10) == 112 - assert sorted(chain_node.input_types) == ["bar_upstream", "result_from_downstream_function_raw"] + assert chain_node(**{"result_from_downstream_function.with_raw": 2, "bar_upstream": 10}) == 112 + assert sorted(chain_node.input_types) == [ + "bar_upstream", + "result_from_downstream_function.with_raw", + ] final_node = nodes_by_name["result_from_downstream_function"] assert final_node(foo=112) == 112 # original arg name assert final_node(node_1=112) == 112 # renamed to match the last node @@ -800,7 +803,12 @@ def test_pipe_output_decorator_no_collapse_multi_node(): nodes_by_name = {item.name: item for item in nodes} final_node = nodes_by_name["result_from_downstream_function"] assert len(nodes_by_name) == 4 # We add fn_raw and identity - assert nodes_by_name["node_1"](result_from_downstream_function_raw=1, bar_upstream=10) == 111 + assert ( + nodes_by_name["node_1"]( + **{"result_from_downstream_function.with_raw": 1, "bar_upstream": 10} + ) + == 111 + ) assert nodes_by_name["node_2"](node_1=4) == 114 assert final_node(node_2=13) == 13 @@ -827,7 +835,7 @@ def test_pipe_output_inherits_null_namespace(): decorator.validate(result_from_downstream_function) nodes = decorator.transform_dag([n], {}, result_from_downstream_function) assert "node_1" in {item.name for item in nodes} - assert "result_from_downstream_function_raw" in {item.name for item in nodes} + assert "result_from_downstream_function.with_raw" in {item.name for item in nodes} assert "result_from_downstream_function" in {item.name for item in nodes}