diff --git a/examples/model_examples/modular_example/README.md b/examples/model_examples/modular_example/README.md new file mode 100644 index 000000000..ef48d4dee --- /dev/null +++ b/examples/model_examples/modular_example/README.md @@ -0,0 +1,34 @@ +# Modular pipeline example + +In this example we show how you can compose a pipeline from multiple modules. +This is a common pattern in Hamilton, where you can define a module that encapsulates +a set of "assets" and then use that module in a parameterized manner. + +The use case here is that: + +1. we have common data/feature engineering code. +2. we have a training set that creates a model +3. we have an inference step that given a model and a dataset, predicts the outcome on that dataset. + +With these 3 things we want to create a single pipeline that: + +1. trains a model and predicts on the training set. +2. uses that trained model to then predict on a separate dataset. + +We do this by creating our base components: + +1. Creating a module that contains the common data/feature engineering code. +2. Creating a module that trains a model. +3. Creating a module that predicts on a dataset. + +We can then create two pipelines that use these modules in different ways: + +1. For training and predicting on the training set we use all 3 modules. +2. For predicting on a separate dataset we use only the feature engineering module and the prediction module. +3. We wire the two together so that the trained model then gets used in the prediction step for the separate dataset. + +By using `@subdag` we namespace the reuse of the modules and that's how we can +reuse the same functions in different pipelines. + +See: +![single_pipeline](my_dag_annotated.png) diff --git a/examples/model_examples/modular_example/features.py b/examples/model_examples/modular_example/features.py new file mode 100644 index 000000000..a18f7e84b --- /dev/null +++ b/examples/model_examples/modular_example/features.py @@ -0,0 +1,9 @@ +import pandas as pd + + +def raw_data(path: str) -> pd.DataFrame: + return pd.read_csv(path) + + +def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame: + return raw_data.dropna() diff --git a/examples/model_examples/modular_example/inference.py b/examples/model_examples/modular_example/inference.py new file mode 100644 index 000000000..fdf79e73a --- /dev/null +++ b/examples/model_examples/modular_example/inference.py @@ -0,0 +1,7 @@ +from typing import Any + +import pandas as pd + + +def predicted_data(transformed_data: pd.DataFrame, fit_model: Any) -> pd.DataFrame: + return fit_model.predict(transformed_data) diff --git a/examples/model_examples/modular_example/my_dag_annotated.png b/examples/model_examples/modular_example/my_dag_annotated.png new file mode 100644 index 000000000..81e4e4eab Binary files /dev/null and b/examples/model_examples/modular_example/my_dag_annotated.png differ diff --git a/examples/model_examples/modular_example/notebook.ipynb b/examples/model_examples/modular_example/notebook.ipynb new file mode 100644 index 000000000..f6128da4d --- /dev/null +++ b/examples/model_examples/modular_example/notebook.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "%pip install sf-hamilton[visualization]", + "id": "6b12abc0bf96a1fa" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Modular Pipeline Example [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/model_examples/modular_example/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/model_examples/modular_example/notebook.ipynb)\n", + "This uses the jupyter magic commands to create a simple example of how to reuse pipelines in a modular manner with subdag. " + ], + "id": "5fdf2bac7ddc6f79" + }, + { + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-12-07T06:57:19.359572Z", + "start_time": "2024-12-07T06:57:13.119759Z" + } + }, + "cell_type": "code", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n" + ] + } + ], + "execution_count": 1, + "source": "%load_ext hamilton.plugins.jupyter_magic", + "id": "initial_id" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Define features module\n", + "\n", + "This is the common data preprocessing step." + ], + "id": "29ebd0ec7fc5b800" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T06:57:19.627950Z", + "start_time": "2024-12-07T06:57:19.368576Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module features --display\n", + "\n", + "import pandas as pd\n", + "\n", + "def raw_data(path: str) -> pd.DataFrame:\n", + " return pd.read_csv(path)\n", + "\n", + "def transformed_data(raw_data: pd.DataFrame) -> pd.DataFrame:\n", + " return raw_data.dropna()" + ], + "id": "7fafbffaf2f6f68a", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\ntransformed_data\n\ntransformed_data\nDataFrame\n\n\n\nraw_data\n\nraw_data\nDataFrame\n\n\n\nraw_data->transformed_data\n\n\n\n\n\n_raw_data_inputs\n\npath\nstr\n\n\n\n_raw_data_inputs->raw_data\n\n\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Define train module\n", + "\n", + "This is the training bit of the dataflow." + ], + "id": "ee170ce894848eae" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T06:57:19.971271Z", + "start_time": "2024-12-07T06:57:19.724804Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module train --config '{\"model\":\"RandomForest\"}'--display\n", + "\n", + "from typing import Any\n", + "import pandas as pd\n", + "\n", + "from hamilton.function_modifiers import config\n", + "\n", + "@config.when(model=\"RandomForest\")\n", + "def base_model__rf(model_params: dict) -> Any:\n", + " from sklearn.ensemble import RandomForestClassifier\n", + " return RandomForestClassifier(**model_params)\n", + "\n", + "@config.when(model=\"LogisticRegression\")\n", + "def base_model__lr(model_params: dict) -> Any:\n", + " from sklearn.linear_model import LogisticRegression\n", + " return LogisticRegression(**model_params)\n", + "\n", + "@config.when(model=\"XGBoost\")\n", + "def base_model__xgb(model_params: dict) -> Any:\n", + " from xgboost import XGBClassifier\n", + " return XGBClassifier(**model_params)\n", + "\n", + "\n", + "def fit_model(transformed_data: pd.DataFrame, base_model: Any) -> Any:\n", + " \"\"\"Fit a model to transformed data.\"\"\"\n", + " base_model.fit(transformed_data.drop(\"target\", axis=1), transformed_data[\"target\"])\n", + " return base_model\n" + ], + "id": "eae523c3fba37c93", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nmodel\n\n\n\nmodel\nRandomForest\n\n\n\nbase_model\n\nbase_model: model\ntyping.Any\n\n\n\nfit_model\n\nfit_model\ntyping.Any\n\n\n\nbase_model->fit_model\n\n\n\n\n\n_base_model_inputs\n\nmodel_params\ndict\n\n\n\n_base_model_inputs->base_model\n\n\n\n\n\n_fit_model_inputs\n\ntransformed_data\nDataFrame\n\n\n\n_fit_model_inputs->fit_model\n\n\n\n\n\nconfig\n\n\n\nconfig\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 3 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Define the inference module\n", + "\n", + "This houses what we need for inference." + ], + "id": "8cae5e1a9c682ea5" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T06:57:20.363768Z", + "start_time": "2024-12-07T06:57:20.114344Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module inference --display\n", + "from typing import Any\n", + "import pandas as pd\n", + "\n", + "\n", + "def predicted_data(transformed_data: pd.DataFrame, fit_model: Any) -> pd.DataFrame:\n", + " return fit_model.predict(transformed_data)\n", + "\n" + ], + "id": "2ad9e61062f6516a", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\npredicted_data\n\npredicted_data\nDataFrame\n\n\n\n_predicted_data_inputs\n\ntransformed_data\nDataFrame\nfit_model\ntyping.Any\n\n\n\n_predicted_data_inputs->predicted_data\n\n\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 4 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# We can combine the modules independently with different drivers\n", + "\n", + "But this won't provide us with a single dataflow or DAG." + ], + "id": "3a1a0d9aca3944b1" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:08:40.538779Z", + "start_time": "2024-12-07T18:08:39.642181Z" + } + }, + "cell_type": "code", + "source": [ + "# train\n", + "from hamilton import driver\n", + "\n", + "train_dr = (\n", + " driver.Builder()\n", + " .with_config({\"model\": \"RandomForest\", \"model_params\": {\"n_estimators\": 100}})\n", + " .with_modules(features, train, inference)\n", + " .build()\n", + ")\n", + "train_dr.display_all_functions()" + ], + "id": "9ac29701bdd31fb5", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nmodel\n\n\n\nmodel\nRandomForest\n\n\n\nmodel_params\n\n\n\nmodel_params\n{'n_estimators': 100}\n\n\n\npredicted_data\n\npredicted_data\nDataFrame\n\n\n\nbase_model\n\nbase_model: model\ntyping.Any\n\n\n\nfit_model\n\nfit_model\ntyping.Any\n\n\n\nbase_model->fit_model\n\n\n\n\n\ntransformed_data\n\ntransformed_data\nDataFrame\n\n\n\ntransformed_data->predicted_data\n\n\n\n\n\ntransformed_data->fit_model\n\n\n\n\n\nfit_model->predicted_data\n\n\n\n\n\nraw_data\n\nraw_data\nDataFrame\n\n\n\nraw_data->transformed_data\n\n\n\n\n\n_base_model_inputs\n\nmodel_params\ndict\n\n\n\n_base_model_inputs->base_model\n\n\n\n\n\n_raw_data_inputs\n\npath\nstr\n\n\n\n_raw_data_inputs->raw_data\n\n\n\n\n\nconfig\n\n\n\nconfig\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:09:13.265102Z", + "start_time": "2024-12-07T18:09:12.750662Z" + } + }, + "cell_type": "code", + "source": [ + "# Inference\n", + "from hamilton import driver\n", + "\n", + "inference_dr = (\n", + " driver.Builder()\n", + " .with_config({})\n", + " .with_modules(features, inference)\n", + " .build()\n", + ")\n", + "inference_dr.display_all_functions()" + ], + "id": "cc9401ed081df22f", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\npredicted_data\n\npredicted_data\nDataFrame\n\n\n\ntransformed_data\n\ntransformed_data\nDataFrame\n\n\n\ntransformed_data->predicted_data\n\n\n\n\n\nraw_data\n\nraw_data\nDataFrame\n\n\n\nraw_data->transformed_data\n\n\n\n\n\n_predicted_data_inputs\n\nfit_model\ntyping.Any\n\n\n\n_predicted_data_inputs->predicted_data\n\n\n\n\n\n_raw_data_inputs\n\npath\nstr\n\n\n\n_raw_data_inputs->raw_data\n\n\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 10 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# To combine into a single dataflow we can use @subdag\n", + "\n", + "So if we want a single pipeline that enables us to:\n", + "\n", + "1. train the model & get training set predictions.\n", + "2. then use the fit model to predict on a separate dataset.\n", + "\n", + "To do that we define another module that uses the `@subdag` constructs that we wire together." + ], + "id": "d85c51388733ce96" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T07:00:23.770491Z", + "start_time": "2024-12-07T07:00:23.481869Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module pipeline --config '{\"model\":\"RandomForest\"}' --display\n", + "from typing import Any\n", + "\n", + "import pandas as pd\n", + "\n", + "from hamilton.function_modifiers import subdag, extract_fields, configuration, source\n", + "import features\n", + "import train\n", + "import inference\n", + "\n", + "@extract_fields(\n", + " {'fit_model': Any, 'training_prediction': pd.DataFrame}\n", + ")\n", + "@subdag(\n", + " features, train, inference,\n", + " inputs={\n", + " \"path\": source(\"path\"),\n", + " \"model_params\": source(\"model_params\"),\n", + " },\n", + " # there are several ways to pass in configuration.\n", + " # config={ \n", + " # \"model\": configuration(\"model\")\n", + " # },\n", + ")\n", + "def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict:\n", + " return {'fit_model': fit_model, 'training_prediction': predicted_data}\n", + "\n", + "@subdag(\n", + " features, inference,\n", + " inputs={\n", + " \"path\": source(\"predict_path\"),\n", + " \"fit_model\": source(\"fit_model\"),\n", + " },\n", + ")\n", + "def predicted_data(predicted_data: pd.DataFrame) -> pd.DataFrame:\n", + " return predicted_data" + ], + "id": "6d1585dad64464d7", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nmodel\n\n\n\nmodel\nRandomForest\n\n\n\ntrained_pipeline.base_model\n\ntrained_pipeline.base_model: model\ntyping.Any\n\n\n\ntrained_pipeline.fit_model\n\ntrained_pipeline.fit_model\ntyping.Any\n\n\n\ntrained_pipeline.base_model->trained_pipeline.fit_model\n\n\n\n\n\nfit_model\n\nfit_model\ntyping.Any\n\n\n\npredicted_data.fit_model\n\npredicted_data.fit_model\ntyping.Any\n\n\n\nfit_model->predicted_data.fit_model\n\n\n\n\n\npredicted_data.raw_data\n\npredicted_data.raw_data\nDataFrame\n\n\n\npredicted_data.transformed_data\n\npredicted_data.transformed_data\nDataFrame\n\n\n\npredicted_data.raw_data->predicted_data.transformed_data\n\n\n\n\n\npredicted_data.predicted_data\n\npredicted_data.predicted_data\nDataFrame\n\n\n\npredicted_data.transformed_data->predicted_data.predicted_data\n\n\n\n\n\ntrained_pipeline.predicted_data\n\ntrained_pipeline.predicted_data\nDataFrame\n\n\n\ntrained_pipeline\n\ntrained_pipeline\ndict\n\n\n\ntrained_pipeline.predicted_data->trained_pipeline\n\n\n\n\n\ntrained_pipeline.transformed_data\n\ntrained_pipeline.transformed_data\nDataFrame\n\n\n\ntrained_pipeline.transformed_data->trained_pipeline.predicted_data\n\n\n\n\n\ntrained_pipeline.transformed_data->trained_pipeline.fit_model\n\n\n\n\n\npredicted_data\n\npredicted_data\nDataFrame\n\n\n\npredicted_data.predicted_data->predicted_data\n\n\n\n\n\ntrained_pipeline.model_params\n\ntrained_pipeline.model_params\ndict\n\n\n\ntrained_pipeline.model_params->trained_pipeline.base_model\n\n\n\n\n\ntrained_pipeline.fit_model->trained_pipeline.predicted_data\n\n\n\n\n\ntrained_pipeline.fit_model->trained_pipeline\n\n\n\n\n\npredicted_data.fit_model->predicted_data.predicted_data\n\n\n\n\n\npredicted_data.path\n\npredicted_data.path\nstr\n\n\n\npredicted_data.path->predicted_data.raw_data\n\n\n\n\n\ntrained_pipeline.path\n\ntrained_pipeline.path\nstr\n\n\n\ntrained_pipeline.raw_data\n\ntrained_pipeline.raw_data\nDataFrame\n\n\n\ntrained_pipeline.path->trained_pipeline.raw_data\n\n\n\n\n\ntrained_pipeline->fit_model\n\n\n\n\n\ntraining_prediction\n\ntraining_prediction\nDataFrame\n\n\n\ntrained_pipeline->training_prediction\n\n\n\n\n\ntrained_pipeline.raw_data->trained_pipeline.transformed_data\n\n\n\n\n\n_trained_pipeline.model_params_inputs\n\nmodel_params\ndict\n\n\n\n_trained_pipeline.model_params_inputs->trained_pipeline.model_params\n\n\n\n\n\n_predicted_data.path_inputs\n\npredict_path\nstr\n\n\n\n_predicted_data.path_inputs->predicted_data.path\n\n\n\n\n\n_trained_pipeline.path_inputs\n\npath\nstr\n\n\n\n_trained_pipeline.path_inputs->trained_pipeline.path\n\n\n\n\n\nconfig\n\n\n\nconfig\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 8 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T06:57:20.874962Z", + "start_time": "2024-12-07T06:57:20.643256Z" + } + }, + "cell_type": "code", + "source": [ + "from hamilton import driver\n", + "\n", + "dr = (\n", + " driver.Builder()\n", + " .with_config({\"model\": \"RandomForest\", \"model_params\": {\"n_estimators\": 100}})\n", + " .with_modules(pipeline)\n", + " .build()\n", + ")\n", + "dr.display_all_functions()" + ], + "id": "f72146c07a654ca4", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nmodel\n\n\n\nmodel\nRandomForest\n\n\n\nmodel_params\n\n\n\nmodel_params\n{'n_estimators': 100}\n\n\n\ntrained_pipeline.base_model\n\ntrained_pipeline.base_model: model\ntyping.Any\n\n\n\ntrained_pipeline.fit_model\n\ntrained_pipeline.fit_model\ntyping.Any\n\n\n\ntrained_pipeline.base_model->trained_pipeline.fit_model\n\n\n\n\n\nfit_model\n\nfit_model\ntyping.Any\n\n\n\npredicted_data.fit_model\n\npredicted_data.fit_model\ntyping.Any\n\n\n\nfit_model->predicted_data.fit_model\n\n\n\n\n\ntrained_pipeline.model\n\ntrained_pipeline.model\nUpstreamDependency\n\n\n\npredicted_data.raw_data\n\npredicted_data.raw_data\nDataFrame\n\n\n\npredicted_data.transformed_data\n\npredicted_data.transformed_data\nDataFrame\n\n\n\npredicted_data.raw_data->predicted_data.transformed_data\n\n\n\n\n\npredicted_data.predicted_data\n\npredicted_data.predicted_data\nDataFrame\n\n\n\npredicted_data.transformed_data->predicted_data.predicted_data\n\n\n\n\n\ntrained_pipeline.predicted_data\n\ntrained_pipeline.predicted_data\nDataFrame\n\n\n\ntrained_pipeline\n\ntrained_pipeline\ndict\n\n\n\ntrained_pipeline.predicted_data->trained_pipeline\n\n\n\n\n\ntrained_pipeline.transformed_data\n\ntrained_pipeline.transformed_data\nDataFrame\n\n\n\ntrained_pipeline.transformed_data->trained_pipeline.predicted_data\n\n\n\n\n\ntrained_pipeline.transformed_data->trained_pipeline.fit_model\n\n\n\n\n\npredicted_data\n\npredicted_data\nDataFrame\n\n\n\npredicted_data.predicted_data->predicted_data\n\n\n\n\n\ntrained_pipeline.model_params\n\ntrained_pipeline.model_params\ndict\n\n\n\ntrained_pipeline.model_params->trained_pipeline.base_model\n\n\n\n\n\ntrained_pipeline.fit_model->trained_pipeline.predicted_data\n\n\n\n\n\ntrained_pipeline.fit_model->trained_pipeline\n\n\n\n\n\npredicted_data.fit_model->predicted_data.predicted_data\n\n\n\n\n\npredicted_data.path\n\npredicted_data.path\nstr\n\n\n\npredicted_data.path->predicted_data.raw_data\n\n\n\n\n\ntrained_pipeline.path\n\ntrained_pipeline.path\nstr\n\n\n\ntrained_pipeline.raw_data\n\ntrained_pipeline.raw_data\nDataFrame\n\n\n\ntrained_pipeline.path->trained_pipeline.raw_data\n\n\n\n\n\ntrained_pipeline->fit_model\n\n\n\n\n\ntraining_prediction\n\ntraining_prediction\nDataFrame\n\n\n\ntrained_pipeline->training_prediction\n\n\n\n\n\ntrained_pipeline.raw_data->trained_pipeline.transformed_data\n\n\n\n\n\n_trained_pipeline.model_params_inputs\n\nmodel_params\ndict\n\n\n\n_trained_pipeline.model_params_inputs->trained_pipeline.model_params\n\n\n\n\n\n_predicted_data.path_inputs\n\npredict_path\nstr\n\n\n\n_predicted_data.path_inputs->predicted_data.path\n\n\n\n\n\n_trained_pipeline.path_inputs\n\npath\nstr\n\n\n\n_trained_pipeline.path_inputs->trained_pipeline.path\n\n\n\n\n\nconfig\n\n\n\nconfig\n\n\n\ninput\n\ninput\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T06:57:34.959772Z", + "start_time": "2024-12-07T06:57:34.956204Z" + } + }, + "cell_type": "code", + "source": [ + "# this wont work because we don't actually have data...\n", + "# dr.execute([\"trained_pipeline\", \"predicted_data\"], \n", + "# inputs={\"path\": \"data.csv\", \"predict_path\": \"data.csv\"})" + ], + "id": "b3abca24b1a86329", + "outputs": [], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "1b3dba37a6c00d7c" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/model_examples/modular_example/pipeline.py b/examples/model_examples/modular_example/pipeline.py new file mode 100644 index 000000000..0c7ebb0eb --- /dev/null +++ b/examples/model_examples/modular_example/pipeline.py @@ -0,0 +1,37 @@ +from typing import Any + +import features +import inference +import pandas as pd +import train + +from hamilton.function_modifiers import configuration, extract_fields, source, subdag + + +@extract_fields({"fit_model": Any, "training_prediction": pd.DataFrame}) +@subdag( + features, + train, + inference, + inputs={ + "path": source("path"), + "model_params": source("model_params"), + }, + config={ + "model": configuration("train_model_type"), # not strictly required but allows us to remap. + }, +) +def trained_pipeline(fit_model: Any, predicted_data: pd.DataFrame) -> dict: + return {"fit_model": fit_model, "training_prediction": predicted_data} + + +@subdag( + features, + inference, + inputs={ + "path": source("predict_path"), + "fit_model": source("fit_model"), + }, +) +def predicted_data(predicted_data: pd.DataFrame) -> pd.DataFrame: + return predicted_data diff --git a/examples/model_examples/modular_example/run.py b/examples/model_examples/modular_example/run.py new file mode 100644 index 000000000..415bd2863 --- /dev/null +++ b/examples/model_examples/modular_example/run.py @@ -0,0 +1,18 @@ +import pipeline + +from hamilton import driver + + +def run(): + dr = ( + driver.Builder() + .with_config({"train_model_type": "RandomForest", "model_params": {"n_estimators": 100}}) + .with_modules(pipeline) + .build() + ) + dr.display_all_functions("./my_dag.png") + # dr.execute(["trained_pipeline", "predicted_data"]) + + +if __name__ == "__main__": + run() diff --git a/examples/model_examples/modular_example/train.py b/examples/model_examples/modular_example/train.py new file mode 100644 index 000000000..9d0d93bea --- /dev/null +++ b/examples/model_examples/modular_example/train.py @@ -0,0 +1,32 @@ +from typing import Any + +import pandas as pd + +from hamilton.function_modifiers import config + + +@config.when(model="RandomForest") +def base_model__rf(model_params: dict) -> Any: + from sklearn.ensemble import RandomForestClassifier + + return RandomForestClassifier(**model_params) + + +@config.when(model="LogisticRegression") +def base_model__lr(model_params: dict) -> Any: + from sklearn.linear_model import LogisticRegression + + return LogisticRegression(**model_params) + + +@config.when(model="XGBoost") +def base_model__xgb(model_params: dict) -> Any: + from xgboost import XGBClassifier + + return XGBClassifier(**model_params) + + +def fit_model(transformed_data: pd.DataFrame, base_model: Any) -> Any: + """Fit a model to transformed data.""" + base_model.fit(transformed_data.drop("target", axis=1), transformed_data["target"]) + return base_model diff --git a/hamilton/function_modifiers/__init__.py b/hamilton/function_modifiers/__init__.py index 958d07540..3113b13ff 100644 --- a/hamilton/function_modifiers/__init__.py +++ b/hamilton/function_modifiers/__init__.py @@ -36,6 +36,7 @@ value = dependencies.value source = dependencies.source group = dependencies.group +configuration = dependencies.configuration # These aren't strictly part of the API but we should have them here for safety LiteralDependency = dependencies.LiteralDependency diff --git a/hamilton/function_modifiers/dependencies.py b/hamilton/function_modifiers/dependencies.py index 26785505c..4f0c6d159 100644 --- a/hamilton/function_modifiers/dependencies.py +++ b/hamilton/function_modifiers/dependencies.py @@ -16,6 +16,7 @@ class ParametrizedDependencySource(enum.Enum): UPSTREAM = "upstream" GROUPED_LIST = "grouped_list" GROUPED_DICT = "grouped_dict" + CONFIGURATION = "configuration" class ParametrizedDependency: @@ -44,6 +45,14 @@ def get_dependency_type(self) -> ParametrizedDependencySource: return ParametrizedDependencySource.UPSTREAM +@dataclasses.dataclass +class ConfigDependency(SingleDependency): + source: str + + def get_dependency_type(self) -> ParametrizedDependencySource: + return ParametrizedDependencySource.CONFIGURATION + + class GroupedDependency(ParametrizedDependency, abc.ABC): @classmethod @abc.abstractmethod @@ -123,8 +132,8 @@ def value(literal_value: Any) -> LiteralDependency: E.G. value("foo") means that the value is actually the string value "foo". - :param literal_value: Python literal value to use. :return: A LiteralDependency object -- a - signifier to the internal framework of the dependency type. + :param literal_value: Python literal value to use. + :return: A LiteralDependency object -- a signifier to the internal framework of the dependency type. """ if isinstance(literal_value, LiteralDependency): return literal_value @@ -138,14 +147,25 @@ def source(dependency_on: Any) -> UpstreamDependency: be assigned the value that "foo" outputs. :param dependency_on: Upstream function (i.e. node) to come from. - :return: An - UpstreamDependency object -- a signifier to the internal framework of the dependency type. + :return: An UpstreamDependency object -- a signifier to the internal framework of the dependency type. """ if isinstance(dependency_on, UpstreamDependency): return dependency_on return UpstreamDependency(source=dependency_on) +def configuration(dependency_on: str) -> ConfigDependency: + """Specifies that a parameterized dependency comes from the global `config` passed in. + + This means that it comes from a global configuration key value. E.G. config("foo") means that it should + be assigned the value that the "foo" key in global configuration passed to Hamilton maps to. + + :param dependency_on: name of the configuration key to pull from. + :return: An ConfigDependency object -- a signifier to the internal framework of the dependency type. + """ + return ConfigDependency(source=dependency_on) + + def _validate_group_params( dependency_args: List[ParametrizedDependency], dependency_kwargs: Dict[str, ParametrizedDependency], diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index a330714a9..d86ad74f1 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -131,6 +131,43 @@ def _validate_config_inputs(config: Dict[str, Any], inputs: Dict[str, Any]): ) +def _resolve_subdag_configuration( + configuration: Dict[str, Any], fields: Dict[str, Any], function_name: str +) -> Dict[str, Any]: + """Resolves the configuration for a subdag. + + :param configuration: the Hamilton configuration + :param fields: the fields passed to the subdag decorator + :return: resolved configuration to use for this subdag. + """ + sources_to_map = {} + values_to_include = {} + for key, value in fields.items(): + if isinstance(value, dependencies.ConfigDependency): + sources_to_map[key] = value.source + elif isinstance(value, dependencies.LiteralDependency): + values_to_include[key] = value.value + elif isinstance(value, (dependencies.GroupedDependency, dependencies.SingleDependency)): + raise InvalidDecoratorException( + f"`{value}` is not allowed in the config= part of the subdag decorator. " + "Please use `configuration()` or `value()` or literal python values." + ) + plain_configs = { + k: v for k, v in fields.items() if k not in sources_to_map and k not in values_to_include + } + resolved_config = dict(configuration, **plain_configs, **values_to_include) + + # override any values from sources + for key, source in sources_to_map.items(): + try: + resolved_config[key] = resolved_config[source] + except KeyError as e: + raise InvalidDecoratorException( + f"Source {source} was not found in the configuration. This is required for the {function_name} subdag." + ) from e + return resolved_config + + NON_FINAL_TAGS = {NodeTransformer.NON_FINAL_TAG: True} @@ -423,7 +460,9 @@ def _derive_name(self, fn: Callable) -> str: def generate_nodes(self, fn: Callable, configuration: Dict[str, Any]) -> Collection[node.Node]: # Resolve all nodes from passed in functions - resolved_config = dict(configuration, **self.config) + # if self.config has configuration() or value() in it, we need to resolve it + resolved_config = _resolve_subdag_configuration(configuration, self.config, fn.__name__) + # resolved_config = dict(configuration, **self.config) nodes = self.collect_nodes(config=resolved_config, subdag_functions=self.subdag_functions) # Derive the namespace under which all these nodes will live namespace = self._derive_namespace(fn) diff --git a/tests/function_modifiers/test_recursive.py b/tests/function_modifiers/test_recursive.py index e9b76686c..f2f35ee71 100644 --- a/tests/function_modifiers/test_recursive.py +++ b/tests/function_modifiers/test_recursive.py @@ -10,6 +10,8 @@ from hamilton.function_modifiers import ( InvalidDecoratorException, config, + configuration, + group, parameterized_subdag, recursive, subdag, @@ -392,6 +394,122 @@ def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: ) +def test_nested_subdag_with_config_remapping(): + """Tests that we can remap config values and source and value are resolved correctly.""" + + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag(inner_subdag, inputs={"input_2": value(3)}, config={"broken": configuration("broken2")}) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + fg = graph.FunctionGraph.from_modules(full_module, config={"broken2": False}) + assert "outer_subdag_1" in fg.nodes + assert "outer_subdag_2" in fg.nodes + res = fg.execute(nodes=[fg.nodes["sum_all"]], inputs={"input_1": 2}) + # This is effectively the function graph + assert res["sum_all"] == sum_all( + outer_subdag_1(inner_subdag(bar(2), foo(10))), outer_subdag_2(inner_subdag(bar(2), foo(3))) + ) + + +def test_nested_subdag_with_config_remapping_missing_error(): + """Tests that we error if we can't remap a config value.""" + + def bar(input_1: int) -> int: + return input_1 + 1 + + @config.when(broken=False) + def foo(input_2: int) -> int: + return input_2 + 1 + + @subdag( + foo, + bar, + ) + def inner_subdag(foo: int, bar: int) -> Tuple[int, int]: + return foo, bar + + @subdag(inner_subdag, inputs={"input_2": value(10)}, config={"broken": value(False)}) + def outer_subdag_1(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + @subdag( + inner_subdag, + inputs={"input_2": value(3)}, + config={"broken": configuration("broken_missing")}, + ) + def outer_subdag_2(inner_subdag: Tuple[int, int]) -> int: + return sum(inner_subdag) + + def sum_all(outer_subdag_1: int, outer_subdag_2: int) -> int: + return outer_subdag_1 + outer_subdag_2 + + # we only need to generate from the outer subdag + # as it refers to the inner one + full_module = ad_hoc_utils.create_temporary_module(outer_subdag_1, outer_subdag_2, sum_all) + with pytest.raises(InvalidDecoratorException): + graph.FunctionGraph.from_modules(full_module, config={"broken2": False}) + + +@pytest.mark.parametrize( + "configuration,fields,expected", + [ + ({"a": 1, "b": 2}, {}, {"a": 1, "b": 2}), + ({"a": 1, "b": 2}, {"c": value(3)}, {"a": 1, "b": 2, "c": 3}), + ( + {"a": 1, "b": 2}, + {"c": value(3), "d": configuration("a")}, + {"a": 1, "b": 2, "c": 3, "d": 1}, + ), + ], +) +def test_resolve_subdag_configuration_happy(configuration, fields, expected): + actual = recursive._resolve_subdag_configuration(configuration, fields, "test") + assert actual == expected + + +def test_resolve_subdag_configuration_bad_mapping(): + _configuration = {"a": 1, "b": 2} + fields = {"c": value(3), "d": configuration("e")} + with pytest.raises(InvalidDecoratorException): + recursive._resolve_subdag_configuration(_configuration, fields, "test") + + +def test_resolve_subdag_configuration_flag_incorrect_source_group_deps(): + _configuration = {"a": 1, "b": 2} + with pytest.raises(InvalidDecoratorException): + recursive._resolve_subdag_configuration( + _configuration, {"c": value(3), "d": source("e")}, "test" + ) + with pytest.raises(InvalidDecoratorException): + recursive._resolve_subdag_configuration( + _configuration, {"c": value(3), "d": group(source("e"))}, "test" + ) + + def test_subdag_with_external_nodes_input(): def bar(input_1: int) -> int: return input_1 + 1