Skip to content

Commit

Permalink
Adds modular subdag example and configuration object for subdag config (
Browse files Browse the repository at this point in the history
#1251)

* Enables subdag to handle `configuration` and `value` for config values

To be consistent and to allow people to remap external values we need
this functionality.

E.g. subdag has a config key "foo", but we need two different values
because we are reusing it in two subdags. So we can
remap the values from external config:

Global config: {"foo1": "foo_v1", "foo2": "foo_v2"}
Subdag1 config: {"foo": configuration("foo1")}
Subdag2 config: {"foo": configuration{"foo2")}

This also handles if someone wants to use value().

Otherwise to preserve backwards compatibility we need
to do this in a specific order and only override config values
appropriately.

Squashed commits:

* Adds modular model pipeline example

This shows how you can construct a pipeline from components
and then use subdag to parameterize it for reuse.

* Adds more notes  to module notebook example

* Adds configuration object and refactors code

Named it configuration to not clash with config decorator.

* Adds more tests

* Updates modular example image
  • Loading branch information
skrawcz authored Dec 12, 2024
1 parent 622866a commit ed90580
Show file tree
Hide file tree
Showing 12 changed files with 752 additions and 5 deletions.
34 changes: 34 additions & 0 deletions examples/model_examples/modular_example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Modular pipeline example

In this example we show how you can compose a pipeline from multiple modules.
This is a common pattern in Hamilton, where you can define a module that encapsulates
a set of "assets" and then use that module in a parameterized manner.

The use case here is that:

1. we have common data/feature engineering code.
2. we have a training set that creates a model
3. we have an inference step that given a model and a dataset, predicts the outcome on that dataset.

With these 3 things we want to create a single pipeline that:

1. trains a model and predicts on the training set.
2. uses that trained model to then predict on a separate dataset.

We do this by creating our base components:

1. Creating a module that contains the common data/feature engineering code.
2. Creating a module that trains a model.
3. Creating a module that predicts on a dataset.

We can then create two pipelines that use these modules in different ways:

1. For training and predicting on the training set we use all 3 modules.
2. For predicting on a separate dataset we use only the feature engineering module and the prediction module.
3. We wire the two together so that the trained model then gets used in the prediction step for the separate dataset.

By using `@subdag` we namespace the reuse of the modules and that's how we can
reuse the same functions in different pipelines.

See:
![single_pipeline](my_dag_annotated.png)
9 changes: 9 additions & 0 deletions examples/model_examples/modular_example/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pandas as pd


def raw_data(path: str) -> pd.DataFrame:
return pd.read_csv(path)


def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:
return raw_data.dropna()
7 changes: 7 additions & 0 deletions examples/model_examples/modular_example/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Any

import pandas as pd


def predicted_data(transformed_data: pd.DataFrame, fit_model: Any) -> pd.DataFrame:
return fit_model.predict(transformed_data)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
432 changes: 432 additions & 0 deletions examples/model_examples/modular_example/notebook.ipynb

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions examples/model_examples/modular_example/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Any

import features
import inference
import pandas as pd
import train

from hamilton.function_modifiers import configuration, extract_fields, source, subdag


@extract_fields({"fit_model": Any, "training_prediction": pd.DataFrame})
@subdag(
features,
train,
inference,
inputs={
"path": source("path"),
"model_params": source("model_params"),
},
config={
"model": configuration("train_model_type"), # not strictly required but allows us to remap.
},
)
def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:
return {"fit_model": fit_model, "training_prediction": predicted_data}


@subdag(
features,
inference,
inputs={
"path": source("predict_path"),
"fit_model": source("fit_model"),
},
)
def predicted_data(predicted_data: pd.DataFrame) -> pd.DataFrame:
return predicted_data
18 changes: 18 additions & 0 deletions examples/model_examples/modular_example/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pipeline

from hamilton import driver


def run():
dr = (
driver.Builder()
.with_config({"train_model_type": "RandomForest", "model_params": {"n_estimators": 100}})
.with_modules(pipeline)
.build()
)
dr.display_all_functions("./my_dag.png")
# dr.execute(["trained_pipeline", "predicted_data"])


if __name__ == "__main__":
run()
32 changes: 32 additions & 0 deletions examples/model_examples/modular_example/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any

import pandas as pd

from hamilton.function_modifiers import config


@config.when(model="RandomForest")
def base_model__rf(model_params: dict) -> Any:
from sklearn.ensemble import RandomForestClassifier

return RandomForestClassifier(**model_params)


@config.when(model="LogisticRegression")
def base_model__lr(model_params: dict) -> Any:
from sklearn.linear_model import LogisticRegression

return LogisticRegression(**model_params)


@config.when(model="XGBoost")
def base_model__xgb(model_params: dict) -> Any:
from xgboost import XGBClassifier

return XGBClassifier(**model_params)


def fit_model(transformed_data: pd.DataFrame, base_model: Any) -> Any:
"""Fit a model to transformed data."""
base_model.fit(transformed_data.drop("target", axis=1), transformed_data["target"])
return base_model
1 change: 1 addition & 0 deletions hamilton/function_modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
value = dependencies.value
source = dependencies.source
group = dependencies.group
configuration = dependencies.configuration

# These aren't strictly part of the API but we should have them here for safety
LiteralDependency = dependencies.LiteralDependency
Expand Down
28 changes: 24 additions & 4 deletions hamilton/function_modifiers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ParametrizedDependencySource(enum.Enum):
UPSTREAM = "upstream"
GROUPED_LIST = "grouped_list"
GROUPED_DICT = "grouped_dict"
CONFIGURATION = "configuration"


class ParametrizedDependency:
Expand Down Expand Up @@ -44,6 +45,14 @@ def get_dependency_type(self) -> ParametrizedDependencySource:
return ParametrizedDependencySource.UPSTREAM


@dataclasses.dataclass
class ConfigDependency(SingleDependency):
source: str

def get_dependency_type(self) -> ParametrizedDependencySource:
return ParametrizedDependencySource.CONFIGURATION


class GroupedDependency(ParametrizedDependency, abc.ABC):
@classmethod
@abc.abstractmethod
Expand Down Expand Up @@ -123,8 +132,8 @@ def value(literal_value: Any) -> LiteralDependency:
E.G. value("foo") means that the value is actually the string value "foo".
:param literal_value: Python literal value to use. :return: A LiteralDependency object -- a
signifier to the internal framework of the dependency type.
:param literal_value: Python literal value to use.
:return: A LiteralDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(literal_value, LiteralDependency):
return literal_value
Expand All @@ -138,14 +147,25 @@ def source(dependency_on: Any) -> UpstreamDependency:
be assigned the value that "foo" outputs.
:param dependency_on: Upstream function (i.e. node) to come from.
:return: An
UpstreamDependency object -- a signifier to the internal framework of the dependency type.
:return: An UpstreamDependency object -- a signifier to the internal framework of the dependency type.
"""
if isinstance(dependency_on, UpstreamDependency):
return dependency_on
return UpstreamDependency(source=dependency_on)


def configuration(dependency_on: str) -> ConfigDependency:
"""Specifies that a parameterized dependency comes from the global `config` passed in.
This means that it comes from a global configuration key value. E.G. config("foo") means that it should
be assigned the value that the "foo" key in global configuration passed to Hamilton maps to.
:param dependency_on: name of the configuration key to pull from.
:return: An ConfigDependency object -- a signifier to the internal framework of the dependency type.
"""
return ConfigDependency(source=dependency_on)


def _validate_group_params(
dependency_args: List[ParametrizedDependency],
dependency_kwargs: Dict[str, ParametrizedDependency],
Expand Down
41 changes: 40 additions & 1 deletion hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,43 @@ def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]):
)


def _resolve_subdag_configuration(
configuration: Dict[str, Any], fields: Dict[str, Any], function_name: str
) -> Dict[str, Any]:
"""Resolves the configuration for a subdag.
:param configuration: the Hamilton configuration
:param fields: the fields passed to the subdag decorator
:return: resolved configuration to use for this subdag.
"""
sources_to_map = {}
values_to_include = {}
for key, value in fields.items():
if isinstance(value, dependencies.ConfigDependency):
sources_to_map[key] = value.source
elif isinstance(value, dependencies.LiteralDependency):
values_to_include[key] = value.value
elif isinstance(value, (dependencies.GroupedDependency, dependencies.SingleDependency)):
raise InvalidDecoratorException(
f"`{value}` is not allowed in the config= part of the subdag decorator. "
"Please use `configuration()` or `value()` or literal python values."
)
plain_configs = {
k: v for k, v in fields.items() if k not in sources_to_map and k not in values_to_include
}
resolved_config = dict(configuration, **plain_configs, **values_to_include)

# override any values from sources
for key, source in sources_to_map.items():
try:
resolved_config[key] = resolved_config[source]
except KeyError as e:
raise InvalidDecoratorException(
f"Source {source} was not found in the configuration. This is required for the {function_name} subdag."
) from e
return resolved_config


NON_FINAL_TAGS = {NodeTransformer.NON_FINAL_TAG: True}


Expand Down Expand Up @@ -423,7 +460,9 @@ def _derive_name(self, fn: Callable) -> str:

def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]:
# Resolve all nodes from passed in functions
resolved_config = dict(configuration, **self.config)
# if self.config has configuration() or value() in it, we need to resolve it
resolved_config = _resolve_subdag_configuration(configuration, self.config, fn.__name__)
# resolved_config = dict(configuration, **self.config)
nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions)
# Derive the namespace under which all these nodes will live
namespace = self._derive_namespace(fn)
Expand Down
118 changes: 118 additions & 0 deletions tests/function_modifiers/test_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from hamilton.function_modifiers import (
InvalidDecoratorException,
config,
configuration,
group,
parameterized_subdag,
recursive,
subdag,
Expand Down Expand Up @@ -392,6 +394,122 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
)


def test_nested_subdag_with_config_remapping():
"""Tests that we can remap config values and source and value are resolved correctly."""

def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": configuration("broken2")})
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
fg = graph.FunctionGraph.from_modules(full_module, config={"broken2": False})
assert "outer_subdag_1" in fg.nodes
assert "outer_subdag_2" in fg.nodes
res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2})
# This is effectively the function graph
assert res["sum_all"] == sum_all(
outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3)))
)


def test_nested_subdag_with_config_remapping_missing_error():
"""Tests that we error if we can't remap a config value."""

def bar(input_1: int) -> int:
return input_1 + 1

@config.when(broken=False)
def foo(input_2: int) -> int:
return input_2 + 1

@subdag(
foo,
bar,
)
def inner_subdag(foo: int, bar: int) -> Tuple[int, int]:
return foo, bar

@subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)})
def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

