Skip to content

Commit

Permalink
Adds configuration object and refactors code
Browse files Browse the repository at this point in the history
Named it configuration to not clash with config decorator.
  • Loading branch information
skrawcz committed Dec 12, 2024
1 parent 4d93f0a commit 169862d
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 33 deletions.
Binary file modified examples/model_examples/modular_example/my_dag.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/model_examples/modular_example/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@
"\n",
"import pandas as pd\n",
"\n",
"from hamilton.function_modifiers import subdag, extract_fields, value, source\n",
"from hamilton.function_modifiers import subdag, extract_fields, configuration, source\n",
"import features\n",
"import train\n",
"import inference\n",
Expand All @@ -316,7 +316,7 @@
" },\n",
" # there are several ways to pass in configuration.\n",
" # config={ \n",
" # \"model\": source(\"model\")\n",
" # \"model\": configuration(\"model\")\n",
" # },\n",
")\n",
"def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/model_examples/modular_example/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import train

from hamilton.function_modifiers import extract_fields, source, subdag
from hamilton.function_modifiers import configuration, extract_fields, source, subdag


@extract_fields({"fit_model": Any, "training_prediction": pd.DataFrame})
Expand All @@ -18,7 +18,7 @@
"model_params": source("model_params"),
},
config={
"model": source("model"), # not strictly required but allows us to remap.
"model": configuration("train_model_type"), # not strictly required but allows us to remap.
},
)
def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion examples/model_examples/modular_example/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def run():
dr = (
driver.Builder()
.with_config({"model": "RandomForest", "model_params": {"n_estimators": 100}})
.with_config({"train_model_type": "RandomForest", "model_params": {"n_estimators": 100}})
.with_modules(pipeline)
.build()
)
Expand Down
1 change: 1 addition & 0 deletions hamilton/function_modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
value = dependencies.value
source = dependencies.source
group = dependencies.group
configuration = dependencies.configuration

# These aren't strictly part of the API but we should have them here for safety
LiteralDependency = dependencies.LiteralDependency
Expand Down
28 changes: 24 additions & 4 deletions hamilton/function_modifiers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ParametrizedDependencySource(enum.Enum):
UPSTREAM = "upstream"
GROUPED_LIST = "grouped_list"
GROUPED_DICT = "grouped_dict"
CONFIGURATION = "configuration"


class ParametrizedDependency:
Expand Down Expand Up @@ -44,6 +45,14 @@ def get_dependency_type(self) -> ParametrizedDependencySource:
return ParametrizedDependencySource.UPSTREAM


@dataclasses.dataclass
class ConfigDependency(SingleDependency):
source: str

def get_dependency_type(self) -> ParametrizedDependencySource:
return ParametrizedDependencySource.CONFIGURATION


class GroupedDependency(ParametrizedDependency, abc.ABC):
@classmethod
@abc.abstractmethod
Expand Down Expand Up @@ -123,8 +132,8 @@ def value(literal_value: Any) -> LiteralDependency:
E.G. value("foo") means that the value is actually the string value "foo".
:param literal_value: Python literal value to use. :return: A LiteralDependency object -- a
signifier to the internal framework of the dependency type.
:param literal_value: Python literal value to use.
:return: A LiteralDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(literal_value, LiteralDependency):
return literal_value
Expand All @@ -138,14 +147,25 @@ def source(dependency_on: Any) -> UpstreamDependency:
be assigned the value that "foo" outputs.
:param dependency_on: Upstream function (i.e. node) to come from.
:return: An
UpstreamDependency object -- a signifier to the internal framework of the dependency type.
:return: An UpstreamDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(dependency_on, UpstreamDependency):
return dependency_on
return UpstreamDependency(source=dependency_on)


def configuration(dependency_on: str) -> ConfigDependency:
"""Specifies that a parameterized dependency comes from the global `config` passed in.
This means that it comes from a global configuration key value. E.G. config("foo") means that it should
be assigned the value that the "foo" key in global configuration passed to Hamilton maps to.
:param dependency_on: name of the configuration key to pull from.
:return: An ConfigDependency object -- a signifier to the internal framework of the dependency type.
"""
return ConfigDependency(source=dependency_on)


