diff --git a/examples/model_examples/modular_example/README.md b/examples/model_examples/modular_example/README.md
new file mode 100644
index 000000000..ef48d4dee
--- /dev/null
+++ b/examples/model_examples/modular_example/README.md
@@ -0,0 +1,34 @@
+# Modular pipeline example
+
+In this example we show how you can compose a pipeline from multiple modules.
+This is a common pattern in Hamilton, where you can define a module that encapsulates
+a set of "assets" and then use that module in a parameterized manner.
+
+The use case here is that:
+
+1. we have common data/feature engineering code.
+2. we have a training set that creates a model
+3. we have an inference step that given a model and a dataset, predicts the outcome on that dataset.
+
+With these 3 things we want to create a single pipeline that:
+
+1. trains a model and predicts on the training set.
+2. uses that trained model to then predict on a separate dataset.
+
+We do this by creating our base components:
+
+1. Creating a module that contains the common data/feature engineering code.
+2. Creating a module that trains a model.
+3. Creating a module that predicts on a dataset.
+
+We can then create two pipelines that use these modules in different ways:
+
+1. For training and predicting on the training set we use all 3 modules.
+2. For predicting on a separate dataset we use only the feature engineering module and the prediction module.
+3. We wire the two together so that the trained model then gets used in the prediction step for the separate dataset.
+
+By using `@subdag` we namespace the reuse of the modules and that's how we can
+reuse the same functions in different pipelines.
+
+See:
+![single_pipeline](my_dag_annotated.png)
diff --git a/examples/model_examples/modular_example/features.py b/examples/model_examples/modular_example/features.py
new file mode 100644
index 000000000..a18f7e84b
--- /dev/null
+++ b/examples/model_examples/modular_example/features.py
@@ -0,0 +1,9 @@
+import pandas as pd
+
+
+def raw_data(path: str) -> pd.DataFrame:
+ return pd.read_csv(path)
+
+
+def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:
+ return raw_data.dropna()
diff --git a/examples/model_examples/modular_example/inference.py b/examples/model_examples/modular_example/inference.py
new file mode 100644
index 000000000..fdf79e73a
--- /dev/null
+++ b/examples/model_examples/modular_example/inference.py
@@ -0,0 +1,7 @@
+from typing import Any
+
+import pandas as pd
+
+
+def predicted_data(transformed_data: pd.DataFrame, fit_model: Any) -> pd.DataFrame:
+ return fit_model.predict(transformed_data)
diff --git a/examples/model_examples/modular_example/my_dag_annotated.png b/examples/model_examples/modular_example/my_dag_annotated.png
new file mode 100644
index 000000000..81e4e4eab
Binary files /dev/null and b/examples/model_examples/modular_example/my_dag_annotated.png differ
diff --git a/examples/model_examples/modular_example/notebook.ipynb b/examples/model_examples/modular_example/notebook.ipynb
new file mode 100644
index 000000000..f6128da4d
--- /dev/null
+++ b/examples/model_examples/modular_example/notebook.ipynb
@@ -0,0 +1,432 @@
+{
+ "cells": [
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": "%pip install sf-hamilton[visualization]",
+ "id": "6b12abc0bf96a1fa"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Modular Pipeline Example [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/model_examples/modular_example/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/model_examples/modular_example/notebook.ipynb)\n",
+ "This uses the jupyter magic commands to create a simple example of how to reuse pipelines in a modular manner with subdag. "
+ ],
+ "id": "5fdf2bac7ddc6f79"
+ },
+ {
+ "metadata": {
+ "collapsed": true,
+ "ExecuteTime": {
+ "end_time": "2024-12-07T06:57:19.359572Z",
+ "start_time": "2024-12-07T06:57:13.119759Z"
+ }
+ },
+ "cell_type": "code",
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "execution_count": 1,
+ "source": "%load_ext hamilton.plugins.jupyter_magic",
+ "id": "initial_id"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Define features module\n",
+ "\n",
+ "This is the common data preprocessing step."
+ ],
+ "id": "29ebd0ec7fc5b800"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T06:57:19.627950Z",
+ "start_time": "2024-12-07T06:57:19.368576Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%%cell_to_module features --display\n",
+ "\n",
+ "import pandas as pd\n",
+ "\n",
+ "def raw_data(path: str) -> pd.DataFrame:\n",
+ " return pd.read_csv(path)\n",
+ "\n",
+ "def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:\n",
+ " return raw_data.dropna()"
+ ],
+ "id": "7fafbffaf2f6f68a",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 2
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Define train module\n",
+ "\n",
+ "This is the training bit of the dataflow."
+ ],
+ "id": "ee170ce894848eae"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T06:57:19.971271Z",
+ "start_time": "2024-12-07T06:57:19.724804Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%%cell_to_module train --config '{\"model\":\"RandomForest\"}'--display\n",
+ "\n",
+ "from typing import Any\n",
+ "import pandas as pd\n",
+ "\n",
+ "from hamilton.function_modifiers import config\n",
+ "\n",
+ "@config.when(model=\"RandomForest\")\n",
+ "def base_model__rf(model_params: dict) -> Any:\n",
+ " from sklearn.ensemble import RandomForestClassifier\n",
+ " return RandomForestClassifier(**model_params)\n",
+ "\n",
+ "@config.when(model=\"LogisticRegression\")\n",
+ "def base_model__lr(model_params: dict) -> Any:\n",
+ " from sklearn.linear_model import LogisticRegression\n",
+ " return LogisticRegression(**model_params)\n",
+ "\n",
+ "@config.when(model=\"XGBoost\")\n",
+ "def base_model__xgb(model_params: dict) -> Any:\n",
+ " from xgboost import XGBClassifier\n",
+ " return XGBClassifier(**model_params)\n",
+ "\n",
+ "\n",
+ "def fit_model(transformed_data: pd.DataFrame, base_model: Any) -> Any:\n",
+ " \"\"\"Fit a model to transformed data.\"\"\"\n",
+ " base_model.fit(transformed_data.drop(\"target\", axis=1), transformed_data[\"target\"])\n",
+ " return base_model\n"
+ ],
+ "id": "eae523c3fba37c93",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 3
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Define the inference module\n",
+ "\n",
+ "This houses what we need for inference."
+ ],
+ "id": "8cae5e1a9c682ea5"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T06:57:20.363768Z",
+ "start_time": "2024-12-07T06:57:20.114344Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%%cell_to_module inference --display\n",
+ "from typing import Any\n",
+ "import pandas as pd\n",
+ "\n",
+ "\n",
+ "def predicted_data(transformed_data: pd.DataFrame, fit_model: Any) -> pd.DataFrame:\n",
+ " return fit_model.predict(transformed_data)\n",
+ "\n"
+ ],
+ "id": "2ad9e61062f6516a",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 4
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# We can combine the modules independently with different drivers\n",
+ "\n",
+ "But this won't provide us with a single dataflow or DAG."
+ ],
+ "id": "3a1a0d9aca3944b1"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T18:08:40.538779Z",
+ "start_time": "2024-12-07T18:08:39.642181Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# train\n",
+ "from hamilton import driver\n",
+ "\n",
+ "train_dr = (\n",
+ " driver.Builder()\n",
+ " .with_config({\"model\": \"RandomForest\", \"model_params\": {\"n_estimators\": 100}})\n",
+ " .with_modules(features, train, inference)\n",
+ " .build()\n",
+ ")\n",
+ "train_dr.display_all_functions()"
+ ],
+ "id": "9ac29701bdd31fb5",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 9
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T18:09:13.265102Z",
+ "start_time": "2024-12-07T18:09:12.750662Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# Inference\n",
+ "from hamilton import driver\n",
+ "\n",
+ "inference_dr = (\n",
+ " driver.Builder()\n",
+ " .with_config({})\n",
+ " .with_modules(features, inference)\n",
+ " .build()\n",
+ ")\n",
+ "inference_dr.display_all_functions()"
+ ],
+ "id": "cc9401ed081df22f",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 10
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# To combine into a single dataflow we can use @subdag\n",
+ "\n",
+ "So if we want a single pipeline that enables us to:\n",
+ "\n",
+ "1. train the model & get training set predictions.\n",
+ "2. then use the fit model to predict on a separate dataset.\n",
+ "\n",
+ "To do that we define another module that uses the `@subdag` constructs that we wire together."
+ ],
+ "id": "d85c51388733ce96"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T07:00:23.770491Z",
+ "start_time": "2024-12-07T07:00:23.481869Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "%%cell_to_module pipeline --config '{\"model\":\"RandomForest\"}' --display\n",
+ "from typing import Any\n",
+ "\n",
+ "import pandas as pd\n",
+ "\n",
+ "from hamilton.function_modifiers import subdag, extract_fields, configuration, source\n",
+ "import features\n",
+ "import train\n",
+ "import inference\n",
+ "\n",
+ "@extract_fields(\n",
+ " {'fit_model': Any, 'training_prediction': pd.DataFrame}\n",
+ ")\n",
+ "@subdag(\n",
+ " features, train, inference,\n",
+ " inputs={\n",
+ " \"path\": source(\"path\"),\n",
+ " \"model_params\": source(\"model_params\"),\n",
+ " },\n",
+ " # there are several ways to pass in configuration.\n",
+ " # config={ \n",
+ " # \"model\": configuration(\"model\")\n",
+ " # },\n",
+ ")\n",
+ "def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:\n",
+ " return {'fit_model': fit_model, 'training_prediction': predicted_data}\n",
+ "\n",
+ "@subdag(\n",
+ " features, inference,\n",
+ " inputs={\n",
+ " \"path\": source(\"predict_path\"),\n",
+ " \"fit_model\": source(\"fit_model\"),\n",
+ " },\n",
+ ")\n",
+ "def predicted_data(predicted_data: pd.DataFrame) -> pd.DataFrame:\n",
+ " return predicted_data"
+ ],
+ "id": "6d1585dad64464d7",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "execution_count": 8
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T06:57:20.874962Z",
+ "start_time": "2024-12-07T06:57:20.643256Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "from hamilton import driver\n",
+ "\n",
+ "dr = (\n",
+ " driver.Builder()\n",
+ " .with_config({\"model\": \"RandomForest\", \"model_params\": {\"n_estimators\": 100}})\n",
+ " .with_modules(pipeline)\n",
+ " .build()\n",
+ ")\n",
+ "dr.display_all_functions()"
+ ],
+ "id": "f72146c07a654ca4",
+ "outputs": [
+ {
+ "data": {
+ "image/svg+xml": "\n\n\n\n\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 6
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-12-07T06:57:34.959772Z",
+ "start_time": "2024-12-07T06:57:34.956204Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "# this wont work because we don't actually have data...\n",
+ "# dr.execute([\"trained_pipeline\", \"predicted_data\"], \n",
+ "# inputs={\"path\": \"data.csv\", \"predict_path\": \"data.csv\"})"
+ ],
+ "id": "b3abca24b1a86329",
+ "outputs": [],
+ "execution_count": 7
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": "",
+ "id": "1b3dba37a6c00d7c"
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/model_examples/modular_example/pipeline.py b/examples/model_examples/modular_example/pipeline.py
new file mode 100644
index 000000000..0c7ebb0eb
--- /dev/null
+++ b/examples/model_examples/modular_example/pipeline.py
@@ -0,0 +1,37 @@
+from typing import Any
+
+import features
+import inference
+import pandas as pd
+import train
+
+from hamilton.function_modifiers import configuration, extract_fields, source, subdag
+
+
+@extract_fields({"fit_model": Any, "training_prediction": pd.DataFrame})
+@subdag(
+ features,
+ train,
+ inference,
+ inputs={
+ "path": source("path"),
+ "model_params": source("model_params"),
+ },
+ config={
+ "model": configuration("train_model_type"), # not strictly required but allows us to remap.
+ },
+)
+def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:
+ return {"fit_model": fit_model, "training_prediction": predicted_data}
+
+
+@subdag(
+ features,
+ inference,
+ inputs={
+ "path": source("predict_path"),
+ "fit_model": source("fit_model"),
+ },
+)
+def predicted_data(predicted_data: pd.DataFrame) -> pd.DataFrame:
+ return predicted_data
diff --git a/examples/model_examples/modular_example/run.py b/examples/model_examples/modular_example/run.py
new file mode 100644
index 000000000..415bd2863
--- /dev/null
+++ b/examples/model_examples/modular_example/run.py
@@ -0,0 +1,18 @@
+import pipeline
+
+from hamilton import driver
+
+
+def run():
+ dr = (
+ driver.Builder()
+ .with_config({"train_model_type": "RandomForest", "model_params": {"n_estimators": 100}})
+ .with_modules(pipeline)
+ .build()
+ )
+ dr.display_all_functions("./my_dag.png")
+ # dr.execute(["trained_pipeline", "predicted_data"])
+
+
+if __name__ == "__main__":
+ run()
diff --git a/examples/model_examples/modular_example/train.py b/examples/model_examples/modular_example/train.py
new file mode 100644
index 000000000..9d0d93bea
--- /dev/null
+++ b/examples/model_examples/modular_example/train.py
@@ -0,0 +1,32 @@
+from typing import Any
+
+import pandas as pd
+
+from hamilton.function_modifiers import config
+
+
+@config.when(model="RandomForest")
+def base_model__rf(model_params: dict) -> Any:
+ from sklearn.ensemble import RandomForestClassifier
+
+ return RandomForestClassifier(**model_params)
+
+
+@config.when(model="LogisticRegression")
+def base_model__lr(model_params: dict) -> Any:
+ from sklearn.linear_model import LogisticRegression
+
+ return LogisticRegression(**model_params)
+
+
+@config.when(model="XGBoost")
+def base_model__xgb(model_params: dict) -> Any:
+ from xgboost import XGBClassifier
+
+ return XGBClassifier(**model_params)
+
+
+def fit_model(transformed_data: pd.DataFrame, base_model: Any) -> Any:
+ """Fit a model to transformed data."""
+ base_model.fit(transformed_data.drop("target", axis=1), transformed_data["target"])
+ return base_model
diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py
index 958d07540..3113b13ff 100644
--- a/hamilton/function_modifiers/__init__.py
+++ b/hamilton/function_modifiers/__init__.py
@@ -36,6 +36,7 @@
value = dependencies.value
source = dependencies.source
group = dependencies.group
+configuration = dependencies.configuration
# These aren't strictly part of the API but we should have them here for safety
LiteralDependency = dependencies.LiteralDependency
diff --git a/hamilton/function_modifiers/dependencies.py b/hamilton/function_modifiers/dependencies.py
index 26785505c..4f0c6d159 100644
--- a/hamilton/function_modifiers/dependencies.py
+++ b/hamilton/function_modifiers/dependencies.py
@@ -16,6 +16,7 @@ class ParametrizedDependencySource(enum.Enum):
UPSTREAM = "upstream"
GROUPED_LIST = "grouped_list"
GROUPED_DICT = "grouped_dict"
+ CONFIGURATION = "configuration"
class ParametrizedDependency:
@@ -44,6 +45,14 @@ def get_dependency_type(self) -> ParametrizedDependencySource:
return ParametrizedDependencySource.UPSTREAM
+@dataclasses.dataclass
+class ConfigDependency(SingleDependency):
+ source: str
+
+ def get_dependency_type(self) -> ParametrizedDependencySource:
+ return ParametrizedDependencySource.CONFIGURATION
+
+
class GroupedDependency(ParametrizedDependency, abc.ABC):
@classmethod
@abc.abstractmethod
@@ -123,8 +132,8 @@ def value(literal_value: Any) -> LiteralDependency:
E.G. value("foo") means that the value is actually the string value "foo".
- :param literal_value: Python literal value to use. :return: A LiteralDependency object -- a
- signifier to the internal framework of the dependency type.
+ :param literal_value: Python literal value to use.
+ :return: A LiteralDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(literal_value, LiteralDependency):
return literal_value
@@ -138,14 +147,25 @@ def source(dependency_on: Any) -> UpstreamDependency:
be assigned the value that "foo" outputs.
:param dependency_on: Upstream function (i.e. node) to come from.
- :return: An
- UpstreamDependency object -- a signifier to the internal framework of the dependency type.
+ :return: An UpstreamDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(dependency_on, UpstreamDependency):
return dependency_on
return UpstreamDependency(source=dependency_on)
+def configuration(dependency_on: str) -> ConfigDependency:
+ """Specifies that a parameterized dependency comes from the global `config` passed in.
+
+ This means that it comes from a global configuration key value. E.G. config("foo") means that it should
+ be assigned the value that the "foo" key in global configuration passed to Hamilton maps to.
+
+ :param dependency_on: name of the configuration key to pull from.
+ :return: An ConfigDependency object -- a signifier to the internal framework of the dependency type.
+ """
+ return ConfigDependency(source=dependency_on)
+
+
def _validate_group_params(
dependency_args: List[ParametrizedDependency],
dependency_kwargs: Dict[str, ParametrizedDependency],
diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py
index a330714a9..d86ad74f1 100644
--- a/hamilton/function_modifiers/recursive.py
+++ b/hamilton/function_modifiers/recursive.py
@@ -131,6 +131,43 @@ def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]):
)
+def _resolve_subdag_configuration(
+ configuration: Dict[str, Any], fields: Dict[str, Any], function_name: str
+) -> Dict[str, Any]:
+ """Resolves the configuration for a subdag.
+
+ :param configuration: the Hamilton configuration
+ :param fields: the fields passed to the subdag decorator
+ :return: resolved configuration to use for this subdag.
+ """
+ sources_to_map = {}
+ values_to_include = {}
+ for key, value in fields.items():
+ if isinstance(value, dependencies.ConfigDependency):
+ sources_to_map[key] = value.source
+ elif isinstance(value, dependencies.LiteralDependency):
+ values_to_include[key] = value.value
+ elif isinstance(value, (dependencies.GroupedDependency, dependencies.SingleDependency)):
+ raise InvalidDecoratorException(
+ f"`{value}` is not allowed in the config= part of the subdag decorator. "
+ "Please use `configuration()` or `value()` or literal python values."
+ )
+ plain_configs = {
+ k: v for k, v in fields.items() if k not in sources_to_map and k not in values_to_include
+ }
+ resolved_config = dict(configuration, **plain_configs, **values_to_include)
+
+ # override any values from sources
+ for key, source in sources_to_map.items():
+ try:
+ resolved_config[key] = resolved_config[source]
+ except KeyError as e:
+ raise InvalidDecoratorException(
+ f"Source {source} was not found in the configuration. This is required for the {function_name} subdag."
+ ) from e
+ return resolved_config
+
+
NON_FINAL_TAGS = {NodeTransformer.NON_FINAL_TAG: True}
@@ -423,7 +460,9 @@ def _derive_name(self, fn: Callable) -> str:
def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]:
# Resolve all nodes from passed in functions
- resolved_config = dict(configuration, **self.config)
+ # if self.config has configuration() or value() in it, we need to resolve it
+ resolved_config = _resolve_subdag_configuration(configuration, self.config, fn.__name__)
+ # resolved_config = dict(configuration, **self.config)
nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions)
# Derive the namespace under which all these nodes will live
namespace = self._derive_namespace(fn)
diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py
index e9b76686c..f2f35ee71 100644
--- a/tests/function_modifiers/test_recursive.py
+++ b/tests/function_modifiers/test_recursive.py
@@ -10,6 +10,8 @@
from hamilton.function_modifiers import (
InvalidDecoratorException,
config,
+ configuration,
+ group,
parameterized_subdag,
recursive,
subdag,
@@ -392,6 +394,122 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
)
+def test_nested_subdag_with_config_remapping():
+ """Tests that we can remap config values and source and value are resolved correctly."""
+
+ def bar(input_1: int) -> int:
+ return input_1 + 1
+
+ @config.when(broken=False)
+ def foo(input_2: int) -> int:
+ return input_2 + 1
+
+ @subdag(
+ foo,
+ bar,
+ )
+ def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
+ return foo, bar
+
+ @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)})
+ def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
+ return sum(inner_subdag)
+
+ @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": configuration("broken2")})
+ def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
+ return sum(inner_subdag)
+
+ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
+ return outer_subdag_1 + outer_subdag_2
+
+ # we only need to generate from the outer subdag
+ # as it refers to the inner one
+ full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
+ fg = graph.FunctionGraph.from_modules(full_module, config={"broken2": False})
+ assert "outer_subdag_1" in fg.nodes
+ assert "outer_subdag_2" in fg.nodes
+ res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
+ # This is effectively the function graph
+ assert res["sum_all"] == sum_all(
+ outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
+ )
+
+
+def test_nested_subdag_with_config_remapping_missing_error():
+ """Tests that we error if we can't remap a config value."""
+
+ def bar(input_1: int) -> int:
+ return input_1 + 1
+
+ @config.when(broken=False)
+ def foo(input_2: int) -> int:
+ return input_2 + 1
+
+ @subdag(
+ foo,
+ bar,
+ )
+ def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
+ return foo, bar
+
+ @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)})
+ def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
+ return sum(inner_subdag)
+
+ @subdag(
+ inner_subdag,
+ inputs={"input_2": value(3)},
+ config={"broken": configuration("broken_missing")},
+ )
+ def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
+ return sum(inner_subdag)
+
+ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
+ return outer_subdag_1 + outer_subdag_2
+
+ # we only need to generate from the outer subdag
+ # as it refers to the inner one
+ full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
+ with pytest.raises(InvalidDecoratorException):
+ graph.FunctionGraph.from_modules(full_module, config={"broken2": False})
+
+
+@pytest.mark.parametrize(
+ "configuration,fields,expected",
+ [
+ ({"a": 1, "b": 2}, {}, {"a": 1, "b": 2}),
+ ({"a": 1, "b": 2}, {"c": value(3)}, {"a": 1, "b": 2, "c": 3}),
+ (
+ {"a": 1, "b": 2},
+ {"c": value(3), "d": configuration("a")},
+ {"a": 1, "b": 2, "c": 3, "d": 1},
+ ),
+ ],
+)
+def test_resolve_subdag_configuration_happy(configuration, fields, expected):
+ actual = recursive._resolve_subdag_configuration(configuration, fields, "test")
+ assert actual == expected
+
+
+def test_resolve_subdag_configuration_bad_mapping():
+ _configuration = {"a": 1, "b": 2}
+ fields = {"c": value(3), "d": configuration("e")}
+ with pytest.raises(InvalidDecoratorException):
+ recursive._resolve_subdag_configuration(_configuration, fields, "test")
+
+
+def test_resolve_subdag_configuration_flag_incorrect_source_group_deps():
+ _configuration = {"a": 1, "b": 2}
+ with pytest.raises(InvalidDecoratorException):
+ recursive._resolve_subdag_configuration(
+ _configuration, {"c": value(3), "d": source("e")}, "test"
+ )
+ with pytest.raises(InvalidDecoratorException):
+ recursive._resolve_subdag_configuration(
+ _configuration, {"c": value(3), "d": group(source("e"))}, "test"
+ )
+
+
def test_subdag_with_external_nodes_input():
def bar(input_1: int) -> int:
return input_1 + 1