Skip to content

Commit

Permalink
Model registration tags come from parameters.json (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
jotaylo authored Mar 24, 2020
1 parent 2d54311 commit 352ebbe
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
2 changes: 1 addition & 1 deletion diabetes_regression/evaluate/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
"--model_name",
type=str,
help="Name of the Model",
default="sklearn_regression_model.pkl",
default="diabetes_model.pkl",
)

parser.add_argument(
Expand Down
4 changes: 4 additions & 0 deletions diabetes_regression/parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
"evaluation":
{

},
"registration":
{
"tags": ["mse"]
},
"scoring":
{
Expand Down
40 changes: 30 additions & 10 deletions diabetes_regression/register/register_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""
import json
import os
import sys
import argparse
Expand Down Expand Up @@ -69,8 +70,9 @@ def main():
"--model_name",
type=str,
help="Name of the Model",
default="sklearn_regression_model.pkl",
default="diabetes_model.pkl",
)

parser.add_argument(
"--step_input",
type=str,
Expand All @@ -85,40 +87,58 @@ def main():
model_name = args.model_name
model_path = args.step_input

print("Getting registration parameters")

# Load the registration parameters from the parameters file
with open("parameters.json") as f:
pars = json.load(f)
try:
register_args = pars["registration"]
except KeyError:
print("Could not load registration values from file")
register_args = {"tags": []}

model_tags = {}
for tag in register_args["tags"]:
try:
mtag = run.parent.get_metrics()[tag]
model_tags[tag] = mtag
except KeyError:
print(f"Could not find {tag} metric on parent run.")

# load the model
print("Loading model from " + model_path)
model_file = os.path.join(model_path, model_name)
model = joblib.load(model_file)
model_mse = run.parent.get_metrics()["mse"]
parent_tags = run.parent.get_tags()
try:
build_id = parent_tags["BuildId"]
except KeyError:
build_id = None
print("BuildId tag not found on parent run.")
print("Tags present: {parent_tags}")
print(f"Tags present: {parent_tags}")
try:
build_uri = parent_tags["BuildUri"]
except KeyError:
build_uri = None
print("BuildUri tag not found on parent run.")
print("Tags present: {parent_tags}")
print(f"Tags present: {parent_tags}")

if (model is not None):
dataset_id = parent_tags["dataset_id"]
if (build_id is None):
register_aml_model(
model_file,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id)
elif (build_uri is None):
register_aml_model(
model_file,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id,
Expand All @@ -127,7 +147,7 @@ def main():
register_aml_model(
model_file,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id,
Expand All @@ -152,7 +172,7 @@ def model_already_registered(model_name, exp, run_id):
def register_aml_model(
model_path,
model_name,
model_mse,
model_tags,
exp,
run_id,
dataset_id,
Expand All @@ -162,8 +182,8 @@ def register_aml_model(
try:
tagsValue = {"area": "diabetes_regression",
"run_id": run_id,
"experiment_name": exp.name,
"mse": model_mse}
"experiment_name": exp.name}
tagsValue.update(model_tags)
if (build_id != 'none'):
model_already_registered(model_name, exp, run_id)
tagsValue["BuildId"] = build_id
Expand Down
2 changes: 1 addition & 1 deletion diabetes_regression/training/train_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def main():
"--model_name",
type=str,
help="Name of the Model",
default="sklearn_regression_model.pkl",
default="diabetes_model.pkl",
)

parser.add_argument(
Expand Down

0 comments on commit 352ebbe

Please sign in to comment.