Skip to content

Commit

Permalink
Bugfix pipe_output after config.when resolution
Browse files Browse the repository at this point in the history
In case no pipe_output functions meet confgi.when conditions return
original node and skip pipe_output.
  • Loading branch information
jernejfrank authored and skrawcz committed Nov 12, 2024
1 parent 4e61771 commit 239d142
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
4 changes: 4 additions & 0 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,10 @@ def __identity(foo: Any) -> Any:
fn=fn,
)

# In case config resolves to no pipe functions applied we return the original node and skip pipe
if len(nodes) == 1:
return [node_]

last_node = nodes[-1].copy_with(name=f"{node_.name}", typ=nodes[-2].type)

out = [original_node]
Expand Down
65 changes: 65 additions & 0 deletions tests/function_modifiers/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,71 @@ def test_pipe_output_end_to_end():
assert result["chain_2_using_pipe_output"] == result["chain_2_not_using_pipe_output"]


def test_pipe_output_end_to_end_with_config():
inputs = {
"input_1": 10,
"input_2": 20,
"input_3": 30,
}

dr = (
driver.Builder()
.with_modules(tests.resources.pipe_output)
.with_adapter(base.DefaultAdapter())
.with_config({"key": "Yes"})
.build()
)

result = dr.execute(
[
"chain_3_using_pipe_output",
"chain_3_not_using_pipe_output_config_true",
],
inputs=inputs,
)
assert (
result["chain_3_using_pipe_output"] == result["chain_3_not_using_pipe_output_config_true"]
)

dr = (
driver.Builder()
.with_modules(tests.resources.pipe_output)
.with_adapter(base.DefaultAdapter())
.with_config({"key": "No"})
.build()
)

result = dr.execute(
[
"chain_3_using_pipe_output",
"chain_3_not_using_pipe_output_config_false",
],
inputs=inputs,
)
assert (
result["chain_3_using_pipe_output"] == result["chain_3_not_using_pipe_output_config_false"]
)

dr = (
driver.Builder()
.with_modules(tests.resources.pipe_output)
.with_adapter(base.DefaultAdapter())
.with_config({"key": "skip"})
.build()
)
result = dr.execute(
[
"chain_3_using_pipe_output",
"chain_3_not_using_pipe_output_config_no_conditions_met",
],
inputs=inputs,
)
assert (
result["chain_3_using_pipe_output"]
== result["chain_3_not_using_pipe_output_config_no_conditions_met"]
)


# Mutate will mark the modules (and leave a mark).
# Thus calling it a second time (for instance through pmultiple tests) might mess it up slightly...
# Using fixtures just to be sure.
Expand Down
31 changes: 31 additions & 0 deletions tests/resources/pipe_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,34 @@ def chain_2_not_using_pipe_output(v: int, input_3: int, calc_c: bool = False) ->
d = _add_n(c, n=input_3) # Assuming "upstream" refers to the same value as "v" here
e = _add_two(d)
return e


@pipe_output(
step(_square).named("a").when(key="Yes"),
step(_multiply_n, n=value(2)).named("b").when(key="No"),
step(_add_n, n=10).named("c").when(key="Yes"),
step(_add_n, n=source("input_3")).named("d").when(key="No"),
step(_add_two).named("e").when(key="Yes"),
)
def chain_3_using_pipe_output(v: int) -> int:
return v + 10


def chain_3_not_using_pipe_output_config_true(v: int, input_3: int) -> int:
start = v + 10
a = _square(start)
c = _add_n(a, n=10)
e = _add_two(c)
return e


def chain_3_not_using_pipe_output_config_false(v: int, input_3: int) -> int:
start = v + 10
b = _multiply_n(start, n=2)
d = _add_n(b, n=input_3)
return d


def chain_3_not_using_pipe_output_config_no_conditions_met(v: int, input_3: int) -> int:
start = v + 10
return start

0 comments on commit 239d142

Please sign in to comment.