Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata to model versions #2109

Merged
merged 17 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/update-templates-to-examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ jobs:
python-version: ${{ inputs.python-version }}
stack-name: local
ref-zenml: ${{ github.ref }}
ref-template: '2023.11.24' # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py
ref-template: '2023.12.06' # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py

- name: Clean-up
run: |
rm -rf ./local_checkout
Expand Down
2 changes: 1 addition & 1 deletion examples/e2e/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 2023.11.24
_commit: 2023.11.23-2-gc19b794
_src_path: gh:zenml-io/template-e2e-batch
data_quality_checks: true
email: ''
Expand Down
7 changes: 4 additions & 3 deletions examples/e2e/configs/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ steps:
model_trainer:
parameters:
name: e2e_use_case
compute_performance_metrics_on_current_data:
parameters:
target_env: staging
promote_with_metric_compare:
parameters:
mlflow_model_name: e2e_use_case
target_env: staging
notify_on_success:
parameters:
notify_on_success: False
Expand All @@ -61,9 +65,6 @@ model_version:
# pipeline level extra configurations
extra:
notify_on_failure: True
# pipeline level parameters
parameters:
target_env: staging
# This set contains all the model configurations that you want
# to evaluate during hyperparameter tuning stage.
model_search_space:
Expand Down
19 changes: 8 additions & 11 deletions examples/e2e/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


import random
from typing import Any, Dict, List, Optional
from typing import List, Optional

from steps import (
compute_performance_metrics_on_current_data,
Expand All @@ -33,16 +33,14 @@
train_data_splitter,
)

from zenml import pipeline
from zenml import get_pipeline_context, pipeline
from zenml.logger import get_logger

logger = get_logger(__name__)


