Skip to content

Commit

Permalink
Enables subdag to handle source and value
Browse files Browse the repository at this point in the history
To be consistent and to allow people to remap external values,

e.g. subdag config is foo, but we need two different values
because we have two subdags that need it. So we can
remap the values from external config:

Global config: {"foo1": "foo_v1", "foo2": "foo_v2"}
Subdag1 config: {"foo": source("foo1")}
Subdag2 config: {"foo": source{"foo2")}

This also handles if someone wants to use value().

Otherwise to preserve backwards compatibility we need
to do this in a specific order and only override config values
appropriately.
  • Loading branch information
skrawcz committed Dec 7, 2024
1 parent 4e8dc43 commit fe9b77a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
24 changes: 23 additions & 1 deletion hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,29 @@ def _derive_name(self, fn: Callable) -> str:

def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]:
# Resolve all nodes from passed in functions
resolved_config = dict(configuration, **self.config)
# if self.config has source() or value() in it, we need to resolve it
sources_to_map = {}
values_to_include = {}
for key, value in self.config.items():
if isinstance(value, dependencies.UpstreamDependency):
sources_to_map[key] = value.source
elif isinstance(value, dependencies.LiteralDependency):
values_to_include[key] = value.value
plain_configs = {
k: v
for k, v in self.config.items()
if k not in sources_to_map and k not in values_to_include
}
resolved_config = dict(configuration, **plain_configs, **values_to_include)
# override any values from sources
for key, source in sources_to_map.items():
try:
resolved_config[key] = resolved_config[source]
except KeyError as e:
raise InvalidDecoratorException(
f"Source {source} was not found in the configuration. This is required for the {fn.__name__} subdag."
) from e
# resolved_config = dict(configuration, **self.config)
nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions)
# Derive the namespace under which all these nodes will live
namespace = self._derive_namespace(fn)
Expand Down
76 changes: 76 additions & 0 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,82 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
)


def test_nested_subdag_with_config_remapping():
"""Tests that we can remap config values and source and value are resolved correctly."""

def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken2")})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
fg = graph.FunctionGraph.from_modules(full_module, config={"broken2": False})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)


def test_nested_subdag_with_config_remapping_missing_error():
"""Tests that we error if we can't remap a config value."""

def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken_missing")})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
with pytest.raises(InvalidDecoratorException):
graph.FunctionGraph.from_modules(full_module, config={"broken2": False})


def test_subdag_with_external_nodes_input():
def bar(input_1: int) -> int:
return input_1 + 1
Expand Down

0 comments on commit fe9b77a

Please sign in to comment.