Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata to model versions #2109

Merged
merged 17 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _model_version_to_print(
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"metadata": model_version.to_model_version().metadata,
"data_artifacts_count": len(model_version.data_artifact_ids),
"model_artifacts_count": len(model_version.model_artifact_ids),
"endpoint_artifacts_count": len(model_version.endpoint_artifact_ids),
Expand Down
1 change: 1 addition & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,4 @@ class MetadataResourceTypes(StrEnum):
PIPELINE_RUN = "pipeline_run"
STEP_RUN = "step_run"
ARTIFACT_VERSION = "artifact_version"
MODEL_VERSION = "model_version"
36 changes: 35 additions & 1 deletion src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@

from pydantic import BaseModel, PrivateAttr, root_validator

from zenml.enums import ModelStages
from zenml.enums import MetadataResourceTypes, ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger

if TYPE_CHECKING:
from zenml import ExternalArtifact
from zenml.metadata.metadata_types import MetadataType
from zenml.models import (
ArtifactVersionResponse,
ModelResponseModel,
Expand Down Expand Up @@ -306,6 +307,39 @@ def set_stage(
"""
self._get_or_create_model_version().set_stage(stage=stage, force=force)

def log_metadata(
self,
metadata: Dict[str, "MetadataType"],
) -> None:
"""Log model version metadata.

This function can be used to log metadata for current model version.

Args:
metadata: The metadata to log.
"""
from zenml.client import Client

response = self._get_or_create_model_version()
Client().create_run_metadata(
metadata=metadata,
resource_id=response.id,
resource_type=MetadataResourceTypes.MODEL_VERSION,
)

@property
def metadata(self) -> Dict[str, "MetadataType"]:
"""Get model version metadata.

Returns:
The model version metadata.
"""
response = self._get_or_create_model_version()
return {
name: response.value
for name, response in response.run_metadata.items()
}

#########################
# Internal methods #
#########################
Expand Down
45 changes: 45 additions & 0 deletions src/zenml/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from typing import (
Dict,
Optional,
Union,
)
from uuid import UUID

from zenml.artifacts.artifact_config import ArtifactConfig
from zenml.client import Client
from zenml.enums import ModelStages
from zenml.exceptions import StepContextError
from zenml.logger import get_logger
from zenml.metadata.metadata_types import MetadataType
from zenml.model.model_version import ModelVersion
from zenml.models.model_models import (
ModelVersionArtifactRequestModel,
Expand Down Expand Up @@ -113,3 +116,45 @@ def link_artifact_config_to_model_version(
is_endpoint_artifact=artifact_config.is_endpoint_artifact,
)
client.zen_store.create_model_version_artifact_link(request)


def log_model_version_metadata(
metadata: Dict[str, "MetadataType"],
model_name: Optional[str] = None,
model_version: Optional[Union[ModelStages, int, str]] = None,
) -> None:
"""Log model version metadata.

This function can be used to log metadata for existing model versions.

Args:
metadata: The metadata to log.
model_name: The name of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.
model_version: The version of the model to log metadata for. Can
be omitted when being called inside a step with configured
`model_version` in decorator.

Raises:
ValueError: If no model name/version is provided and the function is not
called inside a step with configured `model_version` in decorator.
"""
mv = None
try:
step_context = get_step_context()
mv = step_context.model_version
except RuntimeError:
step_context = None

if not step_context and not (model_name and model_version):
raise ValueError(
"Model name and version must be provided unless the function is "
"called inside a step with configured `model_version` in decorator."
)
if mv is None:
from zenml import ModelVersion

mv = ModelVersion(name=model_name, version=model_version)

mv.log_metadata(metadata)
1 change: 1 addition & 0 deletions src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@
ModelVersionResponseModel.update_forward_refs(
UserResponse=UserResponse,
WorkspaceResponse=WorkspaceResponse,
RunMetadataResponse=RunMetadataResponse,
)
ModelVersionArtifactRequestModel.update_forward_refs(
UserResponse=UserResponse,
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/models/model_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@

if TYPE_CHECKING:
from zenml.model.model_version import ModelVersion
from zenml.models.v2.core.run_metadata import (
RunMetadataResponse,
)
from zenml.zen_stores.schemas import BaseSchema

AnySchema = TypeVar("AnySchema", bound=BaseSchema)
Expand Down Expand Up @@ -158,6 +161,7 @@ class ModelVersionResponseModel(
description="Pipeline runs linked to the model version",
default={},
)
run_metadata: Dict[str, "RunMetadataResponse"] = {}
avishniakov marked this conversation as resolved.
Show resolved Hide resolved

def to_model_version(
self,
Expand Down
15 changes: 14 additions & 1 deletion src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column
from sqlmodel import Field, Relationship

from zenml.enums import TaggableResourceTypes
from zenml.enums import MetadataResourceTypes, TaggableResourceTypes
from zenml.models import (
ModelRequestModel,
ModelResponseModel,
Expand All @@ -34,6 +34,7 @@
from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
Expand Down Expand Up @@ -229,6 +230,15 @@ class ModelVersionSchema(NamedSchema, table=True):
description: str = Field(sa_column=Column(TEXT, nullable=True))
stage: str = Field(sa_column=Column(TEXT, nullable=True))

run_metadata: List["RunMetadataSchema"] = Relationship(
back_populates="model_version",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)",
cascade="delete",
overlaps="run_metadata",
),
)

@classmethod
def from_request(
cls, model_version_request: ModelVersionRequestModel
Expand Down Expand Up @@ -310,6 +320,9 @@ def to_model(
endpoint_artifact_ids=endpoint_artifact_ids,
data_artifact_ids=data_artifact_ids,
pipeline_run_ids=pipeline_run_ids,
run_metadata={
rm.key: rm.to_model(hydrate=False) for rm in self.run_metadata
},
)

def update(
Expand Down
24 changes: 17 additions & 7 deletions src/zenml/zen_stores/schemas/run_metadata_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


import json
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from uuid import UUID

from sqlalchemy import TEXT, VARCHAR, Column
Expand All @@ -28,15 +28,18 @@
RunMetadataResponseBody,
RunMetadataResponseMetadata,
)
from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
from zenml.zen_stores.schemas.base_schemas import BaseSchema
from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
from zenml.zen_stores.schemas.user_schemas import UserSchema
from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema

if TYPE_CHECKING:
from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema
from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema


class RunMetadataSchema(BaseSchema, table=True):
"""SQL Model for run metadata."""
Expand All @@ -49,21 +52,28 @@ class RunMetadataSchema(BaseSchema, table=True):
back_populates="run_metadata",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataSchema.resource_id)==PipelineRunSchema.id)",
overlaps="run_metadata,step_run,artifact_version",
overlaps="run_metadata,step_run,artifact_version,model_version",
),
)
step_run: List["StepRunSchema"] = Relationship(
back_populates="run_metadata",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataSchema.resource_id)==StepRunSchema.id)",
overlaps="run_metadata,pipeline_run,artifact_version",
overlaps="run_metadata,pipeline_run,artifact_version,model_version",
),
)
artifact_version: List["ArtifactVersionSchema"] = Relationship(
back_populates="run_metadata",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ArtifactVersionSchema.id)",
overlaps="run_metadata,pipeline_run,step_run",
overlaps="run_metadata,pipeline_run,step_run,model_version",
),
)
model_version: List["ModelVersionSchema"] = Relationship(
back_populates="run_metadata",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(RunMetadataSchema.resource_type=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataSchema.resource_id)==ModelVersionSchema.id)",
overlaps="run_metadata,pipeline_run,step_run,artifact_version",
),
)
stack_component_id: Optional[UUID] = build_foreign_key_field(
Expand Down
72 changes: 72 additions & 0 deletions tests/integration/functional/model/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import pytest

from tests.integration.functional.utils import model_killer, tags_killer
from zenml import get_step_context, pipeline, step
from zenml.client import Client
from zenml.enums import ModelStages
from zenml.model.model_version import ModelVersion
from zenml.model.utils import log_model_version_metadata
from zenml.models.tag_models import TagRequestModel

MODEL_NAME = "super_model"
Expand Down Expand Up @@ -62,6 +64,12 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
pass


@step
def step_metadata_logging_functional():
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
log_model_version_metadata({"foo": "bar"})
assert get_step_context().model_version.metadata["foo"] == "bar"


class TestModelVersion:
def test_model_created_with_warning(self):
"""Test if the model is created with a warning.
Expand Down Expand Up @@ -217,3 +225,67 @@ def test_tags_properly_updated(self):
model = mv._get_or_create_model()
assert len(model.tags) == 2
assert {t.name for t in model.tags} == {"foo", "bar"}

def test_metadata_logging(self):
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
"""Test that model version can be used to track metadata from object."""
with model_killer():
mv = ModelVersion(
name=MODEL_NAME,
description="foo",
)
mv.log_metadata({"foo": "bar"})

assert len(mv.metadata) == 1
assert mv.metadata["foo"] == "bar"

mv.log_metadata({"bar": "foo"})

assert len(mv.metadata) == 2
assert mv.metadata["foo"] == "bar"
assert mv.metadata["bar"] == "foo"

def test_metadata_logging_functional(self):
"""Test that model version can be used to track metadata from function."""
with model_killer():
mv = ModelVersion(
name=MODEL_NAME,
description="foo",
)
mv._get_or_create_model_version()

log_model_version_metadata(
{"foo": "bar"}, model_name=mv.name, model_version=mv.number
)

assert len(mv.metadata) == 1
assert mv.metadata["foo"] == "bar"

with pytest.raises(ValueError):
log_model_version_metadata({"foo": "bar"})

log_model_version_metadata(
{"bar": "foo"}, model_name=mv.name, model_version="latest"
)

assert len(mv.metadata) == 2
assert mv.metadata["foo"] == "bar"
assert mv.metadata["bar"] == "foo"

def test_metadata_logging_in_steps(self):
"""Test that model version can be used to track metadata from function in steps."""
with model_killer():

@pipeline(
model_version=ModelVersion(
name=MODEL_NAME,
),
enable_cache=False,
)
def my_pipeline():
step_metadata_logging_functional()

my_pipeline()

mv = ModelVersion(name=MODEL_NAME, version="latest")
assert len(mv.metadata) == 1
assert mv.metadata["foo"] == "bar"
Loading