Skip to content

Commit

Permalink
Adds ability to specify external inputs in subdag/parameterized_subdag
Browse files Browse the repository at this point in the history
See #115
  • Loading branch information
elijahbenizzy committed Mar 23, 2023
1 parent d4d86b4 commit be630a3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
29 changes: 25 additions & 4 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def feature_engineering(feature_df: pd.DataFrame) -> pd.DataFrame:
E.G. take a certain set of nodes, and run them with specified parameters.
@subdag declares parameters that are outputs of its subdags. Note that, if you want to use outputs of other
components of the DAG, you can use the `external_inputs` parameter to declare the parameters that do *not* come
from the subDAG.
Why might you want to use this? Let's take a look at some examples:
1. You have a feature engineering pipeline that you want to run on multiple datasets. If its exactly the same, \
Expand All @@ -148,6 +152,7 @@ def __init__(
config: Dict[str, Any] = None,
namespace: str = None,
final_node_name: str = None,
external_inputs: List[str] = None,
):
"""Adds a subDAG to the main DAG.
Expand All @@ -160,10 +165,14 @@ def __init__(
this will default to the function name.
:param final_node_name: Name of the final node in the subDAG. This is optional -- if not included,
this will default to the function name.
:param external_inputs: Parameters in the function that are not produced by the functions
passed to the subdag. This is useful if you want to perform some logic with other inputs
in the subdag's processing function.
"""
self.subdag_functions = subdag.collect_functions(load_from)
self.inputs = inputs if inputs is not None else {}
self.config = config if config is not None else {}
self.external_inputs = external_inputs if external_inputs is not None else []
self._validate_config_inputs(self.config, self.inputs)
self.namespace = namespace
self.final_node_name = final_node_name
Expand Down Expand Up @@ -307,9 +316,14 @@ def add_final_node(self, fn: Callable, node_name: str, namespace: str):
:return:
"""
node_ = node.Node.from_fn(fn)
namespaced_input_map = {assign_namespace(key, namespace): key for key in node_.input_types}
namespaced_input_map = {
(assign_namespace(key, namespace) if key not in self.external_inputs else key): key
for key in node_.input_types
}

new_input_types = {
assign_namespace(key, namespace): value for key, value in node_.input_types.items()
(assign_namespace(key, namespace) if key not in self.external_inputs else key): value
for key, value in node_.input_types.items()
}

def new_function(**kwargs):
Expand Down Expand Up @@ -377,11 +391,12 @@ def validate(self, fn):
class SubdagParams(TypedDict):
inputs: NotRequired[Dict[str, ParametrizedDependency]]
config: NotRequired[Dict[str, Any]]
external_inputs: NotRequired[List[str]]


class parameterized_subdag(base.NodeCreator):
"""parameterized subdag is when you want to create multiple subdags at one time.
Why do you want to do this?
Why might you want to do this?
1. You have multiple data sets you want to run the same feature engineering pipeline on.
2. You want to run some sort of optimization routine with a variety of results
Expand Down Expand Up @@ -444,13 +459,16 @@ def __init__(
str, Union[dependencies.ParametrizedDependency, dependencies.LiteralDependency]
] = None,
config: Dict[str, Any] = None,
external_inputs: List[str] = None,
**parameterization: SubdagParams,
):
"""Initializes a parameterized_subdag decorator.
:param load_from: Modules to load from
:param inputs: Inputs for each subdag generated by the decorated function
:param config: Config for each subdag generated by the decorated function
:param external_inputs: External inputs to all parameterized subdags. Note that
if you pass in any external inputs from local subdags, it overrides this (does not merge).
:param parameterization: Parameterizations for each subdag generated.
Note that this *overrides* any inputs/config passed to the decorator itself.
Expand All @@ -460,12 +478,14 @@ def __init__(
allowed to name these `load_from`, `inputs`, or `config`. That's a good thing, as these
are not good names for variables anyway.
2. Any empty items (not included) will default to an empty dict
2. Any empty items (not included) will default to an empty dict (or an empty list in
the case of parameterization)
"""
self.load_from = load_from
self.inputs = inputs if inputs is not None else {}
self.config = config if config is not None else {}
self.parameterization = parameterization
self.external_inputs = external_inputs if external_inputs is not None else []

def _gather_subdag_generators(self) -> List[subdag]:
subdag_generators = []
Expand All @@ -475,6 +495,7 @@ def _gather_subdag_generators(self) -> List[subdag]:
*self.load_from,
inputs={**self.inputs, **parameterization.get("inputs", {})},
config={**self.config, **parameterization.get("config", {})},
external_inputs=parameterization.get("external_inputs", self.external_inputs),
namespace=key,
final_node_name=key,
)
Expand Down
51 changes: 51 additions & 0 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,54 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)


def test_subdag_with_external_nodes_input():
def bar(input_1: int) -> int:
return input_1 + 1

def foo(input_2: int) -> int:
return input_2 + 1

@subdag(foo, bar, external_inputs=["baz"])
def foo_bar_baz(foo: int, bar: int, baz: int) -> int:
return foo + bar + baz

full_module = ad_hoc_utils.create_temporary_module(foo_bar_baz)
fg = graph.FunctionGraph(full_module, config={})
# since we've provided it above,
assert "baz" in fg.nodes
assert fg.nodes["baz"].user_defined
res = fg.execute(nodes=[fg.nodes["foo_bar_baz"]], inputs={"input_1": 2, "input_2": 3, "baz": 4})
assert res["foo_bar_baz"] == foo_bar_baz(foo(3), bar(2), 4)


def test_parameterized_subdag_with_external_inputs_global():
def bar(input_1: int) -> int:
return input_1 + 1

def foo(input_2: int) -> int:
return input_2 + 1

@parameterized_subdag(
foo,
bar,
external_inputs=["baz"],
foo_bar_baz_input_1={"inputs": {"input_1": value(10), "input_2": value(20)}},
foo_bar_baz_input_2={"inputs": {"input_1": value(30), "input_2": value(40)}},
)
def foo_bar_baz(foo: int, bar: int, baz: int) -> int:
return foo + bar + baz

def foo_bar_baz_summed(foo_bar_baz_input_1: int, foo_bar_baz_input_2: int) -> int:
return foo_bar_baz_input_1 + foo_bar_baz_input_2

full_module = ad_hoc_utils.create_temporary_module(foo_bar_baz, foo_bar_baz_summed)
fg = graph.FunctionGraph(full_module, config={})
# since we've provided it above,
assert "baz" in fg.nodes
assert fg.nodes["baz"].user_defined
res = fg.execute(nodes=[fg.nodes["foo_bar_baz_summed"]], inputs={"baz": 100})
assert res["foo_bar_baz_summed"] == foo_bar_baz(foo(10), foo(20), 100) + foo_bar_baz(
bar(30), foo(40), 100
)

0 comments on commit be630a3

Please sign in to comment.