From be630a37062cfa3ac6be920de71fda24a9a334d0 Mon Sep 17 00:00:00 2001 From: elijahbenizzy Date: Wed, 22 Mar 2023 19:54:08 -0700 Subject: [PATCH] Adds ability to specify external inputs in subdag/parameterized_subdag See https://github.com/DAGWorks-Inc/hamilton/issues/115 --- hamilton/function_modifiers/recursive.py | 29 ++++++++++-- tests/function_modifiers/test_recursive.py | 51 ++++++++++++++++++++++ 2 files changed, 76 insertions(+), 4 deletions(-) diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index 0e25cd4c0..f46dc3ba0 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -127,6 +127,10 @@ def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame: E.G. take a certain set of nodes, and run them with specified parameters. + @subdag declares parameters that are outputs of its subdags. Note that, if you want to use outputs of other + components of the DAG, you can use the `external_inputs` parameter to declare the parameters that do *not* come + from the subDAG. + Why might you want to use this? Let's take a look at some examples: 1. You have a feature engineering pipeline that you want to run on multiple datasets. If its exactly the same, \ @@ -148,6 +152,7 @@ def __init__( config: Dict[str, Any] = None, namespace: str = None, final_node_name: str = None, + external_inputs: List[str] = None, ): """Adds a subDAG to the main DAG. @@ -160,10 +165,14 @@ def __init__( this will default to the function name. :param final_node_name: Name of the final node in the subDAG. This is optional -- if not included, this will default to the function name. + :param external_inputs: Parameters in the function that are not produced by the functions + passed to the subdag. This is useful if you want to perform some logic with other inputs + in the subdag's processing function. """ self.subdag_functions = subdag.collect_functions(load_from) self.inputs = inputs if inputs is not None else {} self.config = config if config is not None else {} + self.external_inputs = external_inputs if external_inputs is not None else [] self._validate_config_inputs(self.config, self.inputs) self.namespace = namespace self.final_node_name = final_node_name @@ -307,9 +316,14 @@ def add_final_node(self, fn: Callable, node_name: str, namespace: str): :return: """ node_ = node.Node.from_fn(fn) - namespaced_input_map = {assign_namespace(key, namespace): key for key in node_.input_types} + namespaced_input_map = { + (assign_namespace(key, namespace) if key not in self.external_inputs else key): key + for key in node_.input_types + } + new_input_types = { - assign_namespace(key, namespace): value for key, value in node_.input_types.items() + (assign_namespace(key, namespace) if key not in self.external_inputs else key): value + for key, value in node_.input_types.items() } def new_function(**kwargs): @@ -377,11 +391,12 @@ def validate(self, fn): class SubdagParams(TypedDict): inputs: NotRequired[Dict[str, ParametrizedDependency]] config: NotRequired[Dict[str, Any]] + external_inputs: NotRequired[List[str]] class parameterized_subdag(base.NodeCreator): """parameterized subdag is when you want to create multiple subdags at one time. - Why do you want to do this? + Why might you want to do this? 1. You have multiple data sets you want to run the same feature engineering pipeline on. 2. You want to run some sort of optimization routine with a variety of results @@ -444,6 +459,7 @@ def __init__( str, Union[dependencies.ParametrizedDependency, dependencies.LiteralDependency] ] = None, config: Dict[str, Any] = None, + external_inputs: List[str] = None, **parameterization: SubdagParams, ): """Initializes a parameterized_subdag decorator. @@ -451,6 +467,8 @@ def __init__( :param load_from: Modules to load from :param inputs: Inputs for each subdag generated by the decorated function :param config: Config for each subdag generated by the decorated function + :param external_inputs: External inputs to all parameterized subdags. Note that + if you pass in any external inputs from local subdags, it overrides this (does not merge). :param parameterization: Parameterizations for each subdag generated. Note that this *overrides* any inputs/config passed to the decorator itself. @@ -460,12 +478,14 @@ def __init__( allowed to name these `load_from`, `inputs`, or `config`. That's a good thing, as these are not good names for variables anyway. - 2. Any empty items (not included) will default to an empty dict + 2. Any empty items (not included) will default to an empty dict (or an empty list in + the case of parameterization) """ self.load_from = load_from self.inputs = inputs if inputs is not None else {} self.config = config if config is not None else {} self.parameterization = parameterization + self.external_inputs = external_inputs if external_inputs is not None else [] def _gather_subdag_generators(self) -> List[subdag]: subdag_generators = [] @@ -475,6 +495,7 @@ def _gather_subdag_generators(self) -> List[subdag]: *self.load_from, inputs={**self.inputs, **parameterization.get("inputs", {})}, config={**self.config, **parameterization.get("config", {})}, + external_inputs=parameterization.get("external_inputs", self.external_inputs), namespace=key, final_node_name=key, ) diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index c78ed0180..e9b45a48d 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -307,3 +307,54 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: assert res["sum_all"] == sum_all( outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) ) + + +def test_subdag_with_external_nodes_input(): + def bar(input_1: int) -> int: + return input_1 + 1 + + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag(foo, bar, external_inputs=["baz"]) + def foo_bar_baz(foo: int, bar: int, baz: int) -> int: + return foo + bar + baz + + full_module = ad_hoc_utils.create_temporary_module(foo_bar_baz) + fg = graph.FunctionGraph(full_module, config={}) + # since we've provided it above, + assert "baz" in fg.nodes + assert fg.nodes["baz"].user_defined + res = fg.execute(nodes=[fg.nodes["foo_bar_baz"]], inputs={"input_1": 2, "input_2": 3, "baz": 4}) + assert res["foo_bar_baz"] == foo_bar_baz(foo(3), bar(2), 4) + + +def test_parameterized_subdag_with_external_inputs_global(): + def bar(input_1: int) -> int: + return input_1 + 1 + + def foo(input_2: int) -> int: + return input_2 + 1 + + @parameterized_subdag( + foo, + bar, + external_inputs=["baz"], + foo_bar_baz_input_1={"inputs": {"input_1": value(10), "input_2": value(20)}}, + foo_bar_baz_input_2={"inputs": {"input_1": value(30), "input_2": value(40)}}, + ) + def foo_bar_baz(foo: int, bar: int, baz: int) -> int: + return foo + bar + baz + + def foo_bar_baz_summed(foo_bar_baz_input_1: int, foo_bar_baz_input_2: int) -> int: + return foo_bar_baz_input_1 + foo_bar_baz_input_2 + + full_module = ad_hoc_utils.create_temporary_module(foo_bar_baz, foo_bar_baz_summed) + fg = graph.FunctionGraph(full_module, config={}) + # since we've provided it above, + assert "baz" in fg.nodes + assert fg.nodes["baz"].user_defined + res = fg.execute(nodes=[fg.nodes["foo_bar_baz_summed"]], inputs={"baz": 100}) + assert res["foo_bar_baz_summed"] == foo_bar_baz(foo(10), foo(20), 100) + foo_bar_baz( + bar(30), foo(40), 100 + )