From fe9b77a83f142a2dbb8062b43d43513de60e09cd Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Fri, 6 Dec 2024 23:13:14 -0800 Subject: [PATCH] Enables subdag to handle `source` and `value` To be consistent and to allow people to remap external values, e.g. subdag config is foo, but we need two different values because we have two subdags that need it. So we can remap the values from external config: Global config: {"foo1": "foo_v1", "foo2": "foo_v2"} Subdag1 config: {"foo": source("foo1")} Subdag2 config: {"foo": source{"foo2")} This also handles if someone wants to use value(). Otherwise to preserve backwards compatibility we need to do this in a specific order and only override config values appropriately. --- hamilton/function_modifiers/recursive.py | 24 ++++++- tests/function_modifiers/test_recursive.py | 76 ++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index a330714a9..e30aa38ef 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -423,7 +423,29 @@ def _derive_name(self, fn: Callable) -> str: def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]: # Resolve all nodes from passed in functions - resolved_config = dict(configuration, **self.config) + # if self.config has source() or value() in it, we need to resolve it + sources_to_map = {} + values_to_include = {} + for key, value in self.config.items(): + if isinstance(value, dependencies.UpstreamDependency): + sources_to_map[key] = value.source + elif isinstance(value, dependencies.LiteralDependency): + values_to_include[key] = value.value + plain_configs = { + k: v + for k, v in self.config.items() + if k not in sources_to_map and k not in values_to_include + } + resolved_config = dict(configuration, **plain_configs, **values_to_include) + # override any values from sources + for key, source in sources_to_map.items(): + try: + resolved_config[key] = resolved_config[source] + except KeyError as e: + raise InvalidDecoratorException( + f"Source {source} was not found in the configuration. This is required for the {fn.__name__} subdag." + ) from e + # resolved_config = dict(configuration, **self.config) nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions) # Derive the namespace under which all these nodes will live namespace = self._derive_namespace(fn) diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index e9b76686c..67d31d892 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -392,6 +392,82 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: ) +def test_nested_subdag_with_config_remapping(): + """Tests that we can remap config values and source and value are resolved correctly.""" + + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken2")}) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + fg = graph.FunctionGraph.from_modules(full_module, config={"broken2": False}) + assert "outer_subdag_1" in fg.nodes + assert "outer_subdag_2" in fg.nodes + res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2}) + # This is effectively the function graph + assert res["sum_all"] == sum_all( + outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) + ) + + +def test_nested_subdag_with_config_remapping_missing_error(): + """Tests that we error if we can't remap a config value.""" + + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken_missing")}) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + with pytest.raises(InvalidDecoratorException): + graph.FunctionGraph.from_modules(full_module, config={"broken2": False}) + + def test_subdag_with_external_nodes_input(): def bar(input_1: int) -> int: return input_1 + 1