Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to allow nesting of subdags #116

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
EllipsisType = type(...)
from typing import Any, Callable, Collection, Dict, List, Optional, Union

from hamilton import node, registry
from hamilton import node, registry, settings

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -476,16 +476,22 @@ def resolve_config(
:param config_optional_with_defaults:
:return:
"""
config_optional_with_global_defaults_applied = config_optional_with_defaults.copy()
config_optional_with_global_defaults_applied[
settings.ENABLE_POWER_USER_MODE
] = config_optional_with_global_defaults_applied.get(settings.ENABLE_POWER_USER_MODE, False)
missing_keys = (
set(config_required) - set(config.keys()) - set(config_optional_with_defaults.keys())
set(config_required)
- set(config.keys())
- set(config_optional_with_global_defaults_applied.keys())
)
if len(missing_keys) > 0:
raise MissingConfigParametersException(
f"The following configurations are required by {name_for_error}: {missing_keys}"
)
config_out = {key: config[key] for key in config_required}
for key in config_optional_with_defaults:
config_out[key] = config.get(key, config_optional_with_defaults[key])
for key in config_optional_with_global_defaults_applied:
config_out[key] = config.get(key, config_optional_with_global_defaults_applied[key])
return config_out


Expand Down
12 changes: 6 additions & 6 deletions hamilton/function_modifiers/delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ResolveAt(enum.Enum):
VALID_PARAM_KINDS = [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]


def extract_and_validate_params(fn: Callable) -> Tuple[List[str], List[str]]:
def extract_and_validate_params(fn: Callable) -> Tuple[List[str], Dict[str, Any]]:
"""Gets the parameters from a function, while validating that
the function has *only* named arguments.

Expand All @@ -27,14 +27,14 @@ def extract_and_validate_params(fn: Callable) -> Tuple[List[str], List[str]]:
"""
invalid_params = []
required_params = []
optional_params = []
optional_params = {}
sig = inspect.signature(fn)
for key, value in inspect.signature(fn).parameters.items():
if value.kind not in VALID_PARAM_KINDS:
invalid_params.append(key)
else:
if value.default is not value.empty:
optional_params.append(key)
optional_params[key] = value.default
else:
required_params.append(key)
if invalid_params:
Expand Down Expand Up @@ -122,11 +122,11 @@ def __init__(self, *, when: ResolveAt, decorate_with: Callable[..., NodeTransfor
def required_config(self) -> Optional[List[str]]:
return self._required_config

def optional_config(self) -> Optional[List[str]]:
return self._optional_config + [settings.ENABLE_POWER_USER_MODE]
def optional_config(self) -> Optional[Dict[str, Any]]:
return self._optional_config

def resolve(self, config: Dict[str, Any], fn: Callable) -> NodeTransformLifecycle:
if not config.get(settings.ENABLE_POWER_USER_MODE, False):
if not config[settings.ENABLE_POWER_USER_MODE]:
raise InvalidDecoratorException(
"Dynamic functions are only allowed in power user mode!"
"Why? This is occasionally needed to enable highly flexible "
Expand Down
10 changes: 1 addition & 9 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,13 @@ def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.No
:param nodes:
:return:
"""
already_namespaced_nodes = []
# already_namespaced_nodes = []
new_nodes = []
new_name_map = {}
# First pass we validate + collect names so we can alter dependencies
for node_ in nodes:
new_name = assign_namespace(node_.name, namespace)
new_name_map[node_.name] = new_name
current_node_namespaces = node_.namespace
if current_node_namespaces:
already_namespaced_nodes.append(node_)
for dep, value in self.inputs.items():
# We create nodes for both namespace assignment and source assignment
# Why? Cause we need unique parameter names, and with source() some can share params
Expand All @@ -274,11 +271,6 @@ def _add_namespace(self, nodes: List[node.Node], namespace: str) -> List[node.No
for dep, value in self.config.items():
new_name_map[dep] = assign_namespace(dep, namespace)

if already_namespaced_nodes:
raise ValueError(
f"The following nodes are already namespaced: {already_namespaced_nodes}. "
f"We currently do not allow for multiple namespaces (E.G. layered subDAGs)."
)
# Reassign sources
for node_ in nodes:
new_name = new_name_map[node_.name]
Expand Down
32 changes: 26 additions & 6 deletions tests/function_modifiers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest as pytest

from hamilton import node
from hamilton import node, settings
from hamilton.function_modifiers import InvalidDecoratorException, base
from hamilton.function_modifiers.base import (
MissingConfigParametersException,
Expand All @@ -11,16 +11,36 @@
)
from hamilton.node import Node

power_mode_k = settings.ENABLE_POWER_USER_MODE


@pytest.mark.parametrize(
"config,config_required,config_optional_with_defaults,expected",
[
({"foo": 1}, ["foo"], {}, {"foo": 1}),
({"foo": 1, "bar": 2}, ["foo"], {}, {"foo": 1}),
({"foo": 1, "bar": 2}, ["foo"], {"bar": 3}, {"foo": 1, "bar": 2}),
({"foo": 1}, [], {"bar": 3}, {"bar": 3}),
({"foo": 1}, ["foo"], {}, {"foo": 1, power_mode_k: False}),
({"foo": 1, "bar": 2}, ["foo"], {}, {"foo": 1, power_mode_k: False}),
({"foo": 1, "bar": 2}, ["foo"], {"bar": 3}, {"foo": 1, "bar": 2, power_mode_k: False}),
({"foo": 1}, [], {"bar": 3}, {"bar": 3, power_mode_k: False}),
({"foo": 1, power_mode_k: True}, ["foo"], {}, {"foo": 1, power_mode_k: True}),
({"foo": 1, "bar": 2, power_mode_k: True}, ["foo"], {}, {"foo": 1, power_mode_k: True}),
(
{"foo": 1, "bar": 2, power_mode_k: True},
["foo"],
{"bar": 3},
{"foo": 1, "bar": 2, power_mode_k: True},
),
({"foo": 1, power_mode_k: True}, [], {"bar": 3}, {"bar": 3, power_mode_k: True}),
],
ids=[
"all_present",
"all_present_extra",
"no_apply_default",
"yes_apply_default",
"all_present_with_power_user_mode",
"all_present_extra_with_power_user_mode",
"no_apply_default_with_power_user_mode",
"yes_apply_default_with_power_user_mode",
],
ids=["all_present", "all_present_extra", "no_apply_default", "yes_apply_default"],
)
def test_merge_config_happy(config, config_required, config_optional_with_defaults, expected):
assert (
Expand Down
28 changes: 17 additions & 11 deletions tests/function_modifiers/test_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
settings.ENABLE_POWER_USER_MODE: True,
}

CONFIG_WITH_POWER_MODE_DISABLED = {
settings.ENABLE_POWER_USER_MODE: False,
}


@pytest.mark.parametrize(
"fn,required,optional",
[
(lambda: 1, [], []),
(lambda a, b: 1, ["a", "b"], []),
(lambda a, b=1: 1, ["a"], ["b"]),
(lambda a=1, b=1: 1, [], ["a", "b"]),
(lambda: 1, [], {}),
(lambda a, b: 1, ["a", "b"], {}),
(lambda a, b=1: 1, ["a"], {"b": 1}),
(lambda a=1, b=1: 1, [], {"a": 1, "b": 1}),
],
)
def test_extract_and_validate_params_happy(fn: Callable, required: Callable, optional: Callable):
Expand Down Expand Up @@ -54,14 +58,14 @@ def test_dynamic_resolves():
assert decorator_resolved.columns == ("a", "b")


def test_dynamic_fails_without_config_provided():
def test_dynamic_fails_without_power_mode_fails():
decorator = resolve(
when=ResolveAt.CONFIG_AVAILABLE,
decorate_with=lambda cols_to_extract: extract_columns(*cols_to_extract),
)
with pytest.raises(base.InvalidDecoratorException):
decorator_resolved = decorator.resolve(
CONFIG_WITH_POWER_MODE_ENABLED, fn=test_dynamic_fails_without_config_provided
CONFIG_WITH_POWER_MODE_DISABLED, fn=test_dynamic_fails_without_power_mode_fails
)
# This uses an internal component of extract_columns
# We may want to add a little more comprehensive testing
Expand All @@ -77,10 +81,9 @@ def test_config_derivation():
),
)
assert decorator.required_config() == ["cols_to_extract"]
assert decorator.optional_config() == [
"some_cols_you_might_want_to_extract",
settings.ENABLE_POWER_USER_MODE,
]
assert decorator.optional_config() == {
"some_cols_you_might_want_to_extract": [],
}