@pipeline(on_failure=notify_on_failure)
def e2e_use_case_training(
model_search_space: Dict[str, Any],
target_env: str,
test_size: float = 0.2,
drop_na: Optional[bool] = None,
normalize: Optional[bool] = None,
Expand All @@ -59,8 +57,6 @@ def e2e_use_case_training(
trains and evaluates a model.

Args:
model_search_space: Search space for hyperparameter tuning
target_env: The environment to promote the model to
test_size: Size of holdout set for training 0.0..1.0
drop_na: If `True` NA values will be removed from dataset
normalize: If `True` dataset will be normalized with MinMaxScaler
Expand All @@ -69,10 +65,12 @@ def e2e_use_case_training(
min_test_accuracy: Threshold to stop execution if test set accuracy is lower
fail_on_accuracy_quality_gates: If `True` and `min_train_accuracy` or `min_test_accuracy`
are not met - execution will be interrupted early

"""
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
# Link all the steps together by calling them and passing the output
# of one step as the input of the next step.
pipeline_extra = get_pipeline_context().extra
########## ETL stage ##########
raw_data, target, _ = data_loader(random_state=random.randint(0, 100))
dataset_trn, dataset_tst = train_data_splitter(
Expand All @@ -89,7 +87,9 @@ def e2e_use_case_training(
########## Hyperparameter tuning stage ##########
after = []
search_steps_prefix = "hp_tuning_search_"
for config_name, model_search_configuration in model_search_space.items():
for config_name, model_search_configuration in pipeline_extra[
"model_search_space"
].items():
step_name = f"{search_steps_prefix}{config_name}"
hp_tuning_single_search(
id=step_name,
Expand Down Expand Up @@ -123,15 +123,12 @@ def e2e_use_case_training(
latest_metric,
current_metric,
) = compute_performance_metrics_on_current_data(
dataset_tst=dataset_tst,
target_env=target_env,
after=["model_evaluator"],
dataset_tst=dataset_tst, after=["model_evaluator"]
)

promote_with_metric_compare(
latest_metric=latest_metric,
current_metric=current_metric,
target_env=target_env,
)
last_step = "promote_with_metric_compare"

Expand Down
6 changes: 3 additions & 3 deletions examples/e2e/steps/deployment/deployment_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def deployment_deploy() -> (
# deploy predictor service
deployment_service = mlflow_model_registry_deployer_step.entrypoint(
registry_model_name=model_version.name,
registry_model_version=model_version.get_model_artifact("model")
.run_metadata["model_registry_version"]
.value,
registry_model_version=model_version.metadata[
"model_registry_version"
],
replace_existing=True,
)
else:
Expand Down
22 changes: 7 additions & 15 deletions examples/e2e/steps/promotion/promote_with_metric_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,17 @@ def promote_with_metric_compare(
logger.info(f"Current model version was promoted to '{target_env}'.")

# Promote in Model Registry
latest_version_model_registry_number = (
latest_version.get_model_artifact("model")
.run_metadata["model_registry_version"]
.value
)
latest_version_model_registry_number = latest_version.metadata[
"model_registry_version"
]
if current_version_number is None:
current_version_model_registry_number = (
latest_version_model_registry_number
)
else:
current_version_model_registry_number = (
current_version.get_model_artifact("model")
.run_metadata["model_registry_version"]
.value
)
current_version_model_registry_number = current_version.metadata[
"model_registry_version"
]
promote_in_model_registry(
latest_version=latest_version_model_registry_number,
current_version=current_version_model_registry_number,
Expand All @@ -113,11 +109,7 @@ def promote_with_metric_compare(
)
promoted_version = latest_version_model_registry_number
else:
promoted_version = (
current_version.get_model_artifact("model")
.run_metadata["model_registry_version"]
.value
)
promoted_version = current_version.metadata["model_registry_version"]

logger.info(
f"Current model version in `{target_env}` is `{promoted_version}` registered in Model Registry"
Expand Down
8 changes: 4 additions & 4 deletions examples/e2e/steps/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sklearn.base import ClassifierMixin
from typing_extensions import Annotated

from zenml import ArtifactConfig, log_artifact_metadata, step
from zenml import ArtifactConfig, get_step_context, step
from zenml.client import Client
from zenml.integrations.mlflow.experiment_trackers import (
MLFlowExperimentTracker,
Expand Down Expand Up @@ -103,9 +103,9 @@ def model_trainer(
if model_registry:
versions = model_registry.list_model_versions(name=name)
if versions:
log_artifact_metadata(
metadata={"model_registry_version": versions[-1].version},
artifact_name="model",
model_version = get_step_context().model_version
model_version.log_metadata(
{"model_registry_version": versions[-1].version}
)
### YOUR CODE ENDS HERE ###

Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def copier_github_url(self) -> str:
ZENML_PROJECT_TEMPLATES = dict(
e2e_batch=ZenMLProjectTemplateLocation(
github_url="zenml-io/template-e2e-batch",
github_tag="2023.11.24", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
github_tag="2023.12.06", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
),
starter=ZenMLProjectTemplateLocation(
github_url="zenml-io/template-starter",
Expand Down
1 change: 1 addition & 0 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def _model_version_to_print(
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"metadata": model_version.to_model_version().metadata,
"tags": [t.name for t in model_version.tags],
"data_artifacts_count": len(model_version.data_artifact_ids),
"model_artifacts_count": len(model_version.model_artifact_ids),
Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,4 @@ class MetadataResourceTypes(StrEnum):
PIPELINE_RUN = "pipeline_run"
STEP_RUN = "step_run"
ARTIFACT_VERSION = "artifact_version"
MODEL_VERSION = "model_version"
50 changes: 48 additions & 2 deletions src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@

from pydantic import BaseModel, PrivateAttr, root_validator

from zenml.enums import ModelStages
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger

if TYPE_CHECKING:
from zenml import ExternalArtifact
from zenml.metadata.metadata_types import MetadataType
from zenml.models import (
ArtifactVersionResponse,
ModelResponse,
Expand Down Expand Up @@ -306,6 +307,46 @@ def set_stage(
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)

def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.

This function can be used to log metadata for current model version.

Args:
metadata: The metadata to log.
"""
from zenml.client import Client

response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)

@property
def metadata(self) -> Dict[str, "MetadataType"]:
"""Get model version metadata.

Returns:
The model version metadata.

Raises:
RuntimeError: If the model version metadata cannot be fetched.
"""
response = self._get_or_create_model_version(hydrate=True)
if response.run_metadata is None:
raise RuntimeError(
"Failed to fetch metadata of this model version."
)
return {
name: response.value
for name, response in response.run_metadata.items()
}

#########################
# Internal methods #
#########################
Expand Down Expand Up @@ -431,7 +472,9 @@ def _get_model_version(self) -> "ModelVersionResponse":

return mv

def _get_or_create_model_version(self) -> "ModelVersionResponse":
def _get_or_create_model_version(
self, hydrate: bool = False
) -> "ModelVersionResponse":
"""This method should get or create a model and a model version from Model Control Plane.

A new model is created implicitly if missing, otherwise existing model is fetched. Model
Expand All @@ -444,6 +487,9 @@ def _get_or_create_model_version(self) -> "ModelVersionResponse":
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

Args:
hydrate: Whether to return a hydrated version of the model version.

Returns:
The model version based on configuration.

Expand Down
45 changes: 45 additions & 0 deletions src/zenml/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from typing import (
Dict,
Optional,
Union,
)
from uuid import UUID

from zenml.artifacts.artifact_config import ArtifactConfig
from zenml.client import Client
from zenml.enums import ModelStages
from zenml.exceptions import StepContextError
from zenml.logger import get_logger
from zenml.metadata.metadata_types import MetadataType
from zenml.model.model_version import ModelVersion
from zenml.models import ModelVersionArtifactRequest
from zenml.new.steps.step_context import get_step_context
Expand Down Expand Up @@ -111,3 +114,45 @@ def link_artifact_config_to_model_version(
is_endpoint_artifact=artifact_config.is_endpoint_artifact,
)
client.zen_store.create_model_version_artifact_link(request)


def log_model_version_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.

This function can be used to log metadata for existing model versions.

Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.

Raises:
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model_version` in decorator.
"""
mv = None
try:
step_context = get_step_context()
mv = step_context.model_version
except RuntimeError:
step_context = None

if not step_context and not (model_name and model_version):
raise ValueError(
"Model name and version must be provided unless the function is "
"called inside a step with configured `model_version` in decorator."
)
if mv is None:
from zenml import ModelVersion

mv = ModelVersion(name=model_name, version=model_version)

mv.log_metadata(metadata)
2 changes: 2 additions & 0 deletions src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,11 @@
ModelVersionResponseBody.update_forward_refs(
UserResponse=UserResponse,
ModelResponse=ModelResponse,
RunMetadataResponse=RunMetadataResponse,
)
ModelVersionResponseMetadata.update_forward_refs(
WorkspaceResponse=WorkspaceResponse,
RunMetadataResponse=RunMetadataResponse,
)
ModelVersionArtifactResponseBody.update_forward_refs(
ArtifactVersionResponse=ArtifactVersionResponse,
Expand Down
Loading
Loading