diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index f46dc3ba0..13e117734 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union +from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union from typing_extensions import NotRequired, TypedDict @@ -387,6 +387,20 @@ def validate(self, fn): self._validate_parameterization() + def required_config(self) -> Optional[List[str]]: + """Currently we do not filter for subdag as we do not *statically* know what configuration + is required. This is because we need to parse the function so that we can figure it out, + and that is not available at the time that we call required_config. We need to think about + the best way to do this, but its likely that we'll want to allow required_config to consume + the function itself, and pass it in when its called with that. + + That said, we don't have sufficient justification to do that yet, so we're just going to + return None for now, meaning that it has access to all configuration variables. + + :return: + """ + return None + class SubdagParams(TypedDict): inputs: NotRequired[Dict[str, ParametrizedDependency]] @@ -511,3 +525,10 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node def validate(self, fn: Callable): for subdag_generator in self._gather_subdag_generators(): subdag_generator.validate(fn) + + def required_config(self) -> Optional[List[str]]: + """See note for subdag.required_config -- this is the same pattern. + + :return: Any required config items. + """ + return None diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index e9b45a48d..31e7bc9a4 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -309,6 +309,45 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: ) +def test_nested_subdag_with_config(): + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag(inner_subdag, inputs={"input_2": value(3)}) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + fg = graph.FunctionGraph(full_module, config={"broken": False}) + assert "outer_subdag_1" in fg.nodes + assert "outer_subdag_2" in fg.nodes + res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2}) + # This is effectively the function graph + assert res["sum_all"] == sum_all( + outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) + ) + + def test_subdag_with_external_nodes_input(): def bar(input_1: int) -> int: return input_1 + 1 @@ -358,3 +397,78 @@ def foo_bar_baz_summed(foo_bar_baz_input_1: int, foo_bar_baz_input_2: int) -> in assert res["foo_bar_baz_summed"] == foo_bar_baz(foo(10), foo(20), 100) + foo_bar_baz( bar(30), foo(40), 100 ) + + +def test_parameterized_subdag_with_config(): + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag(inner_subdag, inputs={"input_2": value(3)}) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + fg = graph.FunctionGraph(full_module, config={"broken": False}) + assert "outer_subdag_1" in fg.nodes + assert "outer_subdag_2" in fg.nodes + res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2}) + # This is effectively the function graph + assert res["sum_all"] == sum_all( + outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) + ) + + +def test_nested_parameterized_subdag_with_config(): + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @parameterized_subdag(foo, bar, inner_subdag={}) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @parameterized_subdag( + inner_subdag, + outer_subdag_1={"inputs": {"input_2": value(10)}}, + outer_subdag_2={"inputs": {"input_2": value(3)}}, + ) + def outer_subdag_n(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_n, sum_all) + fg = graph.FunctionGraph(full_module, config={"broken": False}) + assert "outer_subdag_1" in fg.nodes + assert "outer_subdag_2" in fg.nodes + res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2}) + # This is effectively the function graph + assert res["sum_all"] == sum_all( + outer_subdag_n(inner_subdag(bar(2), foo(10))), outer_subdag_n(inner_subdag(bar(2), foo(3))) + )