Skip to content

Commit

Permalink
Fixes a bug in which nested subdags were not passed required
Browse files Browse the repository at this point in the history
configuration parameters

We're aggressive in pruning parameters we pass to decorators, so they
don't use everything. SubDAGs didn't declare them, as they only know the
configurations required once they parse the function that they decorate,
as well as import the subdags they use.

Rather than the more correct, but more complicated fix, this utilizes
the "out" -- allowing required_config to return None so everything gets
passed in. We need to think about the best way to do this, but for now
this is little harm and we can walk it back -- the contract is solid.
  • Loading branch information
elijahbenizzy committed Mar 27, 2023
1 parent be630a3 commit 7783d21
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 1 deletion.
23 changes: 22 additions & 1 deletion hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import ModuleType
from typing import Any, Callable, Collection, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union

from typing_extensions import NotRequired, TypedDict

Expand Down Expand Up @@ -387,6 +387,20 @@ def validate(self, fn):

self._validate_parameterization()

def required_config(self) -> Optional[List[str]]:
"""Currently we do not filter for subdag as we do not *statically* know what configuration
is required. This is because we need to parse the function so that we can figure it out,
and that is not available at the time that we call required_config. We need to think about
the best way to do this, but its likely that we'll want to allow required_config to consume
the function itself, and pass it in when its called with that.
That said, we don't have sufficient justification to do that yet, so we're just going to
return None for now, meaning that it has access to all configuration variables.
:return:
"""
return None


class SubdagParams(TypedDict):
inputs: NotRequired[Dict[str, ParametrizedDependency]]
Expand Down Expand Up @@ -511,3 +525,10 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
def validate(self, fn: Callable):
for subdag_generator in self._gather_subdag_generators():
subdag_generator.validate(fn)

def required_config(self) -> Optional[List[str]]:
"""See note for subdag.required_config -- this is the same pattern.
:return: Any required config items.
"""
return None
114 changes: 114 additions & 0 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,45 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
)


def test_nested_subdag_with_config():
def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
fg = graph.FunctionGraph(full_module, config={"broken": False})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)


def test_subdag_with_external_nodes_input():
def bar(input_1: int) -> int:
return input_1 + 1
Expand Down Expand Up @@ -358,3 +397,78 @@ def foo_bar_baz_summed(foo_bar_baz_input_1: int, foo_bar_baz_input_2: int) -> in
assert res["foo_bar_baz_summed"] == foo_bar_baz(foo(10), foo(20), 100) + foo_bar_baz(
bar(30), foo(40), 100
)


def test_parameterized_subdag_with_config():
def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
fg = graph.FunctionGraph(full_module, config={"broken": False})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)


def test_nested_parameterized_subdag_with_config():
def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@parameterized_subdag(foo, bar, inner_subdag={})
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@parameterized_subdag(
inner_subdag,
outer_subdag_1={"inputs": {"input_2": value(10)}},
outer_subdag_2={"inputs": {"input_2": value(3)}},
)
def outer_subdag_n(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_n, sum_all)
fg = graph.FunctionGraph(full_module, config={"broken": False})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_n(inner_subdag(bar(2), foo(10))), outer_subdag_n(inner_subdag(bar(2), foo(3)))
)

0 comments on commit 7783d21

Please sign in to comment.