@subdag(
inner_subdag,
inputs={"input_2": value(3)},
config={"broken": configuration("broken_missing")},
)
def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int:
return sum(inner_subdag)

def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int:
return outer_subdag_1 + outer_subdag_2

# we only need to generate from the outer subdag
# as it refers to the inner one
full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all)
with pytest.raises(InvalidDecoratorException):
graph.FunctionGraph.from_modules(full_module, config={"broken2": False})


@pytest.mark.parametrize(
"configuration,fields,expected",
[
({"a": 1, "b": 2}, {}, {"a": 1, "b": 2}),
({"a": 1, "b": 2}, {"c": value(3)}, {"a": 1, "b": 2, "c": 3}),
(
{"a": 1, "b": 2},
{"c": value(3), "d": configuration("a")},
{"a": 1, "b": 2, "c": 3, "d": 1},
),
],
)
def test_resolve_subdag_configuration_happy(configuration, fields, expected):
actual = recursive._resolve_subdag_configuration(configuration, fields, "test")
assert actual == expected


def test_resolve_subdag_configuration_bad_mapping():
_configuration = {"a": 1, "b": 2}
fields = {"c": value(3), "d": configuration("e")}
with pytest.raises(InvalidDecoratorException):
recursive._resolve_subdag_configuration(_configuration, fields, "test")


def test_resolve_subdag_configuration_flag_incorrect_source_group_deps():
_configuration = {"a": 1, "b": 2}
with pytest.raises(InvalidDecoratorException):
recursive._resolve_subdag_configuration(
_configuration, {"c": value(3), "d": source("e")}, "test"
)
with pytest.raises(InvalidDecoratorException):
recursive._resolve_subdag_configuration(
_configuration, {"c": value(3), "d": group(source("e"))}, "test"
)


def test_subdag_with_external_nodes_input():
def bar(input_1: int) -> int:
return input_1 + 1
Expand Down

0 comments on commit ed90580

Please sign in to comment.