From 239d1421e5ae4089687f50c0eb0e3519189c492c Mon Sep 17 00:00:00 2001 From: jernejfrank Date: Thu, 7 Nov 2024 09:15:49 +0800 Subject: [PATCH] Bugfix pipe_output after config.when resolution In case no pipe_output functions meet confgi.when conditions return original node and skip pipe_output. --- hamilton/function_modifiers/macros.py | 4 ++ tests/function_modifiers/test_macros.py | 65 +++++++++++++++++++++++++ tests/resources/pipe_output.py | 31 ++++++++++++ 3 files changed, 100 insertions(+) diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index 22ffabacd..7cc074985 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -1314,6 +1314,10 @@ def __identity(foo: Any) -> Any: fn=fn, ) + # In case config resolves to no pipe functions applied we return the original node and skip pipe + if len(nodes) == 1: + return [node_] + last_node = nodes[-1].copy_with(name=f"{node_.name}", typ=nodes[-2].type) out = [original_node] diff --git a/tests/function_modifiers/test_macros.py b/tests/function_modifiers/test_macros.py index e1538a3f4..64b53bf8e 100644 --- a/tests/function_modifiers/test_macros.py +++ b/tests/function_modifiers/test_macros.py @@ -976,6 +976,71 @@ def test_pipe_output_end_to_end(): assert result["chain_2_using_pipe_output"] == result["chain_2_not_using_pipe_output"] +def test_pipe_output_end_to_end_with_config(): + inputs = { + "input_1": 10, + "input_2": 20, + "input_3": 30, + } + + dr = ( + driver.Builder() + .with_modules(tests.resources.pipe_output) + .with_adapter(base.DefaultAdapter()) + .with_config({"key": "Yes"}) + .build() + ) + + result = dr.execute( + [ + "chain_3_using_pipe_output", + "chain_3_not_using_pipe_output_config_true", + ], + inputs=inputs, + ) + assert ( + result["chain_3_using_pipe_output"] == result["chain_3_not_using_pipe_output_config_true"] + ) + + dr = ( + driver.Builder() + .with_modules(tests.resources.pipe_output) + .with_adapter(base.DefaultAdapter()) + .with_config({"key": "No"}) + .build() + ) + + result = dr.execute( + [ + "chain_3_using_pipe_output", + "chain_3_not_using_pipe_output_config_false", + ], + inputs=inputs, + ) + assert ( + result["chain_3_using_pipe_output"] == result["chain_3_not_using_pipe_output_config_false"] + ) + + dr = ( + driver.Builder() + .with_modules(tests.resources.pipe_output) + .with_adapter(base.DefaultAdapter()) + .with_config({"key": "skip"}) + .build() + ) + result = dr.execute( + [ + "chain_3_using_pipe_output", + "chain_3_not_using_pipe_output_config_no_conditions_met", + ], + inputs=inputs, + ) + assert ( + result["chain_3_using_pipe_output"] + == result["chain_3_not_using_pipe_output_config_no_conditions_met"] + ) + + # Mutate will mark the modules (and leave a mark). # Thus calling it a second time (for instance through pmultiple tests) might mess it up slightly... # Using fixtures just to be sure. diff --git a/tests/resources/pipe_output.py b/tests/resources/pipe_output.py index 61ebcd365..31b2b6645 100644 --- a/tests/resources/pipe_output.py +++ b/tests/resources/pipe_output.py @@ -105,3 +105,34 @@ def chain_2_not_using_pipe_output(v: int, input_3: int, calc_c: bool = False) -> d = _add_n(c, n=input_3) # Assuming "upstream" refers to the same value as "v" here e = _add_two(d) return e + + +@pipe_output( + step(_square).named("a").when(key="Yes"), + step(_multiply_n, n=value(2)).named("b").when(key="No"), + step(_add_n, n=10).named("c").when(key="Yes"), + step(_add_n, n=source("input_3")).named("d").when(key="No"), + step(_add_two).named("e").when(key="Yes"), +) +def chain_3_using_pipe_output(v: int) -> int: + return v + 10 + + +def chain_3_not_using_pipe_output_config_true(v: int, input_3: int) -> int: + start = v + 10 + a = _square(start) + c = _add_n(a, n=10) + e = _add_two(c) + return e + + +def chain_3_not_using_pipe_output_config_false(v: int, input_3: int) -> int: + start = v + 10 + b = _multiply_n(start, n=2) + d = _add_n(b, n=input_3) + return d + + +def chain_3_not_using_pipe_output_config_no_conditions_met(v: int, input_3: int) -> int: + start = v + 10 + return start