diff --git a/.gitignore b/.gitignore index ed067acc544..44c0ca1aa7a 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,5 @@ zenml_tutorial/ # script for testing mlstacks_reset.sh + +.local/ \ No newline at end of file diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 007aa1e7a37..f36ef7db778 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -222,6 +222,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_RESOURCES = "/resources" SERVICE_CONNECTOR_CLIENT = "/client" MODELS = "/models" +MODEL_VERSIONS = "/model_versions" # mandatory stack component attributes MANDATORY_COMPONENT_ATTRIBUTES = ["name", "uuid"] diff --git a/src/zenml/model/__init__.py b/src/zenml/model/__init__.py new file mode 100644 index 00000000000..af2e669c0d5 --- /dev/null +++ b/src/zenml/model/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Initialization of ZenML model. +ZenML model support Model WatchTower feature. +""" + +from zenml.model.model_stages import ModelStages + +__all__ = [ + "ModelStages", +] diff --git a/src/zenml/model/model_stages.py b/src/zenml/model/model_stages.py new file mode 100644 index 00000000000..04feaab5ac3 --- /dev/null +++ b/src/zenml/model/model_stages.py @@ -0,0 +1,28 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ModelStages lists supported stages of a Model Version.""" + +from zenml.utils.enum_utils import StrEnum + + +class ModelStages(StrEnum): + """All possible stages of a Model Version.""" + + NONE = "none" + STAGING = "staging" + PRODUCTION = "production" + ARCHIVED = "archived" + # technical stages + LATEST = "latest" + RUNNING = "running" diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 074ede5aec1..10ed29ae571 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -151,12 +151,19 @@ ModelResponseModel, ModelRequestModel, ModelUpdateModel, - ModelConfigBaseModel, - ModelConfigResponseModel, - ModelConfigRequestModel, ModelVersionBaseModel, ModelVersionResponseModel, ModelVersionRequestModel, + ModelVersionArtifactBaseModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionPipelineRunBaseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, + ModelVersionFilterModel, + ModelVersionUpdateModel, ) ComponentResponseModel.update_forward_refs( @@ -277,22 +284,29 @@ WorkspaceResponseModel=WorkspaceResponseModel, ) -ModelConfigRequestModel.update_forward_refs( +ModelVersionRequestModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) -ModelConfigResponseModel.update_forward_refs( +ModelVersionResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) -ModelVersionRequestModel.update_forward_refs( +ModelVersionArtifactRequestModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) - -ModelVersionResponseModel.update_forward_refs( +ModelVersionArtifactResponseModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) +ModelVersionPipelineRunRequestModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) +ModelVersionPipelineRunResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) @@ -398,6 +412,16 @@ "ModelConfigRequestModel", "ModelConfigResponseModel", "ModelVersionBaseModel", + "ModelVersionFilterModel", "ModelVersionRequestModel", "ModelVersionResponseModel", + "ModelVersionUpdateModel", + "ModelVersionArtifactBaseModel", + "ModelVersionArtifactFilterModel", + "ModelVersionArtifactRequestModel", + "ModelVersionArtifactResponseModel", + "ModelVersionPipelineRunBaseModel", + "ModelVersionPipelineRunFilterModel", + "ModelVersionPipelineRunRequestModel", + "ModelVersionPipelineRunResponseModel", ] diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 574de528589..543de9ca721 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -13,23 +13,37 @@ # permissions and limitations under the License. """Model implementation to support Model WatchTower feature.""" -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator +from zenml.model import ModelStages +from zenml.models.artifact_models import ArtifactResponseModel from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, ) from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.models.filter_models import WorkspaceScopedFilterModel +from zenml.models.pipeline_run_models import PipelineRunResponseModel class ModelVersionBaseModel(BaseModel): """Model Version base model.""" - pass + version: str = Field( + title="The name of the model version", + max_length=STR_FIELD_MAX_LENGTH, + ) + description: Optional[str] = Field( + title="The description of the model version", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + stage: Optional[str] = Field( + title="The stage of the model version", + max_length=STR_FIELD_MAX_LENGTH, + ) class ModelVersionRequestModel( @@ -38,7 +52,9 @@ class ModelVersionRequestModel( ): """Model Version request model.""" - pass + model: UUID = Field( + title="The ID of the model containing version", + ) class ModelVersionResponseModel( @@ -47,31 +63,296 @@ class ModelVersionResponseModel( ): """Model Version response model.""" - pass + model: "ModelResponseModel" = Field( + title="The model containing version", + ) + model_object_ids: Dict[str, UUID] = Field( + title="Model Objects linked to the model version", + default={}, + ) + artifact_object_ids: Dict[str, UUID] = Field( + title="Artifacts linked to the model version", + default={}, + ) + deployment_ids: Dict[str, UUID] = Field( + title="Deployments linked to the model version", + default={}, + ) + pipeline_run_ids: Dict[str, UUID] = Field( + title="Pipeline runs linked to the model version", + default={}, + ) + @property + def model_objects(self) -> Dict[str, ArtifactResponseModel]: + """Get all model objects linked to this version. -class ModelConfigBaseModel(BaseModel): - """Model Config base model.""" + Returns: + Dictionary of Model Objects as ArtifactResponseModel + """ + from zenml.client import Client - pass + return { + name: Client().get_artifact(a) + for name, a in self.model_object_ids.items() + } + @property + def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: + """Get all artifacts linked to this version. -class ModelConfigRequestModel( - ModelConfigBaseModel, - WorkspaceScopedRequestModel, + Returns: + Dictionary of Artifact Objects as ArtifactResponseModel + """ + from zenml.client import Client + + return { + name: Client().get_artifact(a) + for name, a in self.artifact_object_ids.items() + } + + @property + def deployments(self) -> Dict[str, ArtifactResponseModel]: + """Get all deployments linked to this version. + + Returns: + Dictionary of Deployments as ArtifactResponseModel + """ + from zenml.client import Client + + return { + name: Client().get_artifact(a) + for name, a in self.deployment_ids.items() + } + + @property + def pipeline_runs(self) -> Dict[str, PipelineRunResponseModel]: + """Get all pipeline runs linked to this version. + + Returns: + Dictionary of Pipeline Runs as PipelineRunResponseModel + """ + from zenml.client import Client + + return { + name: Client().get_pipeline_run(pr) + for name, pr in self.pipeline_run_ids.items() + } + + def get_model_object(self, name: str) -> ArtifactResponseModel: + """Get model object linked to this version. + + Args: + name: The name of the model object to retrieve. + + Returns: + Model Object as ArtifactResponseModel + """ + from zenml.client import Client + + return Client().get_artifact(self.model_object_ids[name]) + + def get_artifact_object(self, name: str) -> ArtifactResponseModel: + """Get artifact linked to this version. + + Args: + name: The name of the artifact to retrieve. + + Returns: + Artifact Object as ArtifactResponseModel + """ + from zenml.client import Client + + return Client().get_artifact(self.artifact_object_ids[name]) + + def get_deployment(self, name: str) -> ArtifactResponseModel: + """Get deployment linked to this version. + + Args: + name: The name of the deployment to retrieve. + + Returns: + Deployment as ArtifactResponseModel + """ + from zenml.client import Client + + return Client().get_artifact(self.deployment_ids[name]) + + def get_pipeline_run(self, name: str) -> PipelineRunResponseModel: + """Get pipeline run linked to this version. + + Args: + name: The name of the pipeline run to retrieve. + + Returns: + PipelineRun as PipelineRunResponseModel + """ + from zenml.client import Client + + return Client().get_pipeline_run(self.pipeline_run_ids[name]) + + def set_stage( + self, stage: ModelStages, force: bool = False + ) -> "ModelVersionResponseModel": + """Sets this Model Version to a desired stage. + + Args: + stage: the target stage for model version. + force: whether to force archiving of current model version in target stage or raise. + + Returns: + Dictionary of Model Objects as model_version_name_or_id + """ + from zenml.client import Client + + return Client().zen_store.update_model_version( + model_version_id=self.id, + model_version_update_model=ModelVersionUpdateModel( + model=self.model.id, + stage=stage, + force=force, + ), + ) + + # TODO in https://zenml.atlassian.net/browse/OSS-2433 + # def generate_model_card(self, template_name: str) -> str: + # """Return HTML/PDF based on input template""" + + +class ModelVersionFilterModel(WorkspaceScopedFilterModel): + """Filter Model for Model Version.""" + + model_id: Union[str, UUID] = Field( + description="The ID of the Model", + ) + version: Optional[Union[str, UUID]] = Field( + default=None, + description="The name of the Model Version", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="The workspace of the Model Version" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="The user of the Model Version" + ) + + +class ModelVersionUpdateModel(BaseModel): + """Update Model for Model Version.""" + + model: UUID = Field( + title="The ID of the model containing version", + ) + stage: ModelStages = Field( + title="Target model version stage to be set", + ) + force: bool = Field( + title="Whether existing model version in target stage should be silently archived " + "or an error should be raised.", + default=False, + ) + + +class ModelVersionArtifactBaseModel(BaseModel): + """Model version links with artifact base model.""" + + name: Optional[str] = Field( + title="The name of the artifact inside model version.", + max_length=STR_FIELD_MAX_LENGTH, + ) + artifact: UUID + model: UUID + model_version: UUID + is_model_object: bool = False + is_deployment: bool = False + + @validator("is_deployment") + def _validate_is_deployment( + cls, is_deployment: bool, values: Dict[str, Any] + ) -> bool: + is_model_object = values.get("is_model_object", False) + if is_model_object and is_deployment: + raise ValueError( + "Artifact cannot be a model object and deployment at the same time." + ) + return is_deployment + + +class ModelVersionArtifactRequestModel( + ModelVersionArtifactBaseModel, WorkspaceScopedRequestModel ): - """Model Config request model.""" + """Model version link with artifact request model.""" - pass +class ModelVersionArtifactResponseModel( + ModelVersionArtifactBaseModel, WorkspaceScopedResponseModel +): + """Model version link with artifact response model.""" -class ModelConfigResponseModel( - ModelConfigBaseModel, - WorkspaceScopedResponseModel, + +class ModelVersionArtifactFilterModel(WorkspaceScopedFilterModel): + """Model version pipeline run links filter model.""" + + model_id: Union[str, UUID] = Field( + description="The name or ID of the Model", + ) + model_version_id: Union[str, UUID] = Field( + description="The name or ID of the Model Version", + ) + name: Optional[str] = Field( + title="The name of the artifact inside model version.", + max_length=STR_FIELD_MAX_LENGTH, + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="The workspace of the Model Version" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="The user of the Model Version" + ) + only_artifacts: Optional[bool] = False + only_model_objects: Optional[bool] = False + only_deployments: Optional[bool] = False + + +class ModelVersionPipelineRunBaseModel(BaseModel): + """Model version links with pipeline run base model.""" + + name: Optional[str] = Field( + title="The name of the pipeline run inside model version.", + max_length=STR_FIELD_MAX_LENGTH, + ) + pipeline_run: UUID + model: UUID + model_version: UUID + + +class ModelVersionPipelineRunRequestModel( + ModelVersionPipelineRunBaseModel, WorkspaceScopedRequestModel ): - """Model Config response model.""" + """Model version link with pipeline run request model.""" - pass + +class ModelVersionPipelineRunResponseModel( + ModelVersionPipelineRunBaseModel, WorkspaceScopedResponseModel +): + """Model version link with pipeline run response model.""" + + +class ModelVersionPipelineRunFilterModel(WorkspaceScopedFilterModel): + """Model version pipeline run links filter model.""" + + model_id: Union[str, UUID] = Field( + description="The name or ID of the Model", + ) + model_version_id: Union[str, UUID] = Field( + description="The name or ID of the Model Version", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="The workspace of the Model Version" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="The user of the Model Version" + ) class ModelBaseModel(BaseModel): diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 93c312497ad..f5c08bd49e9 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -18,12 +18,26 @@ from fastapi import APIRouter, Depends, Security -from zenml.constants import API, MODELS, VERSION_1 +from zenml.constants import ( + API, + ARTIFACTS, + MODEL_VERSIONS, + MODELS, + RUNS, + VERSION_1, +) from zenml.enums import PermissionType from zenml.models import ( ModelFilterModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunResponseModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, ) from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize @@ -34,6 +48,10 @@ zen_store, ) +######### +# Models +######### + router = APIRouter( prefix=API + VERSION_1 + MODELS, tags=["models"], @@ -68,23 +86,6 @@ def list_models( ) -@router.delete( - "/{model_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model( - model_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Delete a model by name or ID. - - Args: - model_name_or_id: The name or ID of the model to delete. - """ - zen_store().delete_model(model_name_or_id) - - @router.get( "/{model_name_or_id}", response_model=ModelResponseModel, @@ -130,3 +131,249 @@ def update_model( model_id=model_id, model_update=model_update, ) + + +@router.delete( + "/{model_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model( + model_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model to delete. + """ + zen_store().delete_model(model_name_or_id) + + +################# +# Model Versions +################# + + +@router.get( + "/{model_name_or_id}" + MODEL_VERSIONS, + response_model=Page[ModelVersionResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_versions( + model_version_filter_model: ModelVersionFilterModel = Depends( + make_dependable(ModelVersionFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionResponseModel]: + """Get model versions according to query filters. + + Args: + model_version_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model versions according to query filters. + """ + return zen_store().list_model_versions( + model_version_filter_model=model_version_filter_model, + ) + + +@router.get( + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_model_version( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> ModelVersionResponseModel: + """Get a model version by name or ID. + + Args: + model_name_or_id: The name or ID of the model containing version. + model_version_name_or_id: The name or ID of the model version to get. + + Returns: + The model version with the given name or ID. + """ + return zen_store().get_model_version( + model_name_or_id, model_version_name_or_id + ) + + +@router.put( + "/{model_id}" + MODEL_VERSIONS + "/{model_version_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_model_version( + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> ModelVersionResponseModel: + """Get all model versions by filter. + + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + """ + return zen_store().update_model_version( + model_version_id=model_version_id, + model_version_update_model=model_version_update_model, + ) + + +@router.delete( + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model containing version. + model_version_name_or_id: The name or ID of the model version to delete. + """ + zen_store().delete_model_version( + model_name_or_id, model_version_name_or_id + ) + + +########################## +# Model Version Artifacts +########################## + + +@router.get( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + ARTIFACTS, + response_model=Page[ModelVersionArtifactResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_artifact_links( + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( + make_dependable(ModelVersionArtifactFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionArtifactResponseModel]: + """Get model version to artifact links according to query filters. + + Args: + model_version_artifact_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to artifact links according to query filters. + """ + return zen_store().list_model_version_artifact_links( + model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, + ) + + +@router.delete( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + ARTIFACTS + + "/{model_version_artifact_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_artifact_link( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. + """ + zen_store().delete_model_version_artifact_link( + model_name_or_id, + model_version_name_or_id, + model_version_artifact_link_name_or_id, + ) + + +############################## +# Model Version Pipeline Runs +############################## + + +@router.get( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS, + response_model=Page[ModelVersionPipelineRunResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_pipeline_run_links( + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( + make_dependable(ModelVersionPipelineRunFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionPipelineRunResponseModel]: + """Get model version to pipeline run links according to query filters. + + Args: + model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, + and filtering + + Returns: + The model version to pipeline run links according to query filters. + """ + return zen_store().list_model_version_pipeline_run_links( + model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, + ) + + +@router.delete( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS + + "/{model_version_pipeline_run_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_pipeline_run_link( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. + """ + zen_store().delete_model_version_pipeline_run_link( + model_name_or_id, + model_version_name_or_id, + model_version_pipeline_run_link_name_or_id, + ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 622dd1ee213..cc300ccb368 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -19,8 +19,10 @@ from zenml.constants import ( API, + ARTIFACTS, CODE_REPOSITORIES, GET_OR_CREATE, + MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, @@ -51,6 +53,15 @@ ModelFilterModel, ModelRequestModel, ModelResponseModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, + ModelVersionRequestModel, + ModelVersionResponseModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1221,3 +1232,276 @@ def list_workspace_models( return zen_store().list_models( model_filter_model=model_filter_model, ) + + +@router.post( + WORKSPACES + + "/{workspace_name_or_id}" + + MODELS + + "/{model_name_or_id}" + + MODEL_VERSIONS, + response_model=ModelVersionResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version: ModelVersionRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelVersionResponseModel: + """Create a new model version. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version: The model version to create. + auth_context: Authentication context. + + Returns: + The created model version. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model version does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model_version.workspace != workspace.id: + raise IllegalOperationError( + "Creating model versions outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model_version.user != auth_context.user.id: + raise IllegalOperationError( + "Creating models for a user other than yourself " + "is not supported." + ) + mv = zen_store().create_model_version(model_version) + return mv + + +@router.get( + WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS, + response_model=Page[ModelVersionResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_workspace_model_versions( + workspace_name_or_id: Union[str, UUID], + model_version_filter_model: ModelVersionFilterModel = Depends( + make_dependable(ModelVersionFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionResponseModel]: + """Get model versions according to query filters. + + Args: + workspace_name_or_id: Name or ID of the workspace. + model_version_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model versions according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_version_filter_model.set_scope_workspace(workspace_id) + return zen_store().list_model_versions( + model_version_filter_model=model_version_filter_model, + ) + + +@router.post( + WORKSPACES + + "/{workspace_name_or_id}" + + MODELS + + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + ARTIFACTS, + response_model=ModelVersionArtifactResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version_artifact_link( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link: ModelVersionArtifactRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelVersionArtifactResponseModel: + """Create a new model version to artifact link. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_artifact_link: The model version to artifact link to create. + auth_context: Authentication context. + + Returns: + The created model version to artifact link. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model version does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model_version_artifact_link.workspace != workspace.id: + raise IllegalOperationError( + "Creating model version to artifact links outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model_version_artifact_link.user != auth_context.user.id: + raise IllegalOperationError( + "Creating model to artifact links for a user other than yourself " + "is not supported." + ) + mv = zen_store().create_model_version_artifact_link( + model_version_artifact_link + ) + return mv + + +@router.get( + WORKSPACES + + "/{workspace_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + ARTIFACTS, + response_model=Page[ModelVersionArtifactResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_workspace_model_version_artifact_links( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( + make_dependable(ModelVersionArtifactFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionArtifactResponseModel]: + """Get model version to artifact links according to query filters. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_artifact_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to artifact links according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_version_artifact_link_filter_model.set_scope_workspace(workspace_id) + return zen_store().list_model_version_artifact_links( + model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, + ) + + +@router.post( + WORKSPACES + + "/{workspace_name_or_id}" + + MODELS + + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS, + response_model=ModelVersionPipelineRunResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version_pipeline_run_link( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelVersionPipelineRunResponseModel: + """Create a new model version to pipeline run link. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_pipeline_run_link: The model version to pipeline run link to create. + auth_context: Authentication context. + + Returns: + The created model version to pipeline run link. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model version does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model_version_pipeline_run_link.workspace != workspace.id: + raise IllegalOperationError( + "Creating model versions outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model_version_pipeline_run_link.user != auth_context.user.id: + raise IllegalOperationError( + "Creating models for a user other than yourself " + "is not supported." + ) + mv = zen_store().create_model_version_pipeline_run_link( + model_version_pipeline_run_link + ) + return mv + + +@router.get( + WORKSPACES + + "/{workspace_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS, + response_model=Page[ModelVersionPipelineRunResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_workspace_model_version_pipeline_run_links( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( + make_dependable(ModelVersionPipelineRunResponseModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionPipelineRunResponseModel]: + """Get model version to pipeline links according to query filters. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to pipeline run links according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_version_pipeline_run_link_filter_model.set_scope_workspace( + workspace_id + ) + return zen_store().list_model_version_pipeline_run_links( + model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, + ) diff --git a/src/zenml/zen_stores/migrations/versions/cdd9599a008d_add_model_version_and_links.py b/src/zenml/zen_stores/migrations/versions/cdd9599a008d_add_model_version_and_links.py new file mode 100644 index 00000000000..e694bb0f574 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/cdd9599a008d_add_model_version_and_links.py @@ -0,0 +1,162 @@ +"""add model_version and links [cdd9599a008d]. + +Revision ID: cdd9599a008d +Revises: 3b68abe58f44 +Create Date: 2023-09-15 17:53:23.963414 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "cdd9599a008d" +down_revision = "3b68abe58f44" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "model_version", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("version", sa.TEXT(), nullable=False), + sa.Column("description", sa.TEXT(), nullable=True), + sa.Column("stage", sa.TEXT(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["model_id"], + ["model.id"], + name="fk_model_version_model_id_model", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_version_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_version_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "model_versions_artifacts", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("artifact_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("is_model_object", sa.BOOLEAN(), nullable=True), + sa.Column("is_deployment", sa.BOOLEAN(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint( + ["artifact_id"], + ["artifact.id"], + name="fk_model_versions_artifacts_artifact_id_artifact", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["model_id"], + ["model.id"], + name="fk_model_versions_artifacts_model_id_model", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["model_version_id"], + ["model_version.id"], + name="fk_model_versions_artifacts_model_version_id_model_version", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_versions_artifacts_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_versions_artifacts_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "model_versions_runs", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column( + "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint( + ["model_id"], + ["model.id"], + name="fk_model_versions_runs_model_id_model", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["model_version_id"], + ["model_version.id"], + name="fk_model_versions_runs_model_version_id_model_version", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["pipeline_run_id"], + ["pipeline_run.id"], + name="fk_model_versions_runs_run_id_pipeline_run", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_versions_runs_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_versions_runs_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("model_versions_runs") + op.drop_table("model_versions_artifacts") + op.drop_table("model_version") + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 32349e7a7e2..4ebbd308b03 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -49,6 +49,7 @@ GET_OR_CREATE, INFO, LOGIN, + MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, @@ -99,6 +100,16 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, + ModelVersionRequestModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -2365,6 +2376,221 @@ def list_models( filter_model=model_filter_model, ) + ################# + # Model Versions + ################# + + def create_model_version( + self, model_version: ModelVersionRequestModel + ) -> ModelVersionResponseModel: + """Creates a new model version. + + Args: + model_version: the Model Version to be created. + + Returns: + The newly created model version. + """ + return self._create_workspace_scoped_resource( + resource=model_version, + response_model=ModelVersionResponseModel, + route=f"{MODELS}/{model_version.model}{MODEL_VERSIONS}", + ) + + def delete_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be deleted. + """ + self._delete_resource( + resource_id=model_version_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + ) + + def get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + ) -> ModelVersionResponseModel: + """Get an existing model version. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be retrieved. + + Returns: + The model version of interest. + """ + return self._get_resource( + resource_id=model_version_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + ) + + def list_model_versions( + self, + model_version_filter_model: ModelVersionFilterModel, + ) -> Page[ModelVersionResponseModel]: + """Get all model versions by filter. + + Args: + model_version_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model versions. + """ + return self._list_paginated_resources( + route=f"{MODELS}/{model_version_filter_model.model_id}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + filter_model=model_version_filter_model, + ) + + def update_model_version( + self, + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + ) -> ModelVersionResponseModel: + """Get all model versions by filter. + + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + + """ + return self._update_resource( + resource_id=model_version_id, + resource_update=model_version_update_model, + route=f"{MODELS}/{model_version_update_model.model}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + ) + + ########################### + # Model Versions Artifacts + ########################### + + def create_model_version_artifact_link( + self, model_version_artifact_link: ModelVersionArtifactRequestModel + ) -> ModelVersionArtifactResponseModel: + """Creates a new model version link. + + Args: + model_version_artifact_link: the Model Version to Artifact Link to be created. + + Returns: + The newly created model version to artifact link. + """ + return self._create_workspace_scoped_resource( + resource=model_version_artifact_link, + response_model=ModelVersionArtifactResponseModel, + route=f"{MODELS}/{model_version_artifact_link.model}{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", + ) + + def list_model_version_artifact_links( + self, + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, + ) -> Page[ModelVersionArtifactResponseModel]: + """Get all model version to artifact links by filter. + + Args: + model_version_artifact_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to artifact links. + """ + return self._list_paginated_resources( + route=f"{MODELS}/{model_version_artifact_link_filter_model.model_id}{MODEL_VERSIONS}/{model_version_artifact_link_filter_model.model_version_id}{ARTIFACTS}", + response_model=ModelVersionArtifactResponseModel, + filter_model=model_version_artifact_link_filter_model, + ) + + def delete_model_version_artifact_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to artifact link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. + """ + self._delete_resource( + resource_id=model_version_artifact_link_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{ARTIFACTS}", + ) + + ############################### + # Model Versions Pipeline Runs + ############################### + + def create_model_version_pipeline_run_link( + self, + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + ) -> ModelVersionPipelineRunResponseModel: + """Creates a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: the Model Version to Pipeline Run Link to be created. + + Returns: + The newly created model version to pipeline run link. + """ + return self._create_workspace_scoped_resource( + resource=model_version_pipeline_run_link, + response_model=ModelVersionPipelineRunResponseModel, + route=f"{MODELS}/{model_version_pipeline_run_link.model}{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", + ) + + def list_model_version_pipeline_run_links( + self, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, + ) -> Page[ModelVersionPipelineRunResponseModel]: + """Get all model version to pipeline run links by filter. + + Args: + model_version_pipeline_run_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to pipeline run links. + """ + return self._list_paginated_resources( + route=f"{MODELS}/{model_version_pipeline_run_link_filter_model.model_id}{MODEL_VERSIONS}/{model_version_pipeline_run_link_filter_model.model_version_id}{RUNS}", + response_model=ModelVersionPipelineRunResponseModel, + filter_model=model_version_pipeline_run_link_filter_model, + ) + + def delete_model_version_pipeline_run_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to pipeline run link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. + """ + self._delete_resource( + resource_id=model_version_pipeline_run_link_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{RUNS}", + ) + # ======================= # Internal helper methods # ======================= diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index dcd454d173f..69dfa63e92a 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -57,7 +57,12 @@ ) from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema -from zenml.zen_stores.schemas.model_schemas import ModelSchema +from zenml.zen_stores.schemas.model_schemas import ( + ModelSchema, + ModelVersionSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, +) __all__ = [ "ArtifactSchema", @@ -93,4 +98,7 @@ "UserSchema", "LogsSchema", "ModelSchema", + "ModelVersionSchema", + "ModelVersionArtifactSchema", + "ModelVersionPipelineRunSchema", ] diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index d52c3c13248..bd55820fe19 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -32,6 +32,9 @@ from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: + from zenml.zen_stores.schemas.model_schemas import ( + ModelVersionArtifactSchema, + ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.step_run_schemas import ( StepRunInputArtifactSchema, @@ -94,6 +97,12 @@ class ArtifactSchema(NamedSchema, table=True): back_populates="artifact", sa_relationship_kwargs={"cascade": "delete"}, ) + model_versions_artifacts_links: List[ + "ModelVersionArtifactSchema" + ] = Relationship( + back_populates="artifact", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index b3a0c34f15e..56ff415a30c 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -16,18 +16,26 @@ import json from datetime import datetime -from typing import Optional +from typing import List, Optional from uuid import UUID -from sqlalchemy import TEXT, Column +from sqlalchemy import BOOLEAN, TEXT, Column from sqlmodel import Field, Relationship from zenml.models import ( ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, + ModelVersionRequestModel, + ModelVersionResponseModel, ) -from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.artifact_schemas import ArtifactSchema +from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema +from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -66,6 +74,18 @@ class ModelSchema(NamedSchema, table=True): trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) ethic: str = Field(sa_column=Column(TEXT, nullable=True)) tags: str = Field(sa_column=Column(TEXT, nullable=True)) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="model", + sa_relationship_kwargs={"cascade": "delete"}, + ) + artifact_links: List["ModelVersionArtifactSchema"] = Relationship( + back_populates="model", + sa_relationship_kwargs={"cascade": "delete"}, + ) + pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship( + back_populates="model", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": @@ -88,7 +108,9 @@ def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": limitations=model_request.limitations, trade_offs=model_request.trade_offs, ethic=model_request.ethic, - tags=json.dumps(model_request.tags), + tags=json.dumps(model_request.tags) + if model_request.tags + else None, ) def to_model(self) -> ModelResponseModel: @@ -133,3 +155,339 @@ def update( setattr(self, field, value) self.updated = datetime.utcnow() return self + + +class ModelVersionSchema(BaseSchema, table=True): + """SQL Model for model version.""" + + __tablename__ = "model_version" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship( + back_populates="model_versions" + ) + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship( + back_populates="model_versions" + ) + + model_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelSchema.__tablename__, + source_column="model_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model: "ModelSchema" = Relationship(back_populates="model_versions") + artifact_links: List["ModelVersionArtifactSchema"] = Relationship( + back_populates="model_version", + sa_relationship_kwargs={"cascade": "delete"}, + ) + pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship( + back_populates="model_version", + sa_relationship_kwargs={"cascade": "delete"}, + ) + + version: str = Field(sa_column=Column(TEXT, nullable=False)) + description: str = Field(sa_column=Column(TEXT, nullable=True)) + stage: str = Field(sa_column=Column(TEXT, nullable=True)) + + @classmethod + def from_request( + cls, model_version_request: ModelVersionRequestModel + ) -> "ModelVersionSchema": + """Convert an `ModelVersionRequestModel` to an `ModelVersionSchema`. + + Args: + model_version_request: The request model version to convert. + + Returns: + The converted schema. + """ + return cls( + workspace_id=model_version_request.workspace, + user_id=model_version_request.user, + model_id=model_version_request.model, + version=model_version_request.version, + description=model_version_request.description, + stage=model_version_request.stage, + ) + + def to_model(self) -> ModelVersionResponseModel: + """Convert an `ModelVersionSchema` to an `ModelVersionResponseModel`. + + Returns: + The created `ModelVersionResponseModel`. + """ + return ModelVersionResponseModel( + id=self.id, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + model=self.model.to_model(), + version=self.version, + description=self.description, + stage=self.stage, + model_object_ids={ + al.name: al.artifact_id + for al in self.artifact_links + if al.artifact_id is not None and al.is_model_object + }, + deployment_ids={ + al.name: al.artifact_id + for al in self.artifact_links + if al.artifact_id is not None and al.is_deployment + }, + artifact_object_ids={ + al.name: al.artifact_id + for al in self.artifact_links + if al.artifact_id is not None + and not (al.is_deployment or al.is_model_object) + }, + pipeline_run_ids={ + al.name: al.pipeline_run_id for al in self.pipeline_run_links + }, + ) + + def update( + self, + target_stage: str, + ) -> "ModelVersionSchema": + """Updates a `ModelVersionSchema` to a target stage. + + Args: + target_stage: The stage to be updated. + + Returns: + The updated `ModelVersionSchema`. + """ + self.stage = target_stage + self.updated = datetime.utcnow() + return self + + +class ModelVersionArtifactSchema(NamedSchema, table=True): + """SQL Model for linking of Model Versions and Artifacts M:M.""" + + __tablename__ = "model_versions_artifacts" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship( + back_populates="model_versions_artifacts_links" + ) + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship( + back_populates="model_versions_artifacts_links" + ) + + model_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelSchema.__tablename__, + source_column="model_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model: "ModelSchema" = Relationship(back_populates="artifact_links") + model_version_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelVersionSchema.__tablename__, + source_column="model_version_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model_version: "ModelVersionSchema" = Relationship( + back_populates="artifact_links" + ) + artifact_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=ArtifactSchema.__tablename__, + source_column="artifact_id", + target_column="id", + ondelete="CASCADE", + nullable=True, + ) + artifact: Optional["ArtifactSchema"] = Relationship( + back_populates="model_versions_artifacts_links" + ) + + is_model_object: bool = Field(sa_column=Column(BOOLEAN, nullable=True)) + is_deployment: bool = Field(sa_column=Column(BOOLEAN, nullable=True)) + + @classmethod + def from_request( + cls, model_version_artifact_request: ModelVersionArtifactRequestModel + ) -> "ModelVersionArtifactSchema": + """Convert an `ModelVersionArtifactRequestModel` to a `ModelVersionArtifactSchema`. + + Args: + model_version_artifact_request: The request link to convert. + + Returns: + The converted schema. + """ + return cls( + name=model_version_artifact_request.name, + workspace_id=model_version_artifact_request.workspace, + user_id=model_version_artifact_request.user, + model_id=model_version_artifact_request.model, + model_version_id=model_version_artifact_request.model_version, + artifact_id=model_version_artifact_request.artifact, + is_model_object=model_version_artifact_request.is_model_object, + is_deployment=model_version_artifact_request.is_deployment, + ) + + def to_model(self) -> ModelVersionArtifactResponseModel: + """Convert an `ModelVersionArtifactSchema` to an `ModelVersionArtifactResponseModel`. + + Returns: + The created `ModelVersionArtifactResponseModel`. + """ + return ModelVersionArtifactResponseModel( + id=self.id, + name=self.name, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + model=self.model_id, + model_version=self.model_version_id, + artifact=self.artifact_id, + is_model_object=self.is_model_object, + is_deployment=self.is_deployment, + ) + + +class ModelVersionPipelineRunSchema(NamedSchema, table=True): + """SQL Model for linking of Model Versions and Pipeline Runs M:M.""" + + __tablename__ = "model_versions_runs" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship( + back_populates="model_versions_pipeline_runs_links" + ) + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship( + back_populates="model_versions_pipeline_runs_links" + ) + + model_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelSchema.__tablename__, + source_column="model_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model: "ModelSchema" = Relationship(back_populates="pipeline_run_links") + model_version_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelVersionSchema.__tablename__, + source_column="model_version_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model_version: "ModelVersionSchema" = Relationship( + back_populates="pipeline_run_links" + ) + pipeline_run_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=PipelineRunSchema.__tablename__, + source_column="run_id", + target_column="id", + ondelete="CASCADE", + nullable=True, + ) + pipeline_run: Optional["PipelineRunSchema"] = Relationship( + back_populates="model_versions_pipeline_runs_links" + ) + + @classmethod + def from_request( + cls, + model_version_pipeline_run_request: ModelVersionPipelineRunRequestModel, + ) -> "ModelVersionPipelineRunSchema": + """Convert an `ModelVersionPipelineRunRequestModel` to an `ModelVersionPipelineRunSchema`. + + Args: + model_version_pipeline_run_request: The request link to convert. + + Returns: + The converted schema. + """ + return cls( + workspace_id=model_version_pipeline_run_request.workspace, + user_id=model_version_pipeline_run_request.user, + name=model_version_pipeline_run_request.name, + model_id=model_version_pipeline_run_request.model, + model_version_id=model_version_pipeline_run_request.model_version, + pipeline_run_id=model_version_pipeline_run_request.pipeline_run, + ) + + def to_model(self) -> ModelVersionPipelineRunResponseModel: + """Convert an `ModelVersionPipelineRunSchema` to an `ModelVersionPipelineRunResponseModel`. + + Returns: + The created `ModelVersionPipelineRunResponseModel`. + """ + return ModelVersionPipelineRunResponseModel( + id=self.id, + name=self.name, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + model=self.model_id, + model_version=self.model_version_id, + pipeline_run=self.pipeline_run_id, + ) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 9cb8cbe33a0..629b44a0341 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -43,6 +43,9 @@ if TYPE_CHECKING: from zenml.zen_stores.schemas.logs_schemas import LogsSchema + from zenml.zen_stores.schemas.model_schemas import ( + ModelVersionPipelineRunSchema, + ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -157,6 +160,12 @@ class PipelineRunSchema(NamedSchema, table=True): back_populates="pipeline_run", sa_relationship_kwargs={"cascade": "delete", "uselist": False}, ) + model_versions_pipeline_runs_links: List[ + "ModelVersionPipelineRunSchema" + ] = Relationship( + back_populates="pipeline_run", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index e9ea6aeb522..6be3c152509 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -28,6 +28,9 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, + ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, PipelineRunSchema, @@ -95,6 +98,15 @@ class UserSchema(NamedSchema, table=True): models: List["ModelSchema"] = Relationship( back_populates="user", ) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="user", + ) + model_versions_artifacts_links: List[ + "ModelVersionArtifactSchema" + ] = Relationship(back_populates="user") + model_versions_pipeline_runs_links: List[ + "ModelVersionPipelineRunSchema" + ] = Relationship(back_populates="user") @classmethod def from_request(cls, model: UserRequestModel) -> "UserSchema": diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index 649848b3ca3..f0eb0792d5e 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -30,6 +30,9 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, + ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, PipelineRunSchema, @@ -121,6 +124,22 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) + model_versions_artifacts_links: List[ + "ModelVersionArtifactSchema" + ] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) + model_versions_pipeline_runs_links: List[ + "ModelVersionPipelineRunSchema" + ] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 92dd71d150f..ca50eff4a8d 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -76,6 +76,7 @@ ) from zenml.io import fileio from zenml.logger import get_console_handler, get_logger, get_logging_level +from zenml.model import ModelStages from zenml.models import ( ArtifactFilterModel, ArtifactRequestModel, @@ -98,6 +99,16 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, + ModelVersionRequestModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -194,6 +205,9 @@ FlavorSchema, IdentitySchema, ModelSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, + ModelVersionSchema, NamedSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -5516,3 +5530,510 @@ def update_model( # Refresh the Model that was just created session.refresh(existing_model) return existing_model.to_model() + + ################# + # Model Versions + ################# + + def create_model_version( + self, model_version: ModelVersionRequestModel + ) -> ModelVersionResponseModel: + """Creates a new model version. + + Args: + model_version: the Model Version to be created. + + Returns: + The newly created model version. + + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ + with Session(self.engine) as session: + model = self.get_model(model_version.model) + existing_model_version = session.exec( + select(ModelVersionSchema) + .where(ModelVersionSchema.model_id == model.id) + .where(ModelVersionSchema.version == model_version.version) + ).first() + if existing_model_version is not None: + raise EntityExistsError( + f"Unable to create model version {model_version.version}: " + f"A model version with this name already exists in {model.name} model." + ) + + model_version_schema = ModelVersionSchema.from_request( + model_version + ) + session.add(model_version_schema) + + session.commit() + mv = ModelVersionSchema.to_model(model_version_schema) + return mv + + def get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + ) -> ModelVersionResponseModel: + """Get an existing model version. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be retrieved. + + Returns: + The model version of interest. + + Raises: + KeyError: specified ID or name not found. + """ + with Session(self.engine) as session: + model = self.get_model(model_name_or_id) + query = select(ModelVersionSchema).where( + ModelVersionSchema.model_id == model.id + ) + try: + UUID(str(model_version_name_or_id)) + query = query.where( + ModelVersionSchema.id == model_version_name_or_id + ) + except ValueError: + query = query.where( + ModelVersionSchema.version == model_version_name_or_id + ) + model_version = session.exec(query).first() + if model_version is None: + raise KeyError( + f"Unable to get model version with name `{model_version_name_or_id}`: " + f"No model version with this name found." + ) + return ModelVersionSchema.to_model(model_version) + + def list_model_versions( + self, + model_version_filter_model: ModelVersionFilterModel, + ) -> Page[ModelVersionResponseModel]: + """Get all model versions by filter. + + Args: + model_version_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model versions. + """ + with Session(self.engine) as session: + query = select(ModelVersionSchema) + return self.filter_and_paginate( + session=session, + query=query, + table=ModelVersionSchema, + filter_model=model_version_filter_model, + ) + + def delete_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be deleted. + + Raises: + KeyError: specified ID or name not found. + """ + with Session(self.engine) as session: + model = self.get_model(model_name_or_id) + query = select(ModelVersionSchema).where( + ModelVersionSchema.model_id == model.id + ) + try: + UUID(str(model_version_name_or_id)) + query = query.where( + ModelVersionSchema.id == model_version_name_or_id + ) + except ValueError: + query = query.where( + ModelVersionSchema.version == model_version_name_or_id + ) + model_version = session.exec(query).first() + if model_version is None: + raise KeyError( + f"Unable to delete model version with name `{model_version_name_or_id}`: " + f"No model version with this name found." + ) + session.delete(model_version) + session.commit() + + def update_model_version( + self, + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + ) -> ModelVersionResponseModel: + """Get all model versions by filter. + + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + + Raises: + KeyError: If the model version not found + RuntimeError: If there is a model version with target stage, but `force` flag is off + """ + with Session(self.engine) as session: + existing_model_version = session.exec( + select(ModelVersionSchema) + .where( + ModelVersionSchema.model_id + == model_version_update_model.model + ) + .where(ModelVersionSchema.id == model_version_id) + ).first() + + if not existing_model_version: + raise KeyError(f"Model version {model_version_id} not found.") + + existing_model_version_in_target_stage = session.exec( + select(ModelVersionSchema) + .where( + ModelVersionSchema.model_id + == model_version_update_model.model + ) + .where( + ModelVersionSchema.stage + == model_version_update_model.stage.value + ) + ).first() + + if existing_model_version_in_target_stage is not None: + if not model_version_update_model.force: + raise RuntimeError( + f"Model version {existing_model_version_in_target_stage.version} is " + f"in {model_version_update_model.stage.value}, but `force` flag is False." + ) + else: + existing_model_version_in_target_stage.update( + ModelStages.ARCHIVED.value + ) + session.add(existing_model_version_in_target_stage) + session.commit() + session.refresh(existing_model_version_in_target_stage) + + logger.info( + f"Model version {existing_model_version_in_target_stage.version} has been set to {ModelStages.ARCHIVED.value}." + ) + existing_model_version.update( + model_version_update_model.stage.value + ) + session.add(existing_model_version) + session.commit() + session.refresh(existing_model_version) + + return existing_model_version.to_model() + + ########################### + # Model Versions Artifacts + ########################### + + def create_model_version_artifact_link( + self, model_version_artifact_link: ModelVersionArtifactRequestModel + ) -> ModelVersionArtifactResponseModel: + """Creates a new model version link. + + Args: + model_version_artifact_link: the Model Version to Artifact Link to be created. + + Returns: + The newly created model version to artifact link. + + Raises: + EntityExistsError: If a link with the given name already exists. + """ + with Session(self.engine) as session: + existing_model_version_artifact_link = session.exec( + select(ModelVersionArtifactSchema) + .where( + ModelVersionArtifactSchema.model_version_id + == model_version_artifact_link.model_version + ) + .where( + or_( + ModelVersionArtifactSchema.name + == model_version_artifact_link.name, + ModelVersionArtifactSchema.artifact_id + == model_version_artifact_link.artifact, + ) + ) + ).first() + if existing_model_version_artifact_link is not None: + raise EntityExistsError( + f"Unable to create model version link {existing_model_version_artifact_link.name}: " + f"A model version link with this name already exists in {existing_model_version_artifact_link.model_version} model version." + ) + + if model_version_artifact_link.name is None: + model_version_artifact_link.name = self.get_artifact( + model_version_artifact_link.artifact + ).name + + model_version_artifact_link_schema = ( + ModelVersionArtifactSchema.from_request( + model_version_artifact_link + ) + ) + session.add(model_version_artifact_link_schema) + + session.commit() + mvl = ModelVersionArtifactSchema.to_model( + model_version_artifact_link_schema + ) + return mvl + + def list_model_version_artifact_links( + self, + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, + ) -> Page[ModelVersionArtifactResponseModel]: + """Get all model version to artifact links by filter. + + Args: + model_version_artifact_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to artifact links. + """ + with Session(self.engine) as session: + # issue: https://github.com/tiangolo/sqlmodel/issues/109 + if model_version_artifact_link_filter_model.only_artifacts: + query = ( + select(ModelVersionArtifactSchema) + .where( + ModelVersionArtifactSchema.is_model_object + == False # noqa: E712 + ) + .where( + ModelVersionArtifactSchema.is_deployment + == False # noqa: E712 + ) + .where( + ModelVersionArtifactSchema.artifact + != None # noqa: E712, E711 + ) + ) + elif model_version_artifact_link_filter_model.only_deployments: + query = ( + select(ModelVersionArtifactSchema) + .where(ModelVersionArtifactSchema.is_deployment) + .where( + ModelVersionArtifactSchema.is_model_object + == False # noqa: E712 + ) + .where( + ModelVersionArtifactSchema.artifact + != None # noqa: E712, E711 + ) + ) + elif model_version_artifact_link_filter_model.only_model_objects: + query = ( + select(ModelVersionArtifactSchema) + .where(ModelVersionArtifactSchema.is_model_object) + .where( + ModelVersionArtifactSchema.is_deployment + == False # noqa: E712 + ) + .where( + ModelVersionArtifactSchema.artifact + != None # noqa: E712, E711 + ) + ) + else: + query = select(ModelVersionArtifactSchema) + model_version_artifact_link_filter_model.only_artifacts = None + model_version_artifact_link_filter_model.only_deployments = None + model_version_artifact_link_filter_model.only_model_objects = None + return self.filter_and_paginate( + session=session, + query=query, + table=ModelVersionArtifactSchema, + filter_model=model_version_artifact_link_filter_model, + ) + + def delete_model_version_artifact_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to artifact link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. + + Raises: + KeyError: specified ID or name not found. + """ + with Session(self.engine) as session: + self.get_model(model_name_or_id) + model_version = self.get_model_version( + model_name_or_id, model_version_name_or_id + ) + query = select(ModelVersionArtifactSchema).where( + ModelVersionArtifactSchema.model_version_id == model_version.id + ) + try: + UUID(str(model_version_artifact_link_name_or_id)) + query = query.where( + ModelVersionArtifactSchema.id + == model_version_artifact_link_name_or_id + ) + except ValueError: + query = query.where( + ModelVersionArtifactSchema.name + == model_version_artifact_link_name_or_id + ) + + model_version_artifact_link = session.exec(query).first() + if model_version_artifact_link is None: + raise KeyError( + f"Unable to delete model version link with name `{model_version_artifact_link_name_or_id}`: " + f"No model version link with this name found." + ) + + session.delete(model_version_artifact_link) + session.commit() + + ############################### + # Model Versions Pipeline Runs + ############################### + + def create_model_version_pipeline_run_link( + self, + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + ) -> ModelVersionPipelineRunResponseModel: + """Creates a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: the Model Version to Pipeline Run Link to be created. + + Returns: + The newly created model version to pipeline run link. + + Raises: + EntityExistsError: If a link with the given ID already exists. + """ + with Session(self.engine) as session: + existing_model_version_pipeline_run_link = session.exec( + select(ModelVersionPipelineRunSchema) + .where( + ModelVersionPipelineRunSchema.model_version_id + == model_version_pipeline_run_link.model_version + ) + .where( + or_( + ModelVersionPipelineRunSchema.pipeline_run_id + == model_version_pipeline_run_link.pipeline_run, + ModelVersionPipelineRunSchema.name + == model_version_pipeline_run_link.name, + ) + ) + ).first() + if existing_model_version_pipeline_run_link is not None: + raise EntityExistsError( + f"Unable to create model version link {existing_model_version_pipeline_run_link.name}: " + f"A model version link with this name already exists in {existing_model_version_pipeline_run_link.model_version} model version." + ) + + if model_version_pipeline_run_link.name is None: + model_version_pipeline_run_link.name = self.get_run( + model_version_pipeline_run_link.pipeline_run + ).name + + model_version_pipeline_run_link_schema = ( + ModelVersionPipelineRunSchema.from_request( + model_version_pipeline_run_link + ) + ) + session.add(model_version_pipeline_run_link_schema) + + session.commit() + mvl = ModelVersionPipelineRunSchema.to_model( + model_version_pipeline_run_link_schema + ) + return mvl + + def list_model_version_pipeline_run_links( + self, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, + ) -> Page[ModelVersionPipelineRunResponseModel]: + """Get all model version to pipeline run links by filter. + + Args: + model_version_pipeline_run_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to pipeline run links. + """ + with Session(self.engine) as session: + return self.filter_and_paginate( + session=session, + query=select(ModelVersionPipelineRunSchema), + table=ModelVersionPipelineRunSchema, + filter_model=model_version_pipeline_run_link_filter_model, + ) + + def delete_model_version_pipeline_run_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to pipeline run link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. + + Raises: + KeyError: specified ID not found. + """ + with Session(self.engine) as session: + self.get_model(model_name_or_id) + model_version = self.get_model_version( + model_name_or_id, model_version_name_or_id + ) + query = select(ModelVersionPipelineRunSchema).where( + ModelVersionPipelineRunSchema.model_version_id + == model_version.id + ) + try: + UUID(str(model_version_pipeline_run_link_name_or_id)) + query = query.where( + ModelVersionPipelineRunSchema.id + == model_version_pipeline_run_link_name_or_id + ) + except ValueError: + query = query.where( + ModelVersionPipelineRunSchema.name + == model_version_pipeline_run_link_name_or_id + ) + + model_version_pipeline_run_link = session.exec(query).first() + if model_version_pipeline_run_link is None: + raise KeyError( + f"Unable to delete model version link with name `{model_version_pipeline_run_link_name_or_id}`: " + f"No model version link with this name found." + ) + + session.delete(model_version_pipeline_run_link) + session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index f97dc27d628..e5bdc518c20 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -36,6 +36,16 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, + ModelVersionRequestModel, + ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1686,7 +1696,7 @@ def create_model(self, model: ModelRequestModel) -> ModelResponseModel: The newly created model. Raises: - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a model with the given name already exists. """ @abstractmethod @@ -1724,6 +1734,9 @@ def get_model( Returns: The model of interest. + + Raises: + KeyError: specified ID or name not found. """ @abstractmethod @@ -1740,3 +1753,200 @@ def list_models( Returns: A page of all models. """ + + ################# + # Model Versions + ################# + + @abstractmethod + def create_model_version( + self, model_version: ModelVersionRequestModel + ) -> ModelVersionResponseModel: + """Creates a new model version. + + Args: + model_version: the Model Version to be created. + + Returns: + The newly created model version. + + Raises: + EntityExistsError: If a model version with the given name already exists. + """ + + @abstractmethod + def delete_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be deleted. + + Raises: + KeyError: specified ID or name not found. + """ + + @abstractmethod + def get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + ) -> ModelVersionResponseModel: + """Get an existing model version. + + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be retrieved. + + Returns: + The model version of interest. + + Raises: + KeyError: specified ID or name not found. + """ + + @abstractmethod + def list_model_versions( + self, + model_version_filter_model: ModelVersionFilterModel, + ) -> Page[ModelVersionResponseModel]: + """Get all model versions by filter. + + Args: + model_version_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model versions. + """ + + @abstractmethod + def update_model_version( + self, + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + ) -> ModelVersionResponseModel: + """Get all model versions by filter. + + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + + Raises: + KeyError: If the model version not found + RuntimeError: If there is a model version with target stage, but `force` flag is off + """ + + ########################### + # Model Versions Artifacts + ########################### + + @abstractmethod + def create_model_version_artifact_link( + self, model_version_artifact_link: ModelVersionArtifactRequestModel + ) -> ModelVersionArtifactResponseModel: + """Creates a new model version link. + + Args: + model_version_artifact_link: the Model Version to Artifact Link to be created. + + Returns: + The newly created model version to artifact link. + + Raises: + EntityExistsError: If a link with the given name already exists. + """ + + @abstractmethod + def list_model_version_artifact_links( + self, + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, + ) -> Page[ModelVersionArtifactResponseModel]: + """Get all model version to artifact links by filter. + + Args: + model_version_artifact_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to artifact links. + """ + + @abstractmethod + def delete_model_version_artifact_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to artifact link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. + + Raises: + KeyError: specified ID or name not found. + """ + + ############################### + # Model Versions Pipeline Runs + ############################### + + @abstractmethod + def create_model_version_pipeline_run_link( + self, + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + ) -> ModelVersionPipelineRunResponseModel: + """Creates a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: the Model Version to Pipeline Run Link to be created. + + Returns: + The newly created model version to pipeline run link. + + Raises: + EntityExistsError: If a link with the given ID already exists. + """ + + @abstractmethod + def list_model_version_pipeline_run_links( + self, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, + ) -> Page[ModelVersionPipelineRunResponseModel]: + """Get all model version to pipeline run links by filter. + + Args: + model_version_pipeline_run_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to pipeline run links. + """ + + @abstractmethod + def delete_model_version_pipeline_run_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to pipeline run link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. + + Raises: + KeyError: specified ID not found. + """ diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index b8c11ddf16f..04cbf55d373 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -16,6 +16,7 @@ from contextlib import ExitStack as does_not_raise from datetime import datetime from typing import Dict, List, Optional, Tuple +from uuid import uuid4 import pytest from pydantic import SecretStr @@ -25,6 +26,7 @@ CodeRepositoryContext, ComponentContext, CrudTestConfig, + ModelVersionContext, PipelineRunContext, RoleContext, ServiceConnectorContext, @@ -48,9 +50,18 @@ from zenml.logging.step_logging import prepare_logs_uri from zenml.models import ( ArtifactFilterModel, + ArtifactResponseModel, ComponentFilterModel, ComponentUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionFilterModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionRequestModel, + ModelVersionUpdateModel, PipelineRunFilterModel, + PipelineRunResponseModel, RoleFilterModel, RoleRequestModel, RoleUpdateModel, @@ -2423,3 +2434,647 @@ def test_connector_validation(): secrets=secrets, ): pass + + +################# +# Models +################# + + +class TestModelVersion: + def test_model_version_create_pass(self): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + + def test_model_version_create_duplicated(self): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + with pytest.raises(EntityExistsError): + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + + def test_model_version_create_no_model(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=uuid4(), + version="great one", + ) + ) + + def test_model_version_get_not_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, model_version_name_or_id="1.0.0" + ) + + def test_model_version_get_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id="great one", + ) + + def test_model_version_list_empty(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) + ) + assert len(mvs) == 0 + + def test_model_version_list_not_empty(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="and yet another one", + ) + ) + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) + ) + assert len(mvs) == 2 + assert mv1 in mvs + assert mv2 in mvs + + def test_model_version_delete_not_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version( + model_name_or_id=model.id, + model_version_name_or_id="1.0.0", + ) + + def test_model_version_delete_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + zs.delete_model_version( + model_name_or_id=model.id, + model_version_name_or_id="great one", + ) + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id="great one", + ) + + def test_model_version_update_not_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.update_model_version( + model_version_id=uuid4(), + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + + def test_model_version_update_not_forced(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="yet another one", + ) + ) + zs.update_model_version( + model_version_id=mv1.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + with pytest.raises(RuntimeError): + zs.update_model_version( + model_version_id=mv2.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + + def test_model_version_update_forced(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="yet another one", + ) + ) + zs.update_model_version( + model_version_id=mv1.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + == "staging" + ) + zs.update_model_version( + model_version_id=mv2.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=True, + ), + ) + + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + == "archived" + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv2.version, + ).stage + == "staging" + ) + + def test_model_version_update_public_interface(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + is None + ) + mv1.set_stage("staging") + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + == "staging" + ) + + +class TestModelVersionArtifactLinks: + def test_link_create_pass(self): + with ModelVersionContext(True, create_artifacts=1) as ( + model_version, + artifacts, + ): + zs = Client().zen_store + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=artifacts[0].id, + ) + ) + + def test_link_create_duplicated(self): + with ModelVersionContext(True, create_artifacts=1) as ( + model_version, + artifacts, + ): + zs = Client().zen_store + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=artifacts[0].id, + ) + ) + # name collision + with pytest.raises(EntityExistsError): + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=uuid4(), + ) + ) + # id collision + with pytest.raises(EntityExistsError): + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link2", + artifact=artifacts[0].id, + ) + ) + + def test_link_delete_found(self): + with ModelVersionContext(True, create_artifacts=1) as ( + model_version, + artifacts, + ): + zs = Client().zen_store + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=artifacts[0].id, + ) + ) + zs.delete_model_version_artifact_link( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + model_version_artifact_link_name_or_id="link", + ) + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + + def test_link_delete_not_found(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version_artifact_link( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + model_version_artifact_link_name_or_id="link", + ) + + def test_link_list_empty(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + + def test_link_list_populated(self): + with ModelVersionContext(True, create_artifacts=3) as ( + model_version, + artifacts, + ): + zs = Client().zen_store + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + for n, mo, dep, artifact in [ + ("link1", False, False, artifacts[0]), + ("link2", True, False, artifacts[1]), + ("link3", False, True, artifacts[2]), + ]: + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name=n, + artifact=artifact.id, + is_model_object=mo, + is_deployment=dep, + ) + ) + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 3 + + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_artifacts=True, + ) + ) + assert len(mvls) == 1 and mvls[0].name == "link1" + + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_model_objects=True, + ) + ) + assert len(mvls) == 1 and mvls[0].name == "link2" + + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_deployments=True, + ) + ) + assert len(mvls) == 1 and mvls[0].name == "link3" + + mv = zs.get_model_version( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + ) + + assert len(mv.model_object_ids) == 1 + assert len(mv.artifact_object_ids) == 1 + assert len(mv.deployment_ids) == 1 + + assert isinstance( + mv.model_objects["link2"], + ArtifactResponseModel, + ) + assert isinstance( + mv.artifact_objects["link1"], + ArtifactResponseModel, + ) + assert isinstance( + mv.deployments["link3"], + ArtifactResponseModel, + ) + + assert mv.model_objects["link2"].id == artifacts[1].id + + assert mv.get_model_object("link2") == mv.model_objects["link2"] + assert ( + mv.get_artifact_object("link1") == mv.artifact_objects["link1"] + ) + assert mv.get_deployment("link3") == mv.deployments["link3"] + + +class TestModelVersionPipelineRunLinks: + def test_link_create_pass(self): + with ModelVersionContext(True, create_prs=1) as ( + model_version, + prs, + ): + zs = Client().zen_store + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + + def test_link_create_duplicated(self): + with ModelVersionContext(True, create_prs=1) as ( + model_version, + prs, + ): + zs = Client().zen_store + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + # name collision + with pytest.raises(EntityExistsError): + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=uuid4(), + ) + ) + # id collision + with pytest.raises(EntityExistsError): + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + + def test_link_delete_found(self): + with ModelVersionContext(True, create_prs=1) as ( + model_version, + prs, + ): + zs = Client().zen_store + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + zs.delete_model_version_pipeline_run_link( + model_version.model.id, model_version.id, "link" + ) + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + + def test_link_delete_not_found(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version_pipeline_run_link( + model_version.model.id, model_version.id, "link" + ) + + def test_link_list_empty(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + + def test_link_list_populated(self): + with ModelVersionContext(True, create_prs=2) as ( + model_version, + prs, + ): + zs = Client().zen_store + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + for n, pr in zip(["link4", None], prs): + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name=n, + pipeline_run=pr.id, + ) + ) + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 2 + + mv = zs.get_model_version( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + ) + + assert len(mv.pipeline_run_ids) == 2 + + assert isinstance( + mv.pipeline_runs["link4"], + PipelineRunResponseModel, + ) + assert isinstance( + mv.pipeline_runs[prs[1].name], + PipelineRunResponseModel, + ) + + assert mv.pipeline_runs["link4"].id == prs[0].id + assert mv.pipeline_runs[prs[1].name].id == prs[1].id + + assert mv.get_pipeline_run("link4") == mv.pipeline_runs["link4"] + assert ( + mv.get_pipeline_run(prs[1].name) + == mv.pipeline_runs[prs[1].name] + ) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 4c0ba0046be..00f04863f62 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -45,6 +45,7 @@ ModelFilterModel, ModelRequestModel, ModelUpdateModel, + ModelVersionRequestModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineDeploymentFilterModel, @@ -509,6 +510,122 @@ def __exit__(self, exc_type, exc_value, exc_traceback): pass +class ModelVersionContext: + def __init__( + self, + create_version: bool = False, + create_artifacts: int = 0, + create_prs: int = 0, + ): + client = Client() + self.workspace = client.active_workspace.id + self.user = client.active_user.id + self.model = "su_model" + self.model_version = "2.0.0" + self.del_ws = False + self.del_user = False + self.del_model = False + + self.create_version = create_version + self.create_artifacts = create_artifacts + self.artifacts = [] + self.create_prs = create_prs + self.prs = [] + + def __enter__(self): + zs = Client().zen_store + try: + ws = zs.get_workspace(self.workspace) + except KeyError: + ws = zs.create_workspace( + WorkspaceRequestModel(name=self.workspace) + ) + self.del_ws = True + try: + user = zs.get_user(self.user) + except KeyError: + user = zs.create_user(UserRequestModel(name=self.user)) + self.del_user = True + try: + model = zs.get_model(self.model) + except KeyError: + model = zs.create_model( + ModelRequestModel( + name=self.model, user=user.id, workspace=ws.id + ) + ) + self.del_model = True + if self.create_version: + try: + mv = zs.get_model_version(self.model, self.model_version) + except KeyError: + mv = zs.create_model_version( + ModelVersionRequestModel( + user=user.id, + workspace=ws.id, + model=model.id, + version=self.model_version, + ) + ) + + for _ in range(self.create_artifacts): + self.artifacts.append( + zs.create_artifact( + ArtifactRequestModel( + name=sample_name("sample_artifact"), + data_type="module.class", + materializer="module.class", + type=ArtifactType.DATA, + uri="", + user=user.id, + workspace=ws.id, + ) + ) + ) + for _ in range(self.create_prs): + self.prs.append( + zs.create_run( + PipelineRunRequestModel( + id=uuid.uuid4(), + name=sample_name("sample_pipeline_run"), + status="running", + config=PipelineConfiguration(name="aria_pipeline"), + user=user.id, + workspace=ws.id, + ) + ) + ) + if self.create_version: + if self.create_artifacts: + return mv, self.artifacts + if self.create_prs: + return mv, self.prs + else: + return mv + else: + if self.create_artifacts: + return model, self.artifacts + if self.create_prs: + return model, self.prs + else: + return model + + def __exit__(self, exc_type, exc_value, exc_traceback): + zs = Client().zen_store + if self.create_version: + zs.delete_model_version(self.model, self.model_version) + if self.del_model: + zs.delete_model(self.model) + for artifact in self.artifacts: + zs.delete_artifact(artifact.id) + for run in self.prs: + zs.delete_run(run.id) + if self.del_user: + zs.delete_user(self.user) + if self.del_ws: + zs.delete_workspace(self.workspace) + + class CatClawMarks(AuthenticationConfig): """Cat claw marks authentication credentials."""