def test_delayed_with_optional():
Expand Down Expand Up @@ -114,4 +117,7 @@ def test_delayed_without_power_mode_fails():
]: extract_columns(*cols_to_extract + some_cols_you_might_want_to_extract),
)
with pytest.raises(base.InvalidDecoratorException):
decorator.resolve({"cols_to_extract": ["a", "b"]}, fn=test_delayed_with_optional)
decorator.resolve(
{"cols_to_extract": ["a", "b"], **CONFIG_WITH_POWER_MODE_DISABLED},
fn=test_delayed_with_optional,
)
40 changes: 39 additions & 1 deletion tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import tests.resources.reuse_subdag
from hamilton import ad_hoc_utils, graph
from hamilton.function_modifiers import config, parameterized_subdag, recursive, value
from hamilton.function_modifiers import config, parameterized_subdag, recursive, subdag, value
from hamilton.function_modifiers.dependencies import source


Expand Down Expand Up @@ -269,3 +269,41 @@ def subdag_processor(foo: int, bar: int, baz: int) -> Tuple[int, int, int]:
assert nodes_by_name["v0.baz"].callable(**{"v0.foo": 1, "v0.bar": 2}) == 3
assert nodes_by_name["v1.baz"].callable(**{"v1.foo": 1, "v1.bar": 2}) == 3
assert nodes_by_name["v2.baz"].callable(**{"v2.foo": 1, "v2.bar": 2}) == 2


def test_nested_subdag():
def bar(input_1: int) -> int:
return input_1 + 1

def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)}, config={"plus_one": True})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"plus_one": False})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
fg = graph.FunctionGraph(full_module, config={})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)