def _validate_group_params(
dependency_args: List[ParametrizedDependency],
dependency_kwargs: Dict[str, ParametrizedDependency],
Expand Down
56 changes: 34 additions & 22 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,38 @@ def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]):
)


def _resolve_subdag_configuration(
configuration: Dict[str, Any], fields: Dict[str, Any], function_name: str
) -> Dict[str, Any]:
"""Resolves the configuration for a subdag.
:param configuration: the Hamilton configuration
:param fields: the fields passed to the subdag decorator
:return: resolved configuration to use for this subdag.
"""
sources_to_map = {}
values_to_include = {}
for key, value in fields.items():
if isinstance(value, dependencies.ConfigDependency):
sources_to_map[key] = value.source
elif isinstance(value, dependencies.LiteralDependency):
values_to_include[key] = value.value
plain_configs = {
k: v for k, v in fields.items() if k not in sources_to_map and k not in values_to_include
}
resolved_config = dict(configuration, **plain_configs, **values_to_include)

# override any values from sources
for key, source in sources_to_map.items():
try:
resolved_config[key] = resolved_config[source]
except KeyError as e:
raise InvalidDecoratorException(
f"Source {source} was not found in the configuration. This is required for the {function_name} subdag."
) from e
return resolved_config


NON_FINAL_TAGS = {NodeTransformer.NON_FINAL_TAG: True}


Expand Down Expand Up @@ -423,28 +455,8 @@ def _derive_name(self, fn: Callable) -> str:

def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]:
# Resolve all nodes from passed in functions
# if self.config has source() or value() in it, we need to resolve it
sources_to_map = {}
values_to_include = {}
for key, value in self.config.items():
if isinstance(value, dependencies.UpstreamDependency):
sources_to_map[key] = value.source
elif isinstance(value, dependencies.LiteralDependency):
values_to_include[key] = value.value
plain_configs = {
k: v
for k, v in self.config.items()
if k not in sources_to_map and k not in values_to_include
}
resolved_config = dict(configuration, **plain_configs, **values_to_include)
# override any values from sources
for key, source in sources_to_map.items():
try:
resolved_config[key] = resolved_config[source]
except KeyError as e:
raise InvalidDecoratorException(
f"Source {source} was not found in the configuration. This is required for the {fn.__name__} subdag."
) from e
# if self.config has configuration() or value() in it, we need to resolve it
resolved_config = _resolve_subdag_configuration(configuration, self.config, fn.__name__)
# resolved_config = dict(configuration, **self.config)
nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions)
# Derive the namespace under which all these nodes will live
Expand Down
33 changes: 31 additions & 2 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hamilton.function_modifiers import (
InvalidDecoratorException,
config,
configuration,
parameterized_subdag,
recursive,
subdag,
Expand Down Expand Up @@ -413,7 +414,7 @@ def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken2")})
@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": configuration("broken2")})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

Expand Down Expand Up @@ -454,7 +455,11 @@ def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken_missing")})
@subdag(
inner_subdag,
inputs={"input_2": value(3)},
config={"broken": configuration("broken_missing")},
)
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

Expand All @@ -468,6 +473,30 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
graph.FunctionGraph.from_modules(full_module, config={"broken2": False})


@pytest.mark.parametrize(
"configuration,fields,expected",
[
({"a": 1, "b": 2}, {}, {"a": 1, "b": 2}),
({"a": 1, "b": 2}, {"c": value(3)}, {"a": 1, "b": 2, "c": 3}),
(
{"a": 1, "b": 2},
{"c": value(3), "d": configuration("a")},
{"a": 1, "b": 2, "c": 3, "d": 1},
),
],
)
def test_resolve_subdag_configuration_happy(configuration, fields, expected):
actual = recursive._resolve_subdag_configuration(configuration, fields, "test")
assert actual == expected


def test_resolve_subdag_configuration_bad_mapping():
_configuration = {"a": 1, "b": 2}
fields = {"c": value(3), "d": configuration("e")}
with pytest.raises(InvalidDecoratorException):
recursive._resolve_subdag_configuration(_configuration, fields, "test")


def test_subdag_with_external_nodes_input():
def bar(input_1: int) -> int:
return input_1 + 1
Expand Down

0 comments on commit 169862d

Please sign in to comment.