From 352ebbeba745ae1f421dddc7f36f81c4db766eee Mon Sep 17 00:00:00 2001 From: jotaylo Date: Tue, 24 Mar 2020 10:48:35 -0700 Subject: [PATCH] Model registration tags come from parameters.json (#237) --- .../evaluate/evaluate_model.py | 2 +- diabetes_regression/parameters.json | 4 ++ .../register/register_model.py | 40 ++++++++++++++----- diabetes_regression/training/train_aml.py | 2 +- 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/diabetes_regression/evaluate/evaluate_model.py b/diabetes_regression/evaluate/evaluate_model.py index 83e422e8..125a16a5 100644 --- a/diabetes_regression/evaluate/evaluate_model.py +++ b/diabetes_regression/evaluate/evaluate_model.py @@ -83,7 +83,7 @@ "--model_name", type=str, help="Name of the Model", - default="sklearn_regression_model.pkl", + default="diabetes_model.pkl", ) parser.add_argument( diff --git a/diabetes_regression/parameters.json b/diabetes_regression/parameters.json index 859fd84d..48f7227d 100644 --- a/diabetes_regression/parameters.json +++ b/diabetes_regression/parameters.json @@ -6,6 +6,10 @@ "evaluation": { + }, + "registration": + { + "tags": ["mse"] }, "scoring": { diff --git a/diabetes_regression/register/register_model.py b/diabetes_regression/register/register_model.py index 3376285e..bca55a83 100644 --- a/diabetes_regression/register/register_model.py +++ b/diabetes_regression/register/register_model.py @@ -23,6 +23,7 @@ ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +import json import os import sys import argparse @@ -69,8 +70,9 @@ def main(): "--model_name", type=str, help="Name of the Model", - default="sklearn_regression_model.pkl", + default="diabetes_model.pkl", ) + parser.add_argument( "--step_input", type=str, @@ -85,24 +87,42 @@ def main(): model_name = args.model_name model_path = args.step_input + print("Getting registration parameters") + + # Load the registration parameters from the parameters file + with open("parameters.json") as f: + pars = json.load(f) + try: + register_args = pars["registration"] + except KeyError: + print("Could not load registration values from file") + register_args = {"tags": []} + + model_tags = {} + for tag in register_args["tags"]: + try: + mtag = run.parent.get_metrics()[tag] + model_tags[tag] = mtag + except KeyError: + print(f"Could not find {tag} metric on parent run.") + # load the model print("Loading model from " + model_path) model_file = os.path.join(model_path, model_name) model = joblib.load(model_file) - model_mse = run.parent.get_metrics()["mse"] parent_tags = run.parent.get_tags() try: build_id = parent_tags["BuildId"] except KeyError: build_id = None print("BuildId tag not found on parent run.") - print("Tags present: {parent_tags}") + print(f"Tags present: {parent_tags}") try: build_uri = parent_tags["BuildUri"] except KeyError: build_uri = None print("BuildUri tag not found on parent run.") - print("Tags present: {parent_tags}") + print(f"Tags present: {parent_tags}") if (model is not None): dataset_id = parent_tags["dataset_id"] @@ -110,7 +130,7 @@ def main(): register_aml_model( model_file, model_name, - model_mse, + model_tags, exp, run_id, dataset_id) @@ -118,7 +138,7 @@ def main(): register_aml_model( model_file, model_name, - model_mse, + model_tags, exp, run_id, dataset_id, @@ -127,7 +147,7 @@ def main(): register_aml_model( model_file, model_name, - model_mse, + model_tags, exp, run_id, dataset_id, @@ -152,7 +172,7 @@ def model_already_registered(model_name, exp, run_id): def register_aml_model( model_path, model_name, - model_mse, + model_tags, exp, run_id, dataset_id, @@ -162,8 +182,8 @@ def register_aml_model( try: tagsValue = {"area": "diabetes_regression", "run_id": run_id, - "experiment_name": exp.name, - "mse": model_mse} + "experiment_name": exp.name} + tagsValue.update(model_tags) if (build_id != 'none'): model_already_registered(model_name, exp, run_id) tagsValue["BuildId"] = build_id diff --git a/diabetes_regression/training/train_aml.py b/diabetes_regression/training/train_aml.py index 65904b20..9303198b 100644 --- a/diabetes_regression/training/train_aml.py +++ b/diabetes_regression/training/train_aml.py @@ -55,7 +55,7 @@ def main(): "--model_name", type=str, help="Name of the Model", - default="sklearn_regression_model.pkl", + default="diabetes_model.pkl", ) parser.add_argument(