Skip to content

Commit

Permalink
Added get_latest_model method (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
starlord-daniel authored Mar 15, 2020
1 parent 5887633 commit d531b2e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
6 changes: 3 additions & 3 deletions diabetes_regression/evaluate/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from azureml.core import Run
import argparse
import traceback
from util.model_helper import get_model_by_tag
from util.model_helper import get_latest_model

run = Run.get_context()

Expand All @@ -45,7 +45,7 @@
# sources_dir = 'diabetes_regression'
# path_to_util = os.path.join(".", sources_dir, "util")
# sys.path.append(os.path.abspath(path_to_util)) # NOQA: E402
# from model_helper import get_model_by_tag
# from model_helper import get_latest_model
# workspace_name = os.environ.get("WORKSPACE_NAME")
# experiment_name = os.environ.get("EXPERIMENT_NAME")
# resource_group = os.environ.get("RESOURCE_GROUP")
Expand Down Expand Up @@ -108,7 +108,7 @@
firstRegistration = False
tag_name = 'experiment_name'

model = get_model_by_tag(
model = get_latest_model(
model_name, tag_name, exp.name, ws)

if (model is not None):
Expand Down
57 changes: 32 additions & 25 deletions diabetes_regression/util/model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ def get_current_workspace() -> Workspace:
return experiment.workspace


def get_model_by_tag(
def get_latest_model(
model_name: str,
tag_name: str,
tag_value: str,
tag_name: str = None,
tag_value: str = None,
aml_workspace: Workspace = None
) -> AMLModel:
"""
Retrieves and returns the latest model from the workspace
by its name and tag.
by its name and (optional) tag.
Parameters:
aml_workspace (Workspace): aml.core Workspace that the model lives.
model_name (str): name of the model we are looking for
tag (str): the tag value the model was registered under.
(optional) tag (str): the tag value & name the model was registered under.
Return:
A single aml model from the workspace that matches the name and tag.
Expand All @@ -44,37 +44,44 @@ def get_model_by_tag(
# Validate params. cannot be None.
if model_name is None:
raise ValueError("model_name[:str] is required")
if tag_name is None:
raise ValueError("tag_name[:str] is required")
if tag_value is None:
raise ValueError("tag[:str] is required")

if aml_workspace is None:
print("No workspace defined - using current experiment workspace.")
aml_workspace = get_current_workspace()

# get model by tag.
model_list = AMLModel.list(
aml_workspace, name=model_name,
tags=[[tag_name, tag_value]], latest=True
)
model_list = None
tag_ext = ""

# Get lastest model
# True: by name and tags
if tag_name is not None and tag_value is not None:
model_list = AMLModel.list(
aml_workspace, name=model_name,
tags=[[tag_name, tag_value]], latest=True
)
tag_ext = f"tag_name: {tag_name}, tag_value: {tag_value}."
# False: Only by name
else:
model_list = AMLModel.list(
aml_workspace, name=model_name, latest=True)

# latest should only return 1 model, but if it does,
# then maybe sdk or source code changed.
should_not_happen = ("Found more than one model "
"for the latest with {{tag_name: {tag_name},"
"tag_value: {tag_value}. "
"Models found: {model_list}}}")\
.format(tag_name=tag_name, tag_value=tag_value,
model_list=model_list)
no_model_found = ("No Model found with {{tag_name: {tag_name} ,"
"tag_value: {tag_value}.}}")\
.format(tag_name=tag_name, tag_value=tag_value)

# define the error messages
too_many_model_message = ("Found more than one latest model. "
f"Models found: {model_list}. "
f"{tag_ext}")

no_model_found_message = (f"No Model found with name: {model_name}. "
f"{tag_ext}")

if len(model_list) > 1:
raise ValueError(should_not_happen)
raise ValueError(too_many_model_message)
if len(model_list) == 1:
return model_list[0]
else:
print(no_model_found)
print(no_model_found_message)
return None
except Exception:
raise
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from azureml.core import Run, Experiment, Workspace
from ml_service.util.env_variables import Env
from diabetes_regression.util.model_helper import get_model_by_tag
from diabetes_regression.util.model_helper import get_latest_model


def main():
Expand Down Expand Up @@ -53,7 +53,7 @@ def main():

try:
tag_name = 'BuildId'
model = get_model_by_tag(
model = get_latest_model(
model_name, tag_name, build_id, exp.workspace)
if (model is not None):
print("Model was registered for this build.")
Expand Down

0 comments on commit d531b2e

Please sign in to comment.