diff --git a/examples/model_examples/modular_example/my_dag.png b/examples/model_examples/modular_example/my_dag.png index 0f152fdea..936e7be29 100644 Binary files a/examples/model_examples/modular_example/my_dag.png and b/examples/model_examples/modular_example/my_dag.png differ diff --git a/examples/model_examples/modular_example/notebook.ipynb b/examples/model_examples/modular_example/notebook.ipynb index aae090fb6..f6128da4d 100644 --- a/examples/model_examples/modular_example/notebook.ipynb +++ b/examples/model_examples/modular_example/notebook.ipynb @@ -300,7 +300,7 @@ "\n", "import pandas as pd\n", "\n", - "from hamilton.function_modifiers import subdag, extract_fields, value, source\n", + "from hamilton.function_modifiers import subdag, extract_fields, configuration, source\n", "import features\n", "import train\n", "import inference\n", @@ -316,7 +316,7 @@ " },\n", " # there are several ways to pass in configuration.\n", " # config={ \n", - " # \"model\": source(\"model\")\n", + " # \"model\": configuration(\"model\")\n", " # },\n", ")\n", "def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:\n", diff --git a/examples/model_examples/modular_example/pipeline.py b/examples/model_examples/modular_example/pipeline.py index 7ccecbf0d..0c7ebb0eb 100644 --- a/examples/model_examples/modular_example/pipeline.py +++ b/examples/model_examples/modular_example/pipeline.py @@ -5,7 +5,7 @@ import pandas as pd import train -from hamilton.function_modifiers import extract_fields, source, subdag +from hamilton.function_modifiers import configuration, extract_fields, source, subdag @extract_fields({"fit_model": Any, "training_prediction": pd.DataFrame}) @@ -18,7 +18,7 @@ "model_params": source("model_params"), }, config={ - "model": source("model"), # not strictly required but allows us to remap. + "model": configuration("train_model_type"), # not strictly required but allows us to remap. }, ) def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict: diff --git a/examples/model_examples/modular_example/run.py b/examples/model_examples/modular_example/run.py index 4b388c943..415bd2863 100644 --- a/examples/model_examples/modular_example/run.py +++ b/examples/model_examples/modular_example/run.py @@ -6,7 +6,7 @@ def run(): dr = ( driver.Builder() - .with_config({"model": "RandomForest", "model_params": {"n_estimators": 100}}) + .with_config({"train_model_type": "RandomForest", "model_params": {"n_estimators": 100}}) .with_modules(pipeline) .build() ) diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py index 958d07540..3113b13ff 100644 --- a/hamilton/function_modifiers/__init__.py +++ b/hamilton/function_modifiers/__init__.py @@ -36,6 +36,7 @@ value = dependencies.value source = dependencies.source group = dependencies.group +configuration = dependencies.configuration # These aren't strictly part of the API but we should have them here for safety LiteralDependency = dependencies.LiteralDependency diff --git a/hamilton/function_modifiers/dependencies.py b/hamilton/function_modifiers/dependencies.py index 26785505c..4f0c6d159 100644 --- a/hamilton/function_modifiers/dependencies.py +++ b/hamilton/function_modifiers/dependencies.py @@ -16,6 +16,7 @@ class ParametrizedDependencySource(enum.Enum): UPSTREAM = "upstream" GROUPED_LIST = "grouped_list" GROUPED_DICT = "grouped_dict" + CONFIGURATION = "configuration" class ParametrizedDependency: @@ -44,6 +45,14 @@ def get_dependency_type(self) -> ParametrizedDependencySource: return ParametrizedDependencySource.UPSTREAM +@dataclasses.dataclass +class ConfigDependency(SingleDependency): + source: str + + def get_dependency_type(self) -> ParametrizedDependencySource: + return ParametrizedDependencySource.CONFIGURATION + + class GroupedDependency(ParametrizedDependency, abc.ABC): @classmethod @abc.abstractmethod @@ -123,8 +132,8 @@ def value(literal_value: Any) -> LiteralDependency: E.G. value("foo") means that the value is actually the string value "foo". - :param literal_value: Python literal value to use. :return: A LiteralDependency object -- a - signifier to the internal framework of the dependency type. + :param literal_value: Python literal value to use. + :return: A LiteralDependency object -- a signifier to the internal framework of the dependency type. """ if isinstance(literal_value, LiteralDependency): return literal_value @@ -138,14 +147,25 @@ def source(dependency_on: Any) -> UpstreamDependency: be assigned the value that "foo" outputs. :param dependency_on: Upstream function (i.e. node) to come from. - :return: An - UpstreamDependency object -- a signifier to the internal framework of the dependency type. + :return: An UpstreamDependency object -- a signifier to the internal framework of the dependency type. """ if isinstance(dependency_on, UpstreamDependency): return dependency_on return UpstreamDependency(source=dependency_on) +def configuration(dependency_on: str) -> ConfigDependency: + """Specifies that a parameterized dependency comes from the global `config` passed in. + + This means that it comes from a global configuration key value. E.G. config("foo") means that it should + be assigned the value that the "foo" key in global configuration passed to Hamilton maps to. + + :param dependency_on: name of the configuration key to pull from. + :return: An ConfigDependency object -- a signifier to the internal framework of the dependency type. + """ + return ConfigDependency(source=dependency_on) + + def _validate_group_params( dependency_args: List[ParametrizedDependency], dependency_kwargs: Dict[str, ParametrizedDependency], diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index e30aa38ef..72c9160c4 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -131,6 +131,38 @@ def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]): ) +def _resolve_subdag_configuration( + configuration: Dict[str, Any], fields: Dict[str, Any], function_name: str +) -> Dict[str, Any]: + """Resolves the configuration for a subdag. + + :param configuration: the Hamilton configuration + :param fields: the fields passed to the subdag decorator + :return: resolved configuration to use for this subdag. + """ + sources_to_map = {} + values_to_include = {} + for key, value in fields.items(): + if isinstance(value, dependencies.ConfigDependency): + sources_to_map[key] = value.source + elif isinstance(value, dependencies.LiteralDependency): + values_to_include[key] = value.value + plain_configs = { + k: v for k, v in fields.items() if k not in sources_to_map and k not in values_to_include + } + resolved_config = dict(configuration, **plain_configs, **values_to_include) + + # override any values from sources + for key, source in sources_to_map.items(): + try: + resolved_config[key] = resolved_config[source] + except KeyError as e: + raise InvalidDecoratorException( + f"Source {source} was not found in the configuration. This is required for the {function_name} subdag." + ) from e + return resolved_config + + NON_FINAL_TAGS = {NodeTransformer.NON_FINAL_TAG: True} @@ -423,28 +455,8 @@ def _derive_name(self, fn: Callable) -> str: def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]: # Resolve all nodes from passed in functions - # if self.config has source() or value() in it, we need to resolve it - sources_to_map = {} - values_to_include = {} - for key, value in self.config.items(): - if isinstance(value, dependencies.UpstreamDependency): - sources_to_map[key] = value.source - elif isinstance(value, dependencies.LiteralDependency): - values_to_include[key] = value.value - plain_configs = { - k: v - for k, v in self.config.items() - if k not in sources_to_map and k not in values_to_include - } - resolved_config = dict(configuration, **plain_configs, **values_to_include) - # override any values from sources - for key, source in sources_to_map.items(): - try: - resolved_config[key] = resolved_config[source] - except KeyError as e: - raise InvalidDecoratorException( - f"Source {source} was not found in the configuration. This is required for the {fn.__name__} subdag." - ) from e + # if self.config has configuration() or value() in it, we need to resolve it + resolved_config = _resolve_subdag_configuration(configuration, self.config, fn.__name__) # resolved_config = dict(configuration, **self.config) nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions) # Derive the namespace under which all these nodes will live diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index 67d31d892..6cdcf1065 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -10,6 +10,7 @@ from hamilton.function_modifiers import ( InvalidDecoratorException, config, + configuration, parameterized_subdag, recursive, subdag, @@ -413,7 +414,7 @@ def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: return sum(inner_subdag) - @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken2")}) + @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": configuration("broken2")}) def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: return sum(inner_subdag) @@ -454,7 +455,11 @@ def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: return sum(inner_subdag) - @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": source("broken_missing")}) + @subdag( + inner_subdag, + inputs={"input_2": value(3)}, + config={"broken": configuration("broken_missing")}, + ) def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: return sum(inner_subdag) @@ -468,6 +473,30 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: graph.FunctionGraph.from_modules(full_module, config={"broken2": False}) +@pytest.mark.parametrize( + "configuration,fields,expected", + [ + ({"a": 1, "b": 2}, {}, {"a": 1, "b": 2}), + ({"a": 1, "b": 2}, {"c": value(3)}, {"a": 1, "b": 2, "c": 3}), + ( + {"a": 1, "b": 2}, + {"c": value(3), "d": configuration("a")}, + {"a": 1, "b": 2, "c": 3, "d": 1}, + ), + ], +) +def test_resolve_subdag_configuration_happy(configuration, fields, expected): + actual = recursive._resolve_subdag_configuration(configuration, fields, "test") + assert actual == expected + + +def test_resolve_subdag_configuration_bad_mapping(): + _configuration = {"a": 1, "b": 2} + fields = {"c": value(3), "d": configuration("e")} + with pytest.raises(InvalidDecoratorException): + recursive._resolve_subdag_configuration(_configuration, fields, "test") + + def test_subdag_with_external_nodes_input(): def bar(input_1: int) -> int: return input_1 + 1