From 320d27c620d15fb598a4cb1f7f1f0fd024234c12 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:40:52 +0200 Subject: [PATCH 01/40] big bang commit --- src/zenml/constants.py | 1 + src/zenml/model/__init__.py | 26 ++++ src/zenml/model/model_config.py | 20 +++ src/zenml/models/__init__.py | 56 ++++++- src/zenml/models/model_models.py | 144 ++++++++++++++++++ src/zenml/zen_stores/rest_zen_store.py | 91 +++++++++++ src/zenml/zen_stores/schemas/__init__.py | 2 + src/zenml/zen_stores/schemas/model_schemas.py | 129 ++++++++++++++++ src/zenml/zen_stores/schemas/user_schemas.py | 4 + .../zen_stores/schemas/workspace_schemas.py | 5 + src/zenml/zen_stores/sql_zen_store.py | 140 +++++++++++++++++ src/zenml/zen_stores/zen_store_interface.py | 73 +++++++++ .../functional/zen_stores/utils.py | 24 +++ 13 files changed, 714 insertions(+), 1 deletion(-) create mode 100644 src/zenml/model/__init__.py create mode 100644 src/zenml/model/model_config.py create mode 100644 src/zenml/models/model_models.py create mode 100644 src/zenml/zen_stores/schemas/model_schemas.py diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 4db56eb62e7..007aa1e7a37 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -221,6 +221,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_VERIFY = "/verify" SERVICE_CONNECTOR_RESOURCES = "/resources" SERVICE_CONNECTOR_CLIENT = "/client" +MODELS = "/models" # mandatory stack component attributes MANDATORY_COMPONENT_ATTRIBUTES = ["name", "uuid"] diff --git a/src/zenml/model/__init__.py b/src/zenml/model/__init__.py new file mode 100644 index 00000000000..193f4295309 --- /dev/null +++ b/src/zenml/model/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Initialization of ZenML model. + +ZenML model support Model WatchTower feature. +""" + +from zenml.model.model_config import ModelConfig +from zenml.model.model import Model + +__all__ = [ + "Model", + "ModelConfig", +] diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py new file mode 100644 index 00000000000..8aa9cc70aeb --- /dev/null +++ b/src/zenml/model/model_config.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +""" ModelConfig user facing interface to pass into pipeline or step""" + +from pydantic import BaseModel + + +class ModelConfig(BaseModel): + pass diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index b7e39cc21a9..6ab749d457b 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# Copyright (c) ZenML GmbH 2022-2023. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -145,6 +145,19 @@ LogsRequestModel, LogsResponseModel, ) +from zenml.models.model_models import ( + ModelBaseModel, + ModelFilterModel, + ModelResponseModel, + ModelRequestModel, + ModelUpdateModel, + ModelConfigBaseModel, + ModelConfigResponseModel, + ModelConfigRequestModel, + ModelVersionBaseModel, + ModelVersionResponseModel, + ModelVersionRequestModel, +) ComponentResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, @@ -254,6 +267,36 @@ WorkspaceResponseModel=WorkspaceResponseModel, ) +ModelRequestModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + +ModelResponseModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + +ModelConfigRequestModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + +ModelConfigResponseModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + +ModelVersionRequestModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + +ModelVersionResponseModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + __all__ = [ "ArtifactRequestModel", "ArtifactResponseModel", @@ -346,4 +389,15 @@ "LogsBaseModel", "LogsRequestModel", "LogsResponseModel", + "ModelBaseModel", + "ModelFilterModel", + "ModelRequestModel", + "ModelResponseModel", + "ModelUpdateModel", + "ModelConfigBaseModel", + "ModelConfigRequestModel", + "ModelConfigResponseModel", + "ModelVersionBaseModel", + "ModelVersionRequestModel", + "ModelVersionResponseModel", ] diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py new file mode 100644 index 00000000000..0c43818278f --- /dev/null +++ b/src/zenml/models/model_models.py @@ -0,0 +1,144 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List, Optional, Union +from uuid import UUID + +from pydantic import BaseModel, Field + +from zenml.models.base_models import ( + WorkspaceScopedRequestModel, + WorkspaceScopedResponseModel, +) +from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH +from zenml.models.filter_models import WorkspaceScopedFilterModel + + +class ModelVersionBaseModel(BaseModel): + pass + + +class ModelVersionRequestModel( + ModelVersionBaseModel, + WorkspaceScopedRequestModel, +): + pass + + +class ModelVersionResponseModel( + ModelVersionBaseModel, + WorkspaceScopedResponseModel, +): + pass + + +class ModelConfigBaseModel(BaseModel): + pass + + +class ModelConfigRequestModel( + ModelConfigBaseModel, + WorkspaceScopedRequestModel, +): + pass + + +class ModelConfigResponseModel( + ModelConfigBaseModel, + WorkspaceScopedResponseModel, +): + pass + + +class ModelBaseModel(BaseModel): + name: str = Field( + title="The name of the model", + max_length=STR_FIELD_MAX_LENGTH, + ) + license: Optional[str] = Field( + title="The license model created under", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + description: Optional[str] = Field( + title="The description of the model", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + audience: Optional[str] = Field( + title="The target audience of the model", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + use_cases: Optional[str] = Field( + title="The use cases of the model", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + limitations: Optional[str] = Field( + title="The know limitations of the model", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + trade_offs: Optional[str] = Field( + title="The trade offs of the model", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + ethic: Optional[str] = Field( + title="The ethical implications of the model", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + tags: Optional[List[str]] = Field( + title="Tags associated with the model", + ) + + +class ModelRequestModel( + WorkspaceScopedRequestModel, + ModelBaseModel, +): + pass + + +class ModelResponseModel( + WorkspaceScopedResponseModel, + ModelBaseModel, +): + @property + def versions(self) -> List[ModelVersionResponseModel]: + pass + + def get_version(version: str) -> ModelVersionResponseModel: + pass + + +class ModelFilterModel(WorkspaceScopedFilterModel): + """Model to enable advanced filtering of all Workspaces.""" + + name: Optional[str] = Field( + default=None, + description="Name of the Model", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="Workspace of the Model" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="User of the Model" + ) + + +class ModelUpdateModel(ModelBaseModel): + license: Optional[str] + description: Optional[str] + audience: Optional[str] + use_cases: Optional[str] + limitations: Optional[str] + trade_offs: Optional[str] + ethic: Optional[str] + tags: Optional[List[str]] diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 6a4768cc493..b280c152955 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -49,6 +49,7 @@ GET_OR_CREATE, INFO, LOGIN, + MODELS, PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, @@ -94,6 +95,10 @@ FlavorRequestModel, FlavorResponseModel, FlavorUpdateModel, + ModelFilterModel, + ModelRequestModel, + ModelResponseModel, + ModelUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -2276,6 +2281,92 @@ def get_service_connector_type( connector_type ) + ######### + # Model + ######### + + def create_model(self, model: ModelRequestModel) -> ModelResponseModel: + """Creates a new model. + + Args: + model: the Model to be created. + + Returns: + The newly created model. + """ + return self._create_workspace_scoped_resource( + resource=model, + response_model=ModelResponseModel, + route=MODELS, + ) + + def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: + """Deletes a model. + + Args: + model_name_or_id: name or id of the model to be deleted. + + Returns: + The newly created or existing model. + """ + self._delete_resource(resource_id=model_name_or_id, route=MODELS) + + def update_model( + self, + model_id: UUID, + model_update: ModelUpdateModel, + ) -> ModelResponseModel: + """Updates an existing model. + + Args: + model_id: UUID of the model to be updated. + model: the Model to be updated. + + Returns: + The updated model. + """ + return self._update_resource( + resource_id=model_id, + resource_update=model_update, + route=MODELS, + response_model=ModelResponseModel, + ) + + def get_model( + self, model_name_or_id: Union[str, UUID] + ) -> ModelResponseModel: + """Get an existing model. + + Args: + model_name_or_id: name or id of the model to be retrieved. + + Returns: + The model of interest. + """ + return self._get_resource( + resource_id=model_name_or_id, + route=MODELS, + response_model=ModelResponseModel, + ) + + def list_models( + self, model_filter_model: ModelFilterModel + ) -> Page[ModelResponseModel]: + """Get all models by filter. + + Args: + model_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all models. + """ + return self._list_paginated_resources( + route=MODELS, + response_model=ModelResponseModel, + filter_model=model_filter_model, + ) + # ======================= # Internal helper methods # ======================= diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index a3efc01a231..dcd454d173f 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -57,6 +57,7 @@ ) from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema +from zenml.zen_stores.schemas.model_schemas import ModelSchema __all__ = [ "ArtifactSchema", @@ -91,4 +92,5 @@ "UserRoleAssignmentSchema", "UserSchema", "LogsSchema", + "ModelSchema", ] diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py new file mode 100644 index 00000000000..e7a5f8543a9 --- /dev/null +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -0,0 +1,129 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""SQLModel implementation of model tables.""" + + +from datetime import datetime +from typing import Optional +from uuid import UUID + +from sqlalchemy import TEXT, Column +from sqlmodel import Field, Relationship + +from zenml.models import ( + ModelRequestModel, + ModelResponseModel, + ModelUpdateModel, +) +from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema + + +class ModelSchema(NamedSchema, table=True): + """SQL Model for model.""" + + __tablename__ = "model" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship(back_populates="models") + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship(back_populates="models") + + license: str = Field(sa_column=Column(TEXT, nullable=True)) + description: str = Field(sa_column=Column(TEXT, nullable=True)) + audience: str = Field(sa_column=Column(TEXT, nullable=True)) + use_cases: str = Field(sa_column=Column(TEXT, nullable=True)) + limitations: str = Field(sa_column=Column(TEXT, nullable=True)) + trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) + ethic: str = Field(sa_column=Column(TEXT, nullable=True)) + tags: str = Field(sa_column=Column(TEXT, nullable=True)) + + @classmethod + def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": + """Convert an `ModelRequestModel` to an `ModelSchema`. + + Args: + model_request: The request model to convert. + + Returns: + The converted schema. + """ + return cls( + name=model_request.name, + workspace_id=model_request.workspace, + user_id=model_request.user, + license=model_request.license, + description=model_request.description, + audience=model_request.audience, + use_cases=model_request.use_cases, + limitations=model_request.limitations, + trade_offs=model_request.trade_offs, + ethic=model_request.ethic, + tags=model_request.tags, + ) + + def to_model(self) -> ModelResponseModel: + """Convert an `ModelSchema` to an `ModelResponseModel`. + + Returns: + The created `ModelResponseModel`. + """ + return ModelResponseModel( + id=self.id, + name=self.name, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + license=self.license, + description=self.description, + audience=self.audience, + use_cases=self.use_cases, + limitations=self.limitations, + trade_offs=self.trade_offs, + ethic=self.ethic, + tags=self.tags, + ) + + def update( + self, + model_update: ModelUpdateModel, + ) -> "ModelSchema": + """Updates a `ModelSchema` from a `ModelUpdateModel`. + + Args: + model_update: The `ModelUpdateModel` to update from. + + Returns: + The updated `ModelSchema`. + """ + for field, value in model_update.dict(exclude_unset=True).items(): + setattr(self, field, value) + self.updated = datetime.utcnow() + return self diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 5710e05b735..33c4319ea9d 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -27,6 +27,7 @@ ArtifactSchema, CodeRepositorySchema, FlavorSchema, + ModelSchema, PipelineBuildSchema, PipelineDeploymentSchema, PipelineRunSchema, @@ -91,6 +92,9 @@ class UserSchema(NamedSchema, table=True): service_connectors: List["ServiceConnectorSchema"] = Relationship( back_populates="user", ) + service_connectors: List["ModelSchema"] = Relationship( + back_populates="user", + ) @classmethod def from_request(cls, model: UserRequestModel) -> "UserSchema": diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index a1b7d7bbe2c..649848b3ca3 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -29,6 +29,7 @@ ArtifactSchema, CodeRepositorySchema, FlavorSchema, + ModelSchema, PipelineBuildSchema, PipelineDeploymentSchema, PipelineRunSchema, @@ -116,6 +117,10 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) + models: List["ModelSchema"] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 1fb101d8180..624dfda44f5 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -94,6 +94,10 @@ FlavorRequestModel, FlavorResponseModel, FlavorUpdateModel, + ModelFilterModel, + ModelRequestModel, + ModelResponseModel, + ModelUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -189,6 +193,7 @@ CodeRepositorySchema, FlavorSchema, IdentitySchema, + ModelSchema, NamedSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -5354,3 +5359,138 @@ def _create_or_reuse_code_reference( session.add(new_reference) return new_reference.id + + ######## + # Model + ######## + + def create_model(self, model: ModelRequestModel) -> ModelResponseModel: + """Creates a new model. + + Args: + model: the Model to be created. + + Returns: + The newly created model. + """ + with Session(self.engine) as session: + # Save artifact. + model_schema = ModelSchema.from_request(model) + session.add(model_schema) + + session.commit() + return ModelSchema.to_model(model_schema) + + def get_model( + self, model_name_or_id: Union[str, UUID] + ) -> ModelResponseModel: + """Get an existing model. + + Args: + model_name_or_id: name or id of the model to be retrieved. + + Returns: + The model of interest. + """ + with Session(self.engine) as session: + is_id = type(model_name_or_id) == UUID + if is_id: + model = session.exec( + select(ModelSchema).where( + ModelSchema.id == model_name_or_id + ) + ).first() + else: + model = session.exec( + select(ModelSchema).where( + ModelSchema.name == model_name_or_id + ) + ).first() + if model is None: + raise KeyError( + f"Unable to get model with {'ID' if is_id else 'Name'} `{model_name_or_id}`: " + f"No model with this {'ID' if is_id else 'Name'} found." + ) + return ModelSchema.to_model(model) + + def list_models( + self, model_filter_model: ModelFilterModel + ) -> Page[ModelResponseModel]: + """Get all models by filter. + + Args: + model_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all models. + """ + with Session(self.engine) as session: + query = select(ModelSchema) + return self.filter_and_paginate( + session=session, + query=query, + table=ModelSchema, + filter_model=model_filter_model, + ) + + def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: + """Deletes a model. + + Args: + model_name_or_id: name or id of the model to be deleted. + + Returns: + The newly created or existing model. + """ + with Session(self.engine) as session: + is_id = type(model_name_or_id) == UUID + if is_id: + model = session.exec( + select(ModelSchema).where( + ModelSchema.id == model_name_or_id + ) + ).first() + else: + model = session.exec( + select(ModelSchema).where( + ModelSchema.name == model_name_or_id + ) + ).first() + if model is None: + raise KeyError( + f"Unable to delete model with {'ID' if is_id else 'Name'} `{model_name_or_id}`: " + f"No model with this {'ID' if is_id else 'Name'} found." + ) + session.delete(model) + session.commit() + + def update_model( + self, + model_id: UUID, + model_update: ModelUpdateModel, + ) -> ModelResponseModel: + """Updates an existing model. + + Args: + model_id: UUID of the model to be updated. + model: the Model to be updated. + + Returns: + The updated model. + """ + with Session(self.engine) as session: + existing_model = session.exec( + select(ModelSchema).where(ModelSchema.id == model_id) + ).first() + + if not existing_model: + raise KeyError(f"Model with ID {model_id} not found.") + + existing_model.update(model_update=model_update) + session.add(existing_model) + session.commit() + + # Refresh the Model that was just created + session.refresh(existing_model) + return existing_model.to_model() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 336e6d49874..ef262d2db70 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -32,6 +32,10 @@ FlavorRequestModel, FlavorResponseModel, FlavorUpdateModel, + ModelFilterModel, + ModelRequestModel, + ModelResponseModel, + ModelUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1666,3 +1670,72 @@ def get_service_connector_type( Raises: KeyError: If no service connector type with the given ID exists. """ + + ######### + # Model + ######### + + @abstractmethod + def create_model(self, model: ModelRequestModel) -> ModelResponseModel: + """Creates a new model. + + Args: + model: the Model to be created. + + Returns: + The newly created model. + """ + + @abstractmethod + def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: + """Deletes a model. + + Args: + model_name_or_id: name or id of the model to be deleted. + + Returns: + The newly created or existing model. + """ + + @abstractmethod + def update_model( + self, + model_id: UUID, + model_update: ModelUpdateModel, + ) -> ModelResponseModel: + """Updates an existing model. + + Args: + model_id: UUID of the model to be updated. + model: the Model to be updated. + + Returns: + The updated model. + """ + + @abstractmethod + def get_model( + self, model_name_or_id: Union[str, UUID] + ) -> ModelResponseModel: + """Get an existing model. + + Args: + model_name_or_id: name or id of the model to be retrieved. + + Returns: + The model of interest. + """ + + @abstractmethod + def list_models( + self, model_filter_model: ModelFilterModel + ) -> Page[ModelResponseModel]: + """Get all models by filter. + + Args: + model_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all models. + """ diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 37c2d0f4e6f..89a3f3f87b2 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -42,6 +42,9 @@ ComponentUpdateModel, FlavorFilterModel, FlavorRequestModel, + ModelFilterModel, + ModelRequestModel, + ModelUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineDeploymentFilterModel, @@ -813,6 +816,26 @@ def update_method( filter_model=ServiceConnectorFilterModel, entity_name="service_connector", ) +model_crud_test_config = CrudTestConfig( + create_model=ModelRequestModel( + user=uuid.uuid4(), + workspace=uuid.uuid4(), + name="super_model", + license="who cares", + description="cool stuff", + audience="world", + use_cases="all", + limitations="none", + trade_offs="secret", + ethic="all good", + tags=["cool", "stuff"], + ), + update_model=ModelUpdateModel( + name=sample_name("updated_sample_service_connector"), + ), + filter_model=ModelFilterModel, + entity_name="model", +) # step_run_crud_test_config = CrudTestConfig( # create_model=StepRunRequestModel( @@ -848,4 +871,5 @@ def update_method( deployment_crud_test_config, code_repository_crud_test_config, service_connector_crud_test_config, + model_crud_test_config, ] From 5e72e3d767ebde58f9c4afe462a43f86eafc7f3c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 8 Sep 2023 17:03:39 +0200 Subject: [PATCH 02/40] typo --- src/zenml/zen_stores/schemas/user_schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 33c4319ea9d..e9ea6aeb522 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -92,7 +92,7 @@ class UserSchema(NamedSchema, table=True): service_connectors: List["ServiceConnectorSchema"] = Relationship( back_populates="user", ) - service_connectors: List["ModelSchema"] = Relationship( + models: List["ModelSchema"] = Relationship( back_populates="user", ) From 18b5344f77d644f5473bf2ab91903f1623be5a16 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 8 Sep 2023 17:06:22 +0200 Subject: [PATCH 03/40] Apply suggestions from code review Co-authored-by: Felix Altenberger --- src/zenml/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 6ab749d457b..074ede5aec1 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) ZenML GmbH 2022-2023. All Rights Reserved. +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From b32a1d3a3a7ecaf637eae9a4254f08d139dae42e Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 07:59:00 +0200 Subject: [PATCH 04/40] add Alembic --- .../versions/3b68abe58f44_add_model_entity.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py diff --git a/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py b/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py new file mode 100644 index 00000000000..541699e23bf --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py @@ -0,0 +1,61 @@ +"""add model entity [3b68abe58f44]. + +Revision ID: 3b68abe58f44 +Revises: 0.44.1 +Create Date: 2023-09-11 07:53:18.641081 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3b68abe58f44" +down_revision = "0.44.1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "model", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("license", sa.TEXT(), nullable=True), + sa.Column("description", sa.TEXT(), nullable=True), + sa.Column("audience", sa.TEXT(), nullable=True), + sa.Column("use_cases", sa.TEXT(), nullable=True), + sa.Column("limitations", sa.TEXT(), nullable=True), + sa.Column("trade_offs", sa.TEXT(), nullable=True), + sa.Column("ethic", sa.TEXT(), nullable=True), + sa.Column("tags", sa.TEXT(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("model") + # ### end Alembic commands ### From 661729d65fe66ea2306a572f9b80dc674a1f0a7c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 08:08:37 +0200 Subject: [PATCH 05/40] lint --- src/zenml/model/__init__.py | 26 --------------------- src/zenml/model/model_config.py | 20 ---------------- src/zenml/models/model_models.py | 23 ++++++++++++++++++ src/zenml/zen_stores/rest_zen_store.py | 2 +- src/zenml/zen_stores/sql_zen_store.py | 2 +- src/zenml/zen_stores/zen_store_interface.py | 2 +- 6 files changed, 26 insertions(+), 49 deletions(-) delete mode 100644 src/zenml/model/__init__.py delete mode 100644 src/zenml/model/model_config.py diff --git a/src/zenml/model/__init__.py b/src/zenml/model/__init__.py deleted file mode 100644 index 193f4295309..00000000000 --- a/src/zenml/model/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) ZenML GmbH 2023. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Initialization of ZenML model. - -ZenML model support Model WatchTower feature. -""" - -from zenml.model.model_config import ModelConfig -from zenml.model.model import Model - -__all__ = [ - "Model", - "ModelConfig", -] diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py deleted file mode 100644 index 8aa9cc70aeb..00000000000 --- a/src/zenml/model/model_config.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) ZenML GmbH 2023. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -""" ModelConfig user facing interface to pass into pipeline or step""" - -from pydantic import BaseModel - - -class ModelConfig(BaseModel): - pass diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 0c43818278f..69b9ecd61a0 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. +"""Model implementation to support Model WatchTower feature.""" from typing import List, Optional, Union from uuid import UUID @@ -26,6 +27,8 @@ class ModelVersionBaseModel(BaseModel): + """Model Version base model.""" + pass @@ -33,6 +36,8 @@ class ModelVersionRequestModel( ModelVersionBaseModel, WorkspaceScopedRequestModel, ): + """Model Version request model.""" + pass @@ -40,10 +45,14 @@ class ModelVersionResponseModel( ModelVersionBaseModel, WorkspaceScopedResponseModel, ): + """Model Version response model.""" + pass class ModelConfigBaseModel(BaseModel): + """Model Config base model.""" + pass @@ -51,6 +60,8 @@ class ModelConfigRequestModel( ModelConfigBaseModel, WorkspaceScopedRequestModel, ): + """Model Config request model.""" + pass @@ -58,10 +69,14 @@ class ModelConfigResponseModel( ModelConfigBaseModel, WorkspaceScopedResponseModel, ): + """Model Config response model.""" + pass class ModelBaseModel(BaseModel): + """Model base model.""" + name: str = Field( title="The name of the model", max_length=STR_FIELD_MAX_LENGTH, @@ -103,6 +118,8 @@ class ModelRequestModel( WorkspaceScopedRequestModel, ModelBaseModel, ): + """Model request model.""" + pass @@ -110,11 +127,15 @@ class ModelResponseModel( WorkspaceScopedResponseModel, ModelBaseModel, ): + """Model response model.""" + @property def versions(self) -> List[ModelVersionResponseModel]: + """List all versions of the model.""" pass def get_version(version: str) -> ModelVersionResponseModel: + """Get specific version of the model.""" pass @@ -134,6 +155,8 @@ class ModelFilterModel(WorkspaceScopedFilterModel): class ModelUpdateModel(ModelBaseModel): + """Model update model.""" + license: Optional[str] description: Optional[str] audience: Optional[str] diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index b280c152955..08621cf7894 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2320,7 +2320,7 @@ def update_model( Args: model_id: UUID of the model to be updated. - model: the Model to be updated. + model_update: the Model to be updated. Returns: The updated model. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 624dfda44f5..9438566247d 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5474,7 +5474,7 @@ def update_model( Args: model_id: UUID of the model to be updated. - model: the Model to be updated. + model_update: the Model to be updated. Returns: The updated model. diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index ef262d2db70..1c9eb6e16b4 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1707,7 +1707,7 @@ def update_model( Args: model_id: UUID of the model to be updated. - model: the Model to be updated. + model_update: the Model to be updated. Returns: The updated model. From 6594e16ae3808c0d47c5475ceb8e04c42c87e3bd Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 08:57:18 +0200 Subject: [PATCH 06/40] mypy --- src/zenml/models/model_models.py | 4 ++-- src/zenml/zen_stores/schemas/model_schemas.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 69b9ecd61a0..5a9f434a100 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -130,11 +130,11 @@ class ModelResponseModel( """Model response model.""" @property - def versions(self) -> List[ModelVersionResponseModel]: + def versions(self) -> List[ModelVersionResponseModel]: # type: ignore[empty-body] """List all versions of the model.""" pass - def get_version(version: str) -> ModelVersionResponseModel: + def get_version(self, version: str) -> ModelVersionResponseModel: # type: ignore[empty-body] """Get specific version of the model.""" pass diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index e7a5f8543a9..b725d39b4bd 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -101,6 +101,8 @@ def to_model(self) -> ModelResponseModel: name=self.name, user=self.user.to_model() if self.user else None, workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, license=self.license, description=self.description, audience=self.audience, From 353bff8fadada13ba519bc3d9d8d05bf1aa68e00 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 10:07:15 +0200 Subject: [PATCH 07/40] darglint --- src/zenml/models/model_models.py | 8 ++++++-- src/zenml/zen_stores/rest_zen_store.py | 3 --- src/zenml/zen_stores/sql_zen_store.py | 10 ++++++++-- src/zenml/zen_stores/zen_store_interface.py | 3 --- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 5a9f434a100..ac28f039433 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -134,8 +134,12 @@ def versions(self) -> List[ModelVersionResponseModel]: # type: ignore[empty-bod """List all versions of the model.""" pass - def get_version(self, version: str) -> ModelVersionResponseModel: # type: ignore[empty-body] - """Get specific version of the model.""" + def get_version(self, version: Optional[str] = None) -> ModelVersionResponseModel: # type: ignore[empty-body] + """Get specific version of the model. + + Args: + version: version number, stage or None for latest version. + """ pass diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 08621cf7894..ccf96645263 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2305,9 +2305,6 @@ def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: Args: model_name_or_id: name or id of the model to be deleted. - - Returns: - The newly created or existing model. """ self._delete_resource(resource_id=model_name_or_id, route=MODELS) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 9438566247d..07c271bb950 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5389,6 +5389,9 @@ def get_model( Args: model_name_or_id: name or id of the model to be retrieved. + Raises: + KeyError: specified ID or name not found. + Returns: The model of interest. """ @@ -5440,8 +5443,8 @@ def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: Args: model_name_or_id: name or id of the model to be deleted. - Returns: - The newly created or existing model. + Raises: + KeyError: specified ID or name not found. """ with Session(self.engine) as session: is_id = type(model_name_or_id) == UUID @@ -5476,6 +5479,9 @@ def update_model( model_id: UUID of the model to be updated. model_update: the Model to be updated. + Raises: + KeyError: specified ID not found. + Returns: The updated model. """ diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 1c9eb6e16b4..79170be4306 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1692,9 +1692,6 @@ def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: Args: model_name_or_id: name or id of the model to be deleted. - - Returns: - The newly created or existing model. """ @abstractmethod From a856526ec843c164d3093c903d9694d1abda6f11 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 11:08:50 +0200 Subject: [PATCH 08/40] wip --- src/zenml/models/model_models.py | 43 +++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index ac28f039433..58108f33135 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -13,23 +13,29 @@ # permissions and limitations under the License. """Model implementation to support Model WatchTower feature.""" -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from uuid import UUID from pydantic import BaseModel, Field +from zenml.models.artifact_models import ArtifactResponseModel from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, ) from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.models.filter_models import WorkspaceScopedFilterModel +from zenml.zen_server.utils import zen_store class ModelVersionBaseModel(BaseModel): """Model Version base model.""" - pass + version: str + stage: Optional[str] + _model_objects: Dict[str, UUID] = None + _artifact_objects: Dict[str, UUID] = None + _deployments: Dict[str, UUID] = None class ModelVersionRequestModel( @@ -47,7 +53,38 @@ class ModelVersionResponseModel( ): """Model Version response model.""" - pass + @staticmethod + def _fetch_artifacts_from_list( + artifacts: Dict[str, UUID] + ) -> Dict[str, ArtifactResponseModel]: + if artifacts: + return { + name: zen_store().get_artifact(a) + for name, a in artifacts.items() + } + else: + return {} + + @property + def model_objects(self) -> Dict[str, ArtifactResponseModel]: + return self._fetch_artifacts_from_list(self._model_objects) + + @property + def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: + return self._fetch_artifacts_from_list(self._artifact_objects) + + @property + def deployments(self) -> Dict[str, ArtifactResponseModel]: + return self._fetch_artifacts_from_list(self._deployments) + + # TODO: after https://zenml.atlassian.net/browse/OSS-2419 + # def set_stage(self, stage: ModelStages): + # """Sets Model Version to a desired stage.""" + # ... + + # TODO in https://zenml.atlassian.net/browse/OSS-2433 + # def generate_model_card(self, template_name: str) -> str: + # """Return HTML/PDF based on input template""" class ModelConfigBaseModel(BaseModel): From e15ad92e8482c4ddc21083b2071b0d27d831ffa0 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 14:37:50 +0200 Subject: [PATCH 09/40] add endpoints --- src/zenml/models/model_models.py | 2 +- .../zen_server/routers/models_endpoints.py | 102 ++++++++++++++++++ .../routers/workspaces_endpoints.py | 79 ++++++++++++++ src/zenml/zen_server/zen_server_api.py | 2 + src/zenml/zen_stores/rest_zen_store.py | 7 +- src/zenml/zen_stores/sql_zen_store.py | 70 ++++++------ src/zenml/zen_stores/zen_store_interface.py | 5 +- 7 files changed, 232 insertions(+), 35 deletions(-) create mode 100644 src/zenml/zen_server/routers/models_endpoints.py diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index ac28f039433..574de528589 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -158,7 +158,7 @@ class ModelFilterModel(WorkspaceScopedFilterModel): ) -class ModelUpdateModel(ModelBaseModel): +class ModelUpdateModel(BaseModel): """Model update model.""" license: Optional[str] diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py new file mode 100644 index 00000000000..380f320b721 --- /dev/null +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -0,0 +1,102 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Endpoint definitions for models.""" + +from typing import Union +from uuid import UUID + +from fastapi import APIRouter, Security + +from zenml.constants import API, MODELS, VERSION_1 +from zenml.enums import PermissionType +from zenml.models import ( + ModelResponseModel, + ModelUpdateModel, +) +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.utils import ( + handle_exceptions, + zen_store, +) + +router = APIRouter( + prefix=API + VERSION_1 + MODELS, + tags=["models"], + responses={401: error_response}, +) + + +@router.delete( + "/{model_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model( + model_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model to delete. + """ + zen_store().delete_model(model_name_or_id) + + +@router.get( + "/{model_name_or_id}", + response_model=ModelResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_model( + model_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> ModelResponseModel: + """Get a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model to get. + + Returns: + The model with the given name or ID. + """ + return zen_store().get_model(model_name_or_id) + + +@router.put( + "/{model_id}", + response_model=ModelResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_model( + model_id: UUID, + model_update: ModelUpdateModel, + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> ModelResponseModel: + """Updates a model. + + Args: + model_id: Name of the stack. + model_update: Stack to use for the update. + + Returns: + The updated model. + """ + return zen_store().update_model( + model_id=model_id, + model_update=model_update, + ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index b272d9128ef..5ec6fe72846 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -21,6 +21,7 @@ API, CODE_REPOSITORIES, GET_OR_CREATE, + MODELS, PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, @@ -47,6 +48,9 @@ ComponentFilterModel, ComponentRequestModel, ComponentResponseModel, + ModelFilterModel, + ModelRequestModel, + ModelResponseModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1142,3 +1146,78 @@ def list_service_connector_resources( resource_type=resource_type, resource_id=resource_id, ) + + +@router.post( + WORKSPACES + "/{workspace_name_or_id}" + MODELS, + response_model=ModelResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model( + workspace_name_or_id: Union[str, UUID], + model: ModelRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelResponseModel: + """Create a new model. + + Args: + workspace_name_or_id: Name or ID of the workspace. + model: The model to create. + auth_context: Authentication context. + + Returns: + The created model. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model.workspace != workspace.id: + raise IllegalOperationError( + "Creating models outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model.user != auth_context.user.id: + raise IllegalOperationError( + "Creating models for a user other than yourself " + "is not supported." + ) + return zen_store().create_model(model) + + +@router.get( + WORKSPACES + "/{workspace_name_or_id}" + MODELS, + response_model=Page[ModelResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_models( + workspace_name_or_id: Union[str, UUID], + model_filter_model: ModelFilterModel = Depends( + make_dependable(ModelFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelResponseModel]: + """Get models according to query filters. + + Args: + workspace_name_or_id: Name or ID of the workspace. + model_filter_model: Filter model used for pagination, sorting, + filtering + + + Returns: + The models according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + return zen_store().list_models( + workspace_id=workspace_id, + model_filter_model=model_filter_model, + ) diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 896c5dc6b21..758a18cb5b1 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -35,6 +35,7 @@ auth_endpoints, code_repositories_endpoints, flavors_endpoints, + models_endpoints, pipeline_builds_endpoints, pipeline_deployments_endpoints, pipelines_endpoints, @@ -220,6 +221,7 @@ def dashboard(request: Request) -> Any: app.include_router(pipeline_builds_endpoints.router) app.include_router(pipeline_deployments_endpoints.router) app.include_router(code_repositories_endpoints.router) +app.include_router(models_endpoints.router) def get_root_static_files() -> List[str]: diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index ccf96645263..9c4e8981432 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2347,11 +2347,14 @@ def get_model( ) def list_models( - self, model_filter_model: ModelFilterModel + self, + workspace_id: UUID, + model_filter_model: ModelFilterModel, ) -> Page[ModelResponseModel]: """Get all models by filter. Args: + workspace_id: The name or ID of the workspace to scope to. model_filter_model: All filter parameters including pagination params. @@ -2359,7 +2362,7 @@ def list_models( A page of all models. """ return self._list_paginated_resources( - route=MODELS, + route=f"{WORKSPACES}/{workspace_id}{MODELS}", response_model=ModelResponseModel, filter_model=model_filter_model, ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 07c271bb950..f3dd41865db 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5318,6 +5318,30 @@ def _get_run_schema( session=session, ) + def _get_model_schema( + self, + model_name_or_id: Union[str, UUID], + session: Session, + ) -> ModelSchema: + """Gets a model schema by name or ID. + + This is a helper method that is used in various places to find a run + by its name or ID. + + Args: + model_name_or_id: The name or ID of the run to get. + session: The database session to use. + + Returns: + The model schema. + """ + return self._get_schema_by_name_or_id( + object_name_or_id=model_name_or_id, + schema_class=ModelSchema, + schema_name="model", + session=session, + ) + def _create_or_reuse_code_reference( self, session: Session, @@ -5396,32 +5420,25 @@ def get_model( The model of interest. """ with Session(self.engine) as session: - is_id = type(model_name_or_id) == UUID - if is_id: - model = session.exec( - select(ModelSchema).where( - ModelSchema.id == model_name_or_id - ) - ).first() - else: - model = session.exec( - select(ModelSchema).where( - ModelSchema.name == model_name_or_id - ) - ).first() + model = self._get_model_schema( + model_name_or_id=model_name_or_id, session=session + ) if model is None: raise KeyError( - f"Unable to get model with {'ID' if is_id else 'Name'} `{model_name_or_id}`: " - f"No model with this {'ID' if is_id else 'Name'} found." + f"Unable to get model with ID `{model_name_or_id}`: " + f"No model with this ID found." ) return ModelSchema.to_model(model) def list_models( - self, model_filter_model: ModelFilterModel + self, + workspace_id: UUID, + model_filter_model: ModelFilterModel, ) -> Page[ModelResponseModel]: """Get all models by filter. Args: + workspace_id: The ID of the workspace to scope to. model_filter_model: All filter parameters including pagination params. @@ -5430,6 +5447,7 @@ def list_models( """ with Session(self.engine) as session: query = select(ModelSchema) + model_filter_model.set_scope_workspace(workspace_id) return self.filter_and_paginate( session=session, query=query, @@ -5447,23 +5465,13 @@ def delete_model(self, model_name_or_id: Union[str, UUID]) -> None: KeyError: specified ID or name not found. """ with Session(self.engine) as session: - is_id = type(model_name_or_id) == UUID - if is_id: - model = session.exec( - select(ModelSchema).where( - ModelSchema.id == model_name_or_id - ) - ).first() - else: - model = session.exec( - select(ModelSchema).where( - ModelSchema.name == model_name_or_id - ) - ).first() + model = self._get_model_schema( + model_name_or_id=model_name_or_id, session=session + ) if model is None: raise KeyError( - f"Unable to delete model with {'ID' if is_id else 'Name'} `{model_name_or_id}`: " - f"No model with this {'ID' if is_id else 'Name'} found." + f"Unable to delete model with ID `{model_name_or_id}`: " + f"No model with this ID found." ) session.delete(model) session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 79170be4306..658a99d79e4 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1725,11 +1725,14 @@ def get_model( @abstractmethod def list_models( - self, model_filter_model: ModelFilterModel + self, + workspace_id: UUID, + model_filter_model: ModelFilterModel, ) -> Page[ModelResponseModel]: """Get all models by filter. Args: + workspace_id: The name or ID of the workspace to scope to. model_filter_model: All filter parameters including pagination params. From a49d2a9da3c38ca2c3597d4c56500616a1cf99da Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 15:28:16 +0200 Subject: [PATCH 10/40] add ModelStages --- src/zenml/model/__init__.py | 23 +++++++++++++++++++++++ src/zenml/model/model_stages.py | 26 ++++++++++++++++++++++++++ src/zenml/models/model_models.py | 9 +++++---- 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 src/zenml/model/__init__.py create mode 100644 src/zenml/model/model_stages.py diff --git a/src/zenml/model/__init__.py b/src/zenml/model/__init__.py new file mode 100644 index 00000000000..af2e669c0d5 --- /dev/null +++ b/src/zenml/model/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Initialization of ZenML model. +ZenML model support Model WatchTower feature. +""" + +from zenml.model.model_stages import ModelStages + +__all__ = [ + "ModelStages", +] diff --git a/src/zenml/model/model_stages.py b/src/zenml/model/model_stages.py new file mode 100644 index 00000000000..5c53793ce4d --- /dev/null +++ b/src/zenml/model/model_stages.py @@ -0,0 +1,26 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ModelStages lists supported stages of a Model Version.""" + +from enum import StrEnum + + +class ModelStages(StrEnum): + NONE = "none" + STAGING = "starting" + PRODUCTION = "production" + ARCHIVED = "archived" + # technical stages + LATEST = "latest" + RUNNING = "running" diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 90cc753c6ed..1ed444c43b3 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -18,6 +18,7 @@ from pydantic import BaseModel, Field +from zenml.model import ModelStages from zenml.models.artifact_models import ArtifactResponseModel from zenml.models.base_models import ( WorkspaceScopedRequestModel, @@ -32,6 +33,7 @@ class ModelVersionBaseModel(BaseModel): """Model Version base model.""" version: str + description: Optional[str] stage: Optional[str] _model_objects: Dict[str, UUID] = None _artifact_objects: Dict[str, UUID] = None @@ -77,10 +79,9 @@ def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: def deployments(self) -> Dict[str, ArtifactResponseModel]: return self._fetch_artifacts_from_list(self._deployments) - # TODO: after https://zenml.atlassian.net/browse/OSS-2419 - # def set_stage(self, stage: ModelStages): - # """Sets Model Version to a desired stage.""" - # ... + def set_stage(self, stage: ModelStages): + """Sets Model Version to a desired stage.""" + pass # TODO in https://zenml.atlassian.net/browse/OSS-2433 # def generate_model_card(self, template_name: str) -> str: From 077c6ee868033fcae3cc5b230f421c87f524502b Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 15:44:07 +0200 Subject: [PATCH 11/40] wip --- src/zenml/models/model_models.py | 44 +++++-- src/zenml/zen_stores/schemas/__init__.py | 6 +- src/zenml/zen_stores/schemas/model_schemas.py | 123 +++++++++++++++++- src/zenml/zen_stores/schemas/user_schemas.py | 4 + .../zen_stores/schemas/workspace_schemas.py | 5 + 5 files changed, 170 insertions(+), 12 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 1ed444c43b3..5eb2125e7e1 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -18,6 +18,7 @@ from pydantic import BaseModel, Field +from zenml.client import Client from zenml.model import ModelStages from zenml.models.artifact_models import ArtifactResponseModel from zenml.models.base_models import ( @@ -26,18 +27,40 @@ ) from zenml.models.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.models.filter_models import WorkspaceScopedFilterModel -from zenml.zen_server.utils import zen_store +from zenml.models.pipeline_run_models import PipelineRunResponseModel class ModelVersionBaseModel(BaseModel): """Model Version base model.""" - version: str - description: Optional[str] - stage: Optional[str] - _model_objects: Dict[str, UUID] = None - _artifact_objects: Dict[str, UUID] = None - _deployments: Dict[str, UUID] = None + version: str = Field( + title="The name of the model version", + max_length=STR_FIELD_MAX_LENGTH, + ) + description: Optional[str] = Field( + title="The description of the model version", + max_length=TEXT_FIELD_MAX_LENGTH, + ) + stage: Optional[str] = Field( + title="The stage of the model version", + max_length=STR_FIELD_MAX_LENGTH, + ) + _model_objects: Dict[str, UUID] = Field( + title="Model Objects linked to the model version", + default={}, + ) + _artifact_objects: Dict[str, UUID] = Field( + title="Artifacts linked to the model version", + default={}, + ) + _deployments: Dict[str, UUID] = Field( + title="Deployments linked to the model version", + default={}, + ) + _pipeline_runs: List[UUID] = Field( + title="Pipeline runs linked to the model version", + default=[], + ) class ModelVersionRequestModel( @@ -61,8 +84,7 @@ def _fetch_artifacts_from_list( ) -> Dict[str, ArtifactResponseModel]: if artifacts: return { - name: zen_store().get_artifact(a) - for name, a in artifacts.items() + name: Client().get_artifact(a) for name, a in artifacts.items() } else: return {} @@ -79,6 +101,10 @@ def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: def deployments(self) -> Dict[str, ArtifactResponseModel]: return self._fetch_artifacts_from_list(self._deployments) + @property + def pipeline_runs(self) -> List[PipelineRunResponseModel]: + return [Client().get_pipeline_run(pr) for pr in self._pipeline_runs] + def set_stage(self, stage: ModelStages): """Sets Model Version to a desired stage.""" pass diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index dcd454d173f..f1170ecf6f4 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -57,7 +57,10 @@ ) from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.logs_schemas import LogsSchema -from zenml.zen_stores.schemas.model_schemas import ModelSchema +from zenml.zen_stores.schemas.model_schemas import ( + ModelSchema, + ModelVersionSchema, +) __all__ = [ "ArtifactSchema", @@ -93,4 +96,5 @@ "UserSchema", "LogsSchema", "ModelSchema", + "ModelVersionSchema", ] diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index b725d39b4bd..1095cd7319d 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -15,7 +15,7 @@ from datetime import datetime -from typing import Optional +from typing import List, Optional from uuid import UUID from sqlalchemy import TEXT, Column @@ -26,7 +26,7 @@ ModelResponseModel, ModelUpdateModel, ) -from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -65,6 +65,125 @@ class ModelSchema(NamedSchema, table=True): trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) ethic: str = Field(sa_column=Column(TEXT, nullable=True)) tags: str = Field(sa_column=Column(TEXT, nullable=True)) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="model", + sa_relationship_kwargs={"cascade": "delete"}, + ) + + @classmethod + def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": + """Convert an `ModelRequestModel` to an `ModelSchema`. + + Args: + model_request: The request model to convert. + + Returns: + The converted schema. + """ + return cls( + name=model_request.name, + workspace_id=model_request.workspace, + user_id=model_request.user, + license=model_request.license, + description=model_request.description, + audience=model_request.audience, + use_cases=model_request.use_cases, + limitations=model_request.limitations, + trade_offs=model_request.trade_offs, + ethic=model_request.ethic, + tags=model_request.tags, + ) + + def to_model(self) -> ModelResponseModel: + """Convert an `ModelSchema` to an `ModelResponseModel`. + + Returns: + The created `ModelResponseModel`. + """ + return ModelResponseModel( + id=self.id, + name=self.name, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + license=self.license, + description=self.description, + audience=self.audience, + use_cases=self.use_cases, + limitations=self.limitations, + trade_offs=self.trade_offs, + ethic=self.ethic, + tags=self.tags, + ) + + def update( + self, + model_update: ModelUpdateModel, + ) -> "ModelSchema": + """Updates a `ModelSchema` from a `ModelUpdateModel`. + + Args: + model_update: The `ModelUpdateModel` to update from. + + Returns: + The updated `ModelSchema`. + """ + for field, value in model_update.dict(exclude_unset=True).items(): + setattr(self, field, value) + self.updated = datetime.utcnow() + return self + + +class ModelVersionSchema(BaseSchema, table=True): + """SQL Model for model version.""" + + __tablename__ = "model_version" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship( + back_populates="model_versions" + ) + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship( + back_populates="model_versions" + ) + + model_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelSchema.__tablename__, + source_column="model_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model: Optional["ModelSchema"] = Relationship( + back_populates="model_versions" + ) + + version: str = Field(sa_column=Column(TEXT, nullable=False)) + description: str = Field(sa_column=Column(TEXT, nullable=True)) + stage: str = Field(sa_column=Column(TEXT, nullable=True)) + use_cases: str = Field(sa_column=Column(TEXT, nullable=True)) + limitations: str = Field(sa_column=Column(TEXT, nullable=True)) + trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) + ethic: str = Field(sa_column=Column(TEXT, nullable=True)) + tags: str = Field(sa_column=Column(TEXT, nullable=True)) @classmethod def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index e9ea6aeb522..5a4a51040ac 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -28,6 +28,7 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, + ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, PipelineRunSchema, @@ -95,6 +96,9 @@ class UserSchema(NamedSchema, table=True): models: List["ModelSchema"] = Relationship( back_populates="user", ) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="user", + ) @classmethod def from_request(cls, model: UserRequestModel) -> "UserSchema": diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index 649848b3ca3..f7b3d0693ac 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -30,6 +30,7 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, + ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, PipelineRunSchema, @@ -121,6 +122,10 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) + model_versions: List["ModelVersionSchema"] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( From 59f47329fe4c6fdb53695d480126027fdd7ca0e2 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:03:49 +0200 Subject: [PATCH 12/40] work with client --- src/zenml/models/model_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 5eb2125e7e1..0efa47677bb 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -18,7 +18,6 @@ from pydantic import BaseModel, Field -from zenml.client import Client from zenml.model import ModelStages from zenml.models.artifact_models import ArtifactResponseModel from zenml.models.base_models import ( @@ -82,6 +81,8 @@ class ModelVersionResponseModel( def _fetch_artifacts_from_list( artifacts: Dict[str, UUID] ) -> Dict[str, ArtifactResponseModel]: + from zenml.client import Client + if artifacts: return { name: Client().get_artifact(a) for name, a in artifacts.items() @@ -103,7 +104,9 @@ def deployments(self) -> Dict[str, ArtifactResponseModel]: @property def pipeline_runs(self) -> List[PipelineRunResponseModel]: - return [Client().get_pipeline_run(pr) for pr in self._pipeline_runs] + from zenml.client import Client + + return [Client().get_run(pr) for pr in self._pipeline_runs] def set_stage(self, stage: ModelStages): """Sets Model Version to a desired stage.""" From 9222d21b626875c57fa2e58752b90ea489ca993e Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:05:11 +0200 Subject: [PATCH 13/40] handle tags --- src/zenml/zen_stores/schemas/model_schemas.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index b725d39b4bd..0f55a2c3f5a 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -14,6 +14,7 @@ """SQLModel implementation of model tables.""" +import json from datetime import datetime from typing import Optional from uuid import UUID @@ -87,7 +88,7 @@ def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": limitations=model_request.limitations, trade_offs=model_request.trade_offs, ethic=model_request.ethic, - tags=model_request.tags, + tags=json.dumps(model_request.tags), ) def to_model(self) -> ModelResponseModel: @@ -110,7 +111,7 @@ def to_model(self) -> ModelResponseModel: limitations=self.limitations, trade_offs=self.trade_offs, ethic=self.ethic, - tags=self.tags, + tags=json.loads(self.tags), ) def update( @@ -126,6 +127,9 @@ def update( The updated `ModelSchema`. """ for field, value in model_update.dict(exclude_unset=True).items(): - setattr(self, field, value) + if field == "tags": + setattr(self, field, json.dumps(value)) + else: + setattr(self, field, value) self.updated = datetime.utcnow() return self From 67c1286fb4e9f987456d7e3712f386898ee0866c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:22:51 +0200 Subject: [PATCH 14/40] fix integrations --- .../zen_server/routers/workspaces_endpoints.py | 2 +- src/zenml/zen_stores/rest_zen_store.py | 5 +++-- src/zenml/zen_stores/sql_zen_store.py | 16 ++++++++++++---- src/zenml/zen_stores/zen_store_interface.py | 5 +++-- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 5ec6fe72846..d8de89946ab 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1217,7 +1217,7 @@ def list_models( The models according to query filters. """ workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_filter_model.set_scope_workspace(workspace_id) return zen_store().list_models( - workspace_id=workspace_id, model_filter_model=model_filter_model, ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 9c4e8981432..709a3ea4b62 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2293,6 +2293,9 @@ def create_model(self, model: ModelRequestModel) -> ModelResponseModel: Returns: The newly created model. + + Raises: + EntityExistsError: If a workspace with the given name already exists. """ return self._create_workspace_scoped_resource( resource=model, @@ -2348,13 +2351,11 @@ def get_model( def list_models( self, - workspace_id: UUID, model_filter_model: ModelFilterModel, ) -> Page[ModelResponseModel]: """Get all models by filter. Args: - workspace_id: The name or ID of the workspace to scope to. model_filter_model: All filter parameters including pagination params. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index f3dd41865db..92dd71d150f 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5396,9 +5396,20 @@ def create_model(self, model: ModelRequestModel) -> ModelResponseModel: Returns: The newly created model. + + Raises: + EntityExistsError: If a workspace with the given name already exists. """ with Session(self.engine) as session: - # Save artifact. + existing_model = session.exec( + select(ModelSchema).where(ModelSchema.name == model.name) + ).first() + if existing_model is not None: + raise EntityExistsError( + f"Unable to create model {model.name}: " + "A model with this name already exists." + ) + model_schema = ModelSchema.from_request(model) session.add(model_schema) @@ -5432,13 +5443,11 @@ def get_model( def list_models( self, - workspace_id: UUID, model_filter_model: ModelFilterModel, ) -> Page[ModelResponseModel]: """Get all models by filter. Args: - workspace_id: The ID of the workspace to scope to. model_filter_model: All filter parameters including pagination params. @@ -5447,7 +5456,6 @@ def list_models( """ with Session(self.engine) as session: query = select(ModelSchema) - model_filter_model.set_scope_workspace(workspace_id) return self.filter_and_paginate( session=session, query=query, diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 658a99d79e4..f97dc27d628 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1684,6 +1684,9 @@ def create_model(self, model: ModelRequestModel) -> ModelResponseModel: Returns: The newly created model. + + Raises: + EntityExistsError: If a workspace with the given name already exists. """ @abstractmethod @@ -1726,13 +1729,11 @@ def get_model( @abstractmethod def list_models( self, - workspace_id: UUID, model_filter_model: ModelFilterModel, ) -> Page[ModelResponseModel]: """Get all models by filter. Args: - workspace_id: The name or ID of the workspace to scope to. model_filter_model: All filter parameters including pagination params. From 4a00124f4f689ab0ef1b29076be9de40068eca64 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:52:41 +0200 Subject: [PATCH 15/40] move list around --- .../zen_server/routers/models_endpoints.py | 32 ++++++++++++++++++- .../routers/workspaces_endpoints.py | 2 +- src/zenml/zen_stores/rest_zen_store.py | 2 +- src/zenml/zen_stores/schemas/model_schemas.py | 2 +- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 380f320b721..93c312497ad 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -16,18 +16,21 @@ from typing import Union from uuid import UUID -from fastapi import APIRouter, Security +from fastapi import APIRouter, Depends, Security from zenml.constants import API, MODELS, VERSION_1 from zenml.enums import PermissionType from zenml.models import ( + ModelFilterModel, ModelResponseModel, ModelUpdateModel, ) +from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response from zenml.zen_server.utils import ( handle_exceptions, + make_dependable, zen_store, ) @@ -38,6 +41,33 @@ ) +@router.get( + "", + response_model=Page[ModelResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_models( + model_filter_model: ModelFilterModel = Depends( + make_dependable(ModelFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelResponseModel]: + """Get models according to query filters. + + Args: + model_filter_model: Filter model used for pagination, sorting, + filtering + + + Returns: + The models according to query filters. + """ + return zen_store().list_models( + model_filter_model=model_filter_model, + ) + + @router.delete( "/{model_name_or_id}", responses={401: error_response, 404: error_response, 422: error_response}, diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index d8de89946ab..622dd1ee213 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1198,7 +1198,7 @@ def create_model( responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions -def list_models( +def list_workspace_models( workspace_name_or_id: Union[str, UUID], model_filter_model: ModelFilterModel = Depends( make_dependable(ModelFilterModel) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 709a3ea4b62..62c99ee2fc5 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2363,7 +2363,7 @@ def list_models( A page of all models. """ return self._list_paginated_resources( - route=f"{WORKSPACES}/{workspace_id}{MODELS}", + route=MODELS, response_model=ModelResponseModel, filter_model=model_filter_model, ) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 0f55a2c3f5a..b3a0c34f15e 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -111,7 +111,7 @@ def to_model(self) -> ModelResponseModel: limitations=self.limitations, trade_offs=self.trade_offs, ethic=self.ethic, - tags=json.loads(self.tags), + tags=json.loads(self.tags) if self.tags else None, ) def update( From 3df40d58a98e09cf63c39f8322d737ce4323a2af Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 18:10:03 +0200 Subject: [PATCH 16/40] update db schema --- src/zenml/models/__init__.py | 16 ++ src/zenml/models/model_models.py | 40 ++- ...8b82e9253a9_add_model_version_and_links.py | 114 +++++++++ src/zenml/zen_stores/schemas/__init__.py | 2 + .../zen_stores/schemas/artifact_schemas.py | 7 + src/zenml/zen_stores/schemas/model_schemas.py | 241 ++++++++++++++---- .../schemas/pipeline_run_schemas.py | 5 + src/zenml/zen_stores/schemas/user_schemas.py | 4 + .../zen_stores/schemas/workspace_schemas.py | 5 + 9 files changed, 384 insertions(+), 50 deletions(-) create mode 100644 src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 074ede5aec1..fc49834b365 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -157,6 +157,9 @@ ModelVersionBaseModel, ModelVersionResponseModel, ModelVersionRequestModel, + ModelVersionLinkBaseModel, + ModelVersionLinkRequestModel, + ModelVersionLinkResponseModel, ) ComponentResponseModel.update_forward_refs( @@ -297,6 +300,16 @@ WorkspaceResponseModel=WorkspaceResponseModel, ) +ModelVersionLinkRequestModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + +ModelVersionLinkResponseModel.update_forward_refs( + UserResponseModel=UserResponseModel, + WorkspaceResponseModel=WorkspaceResponseModel, +) + __all__ = [ "ArtifactRequestModel", "ArtifactResponseModel", @@ -400,4 +413,7 @@ "ModelVersionBaseModel", "ModelVersionRequestModel", "ModelVersionResponseModel", + "ModelVersionLinkBaseModel", + "ModelVersionLinkRequestModel", + "ModelVersionLinkResponseModel", ] diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 0efa47677bb..f988f5385f5 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -16,7 +16,7 @@ from typing import Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator from zenml.model import ModelStages from zenml.models.artifact_models import ArtifactResponseModel @@ -117,6 +117,44 @@ def set_stage(self, stage: ModelStages): # """Return HTML/PDF based on input template""" +class ModelVersionLinkBaseModel(BaseModel): + """Model version links base model.""" + + name: str = Field( + title="The name of the artifact inside model version.", + max_length=STR_FIELD_MAX_LENGTH, + ) + artifact_id: Optional[UUID] + pipeline_run_id: Optional[UUID] + model_version_id: UUID + is_model_object: bool = False + is_deployment: bool = False + + @validator("model_version_id") + def validate_links(cls, model_version_id, values): + artifact_id = values.get("artifact_id", None) + pipeline_run_id = values.get("pipeline_run_id", None) + if (artifact_id is None and pipeline_run_id is None) or ( + artifact_id is not None and pipeline_run_id is not None + ): + raise ValueError( + "You must provide only `artifact_id` or only `pipeline_run_id`." + ) + return model_version_id + + +class ModelVersionLinkRequestModel( + ModelVersionLinkBaseModel, WorkspaceScopedRequestModel +): + """Model version links request model.""" + + +class ModelVersionLinkResponseModel( + ModelVersionLinkBaseModel, WorkspaceScopedResponseModel +): + """Model version links response model.""" + + class ModelConfigBaseModel(BaseModel): """Model Config base model.""" diff --git a/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py b/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py new file mode 100644 index 00000000000..7dd1a6dabad --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py @@ -0,0 +1,114 @@ +"""add model_version and links [e8b82e9253a9]. + +Revision ID: e8b82e9253a9 +Revises: 3b68abe58f44 +Create Date: 2023-09-11 18:05:43.367994 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "e8b82e9253a9" +down_revision = "3b68abe58f44" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "model_version", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("version", sa.TEXT(), nullable=False), + sa.Column("description", sa.TEXT(), nullable=True), + sa.Column("stage", sa.TEXT(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["model_id"], + ["model.id"], + name="fk_model_version_model_id_model", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_version_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_version_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "model_version_links", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column( + "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("artifact_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column( + "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column("is_model_object", sa.BOOLEAN(), nullable=True), + sa.Column("is_deployment", sa.BOOLEAN(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.ForeignKeyConstraint( + ["artifact_id"], + ["artifact.id"], + name="fk_model_version_links_artifact_id_artifact", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["model_version_id"], + ["model_version.id"], + name="fk_model_version_links_model_version_id_model_version", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["pipeline_run_id"], + ["pipeline_run.id"], + name="fk_model_version_links_run_id_pipeline_run", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_version_links_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_version_links_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("model_version_links") + op.drop_table("model_version") + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index f1170ecf6f4..3dc9b8d29cc 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -60,6 +60,7 @@ from zenml.zen_stores.schemas.model_schemas import ( ModelSchema, ModelVersionSchema, + ModelVersionLinkSchema, ) __all__ = [ @@ -97,4 +98,5 @@ "LogsSchema", "ModelSchema", "ModelVersionSchema", + "ModelVersionLinkSchema", ] diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index d52c3c13248..c7c980bfad6 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -32,6 +32,9 @@ from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema if TYPE_CHECKING: + from zenml.zen_stores.schemas.model_schemas import ( + ModelVersionLinkSchema, + ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.step_run_schemas import ( StepRunInputArtifactSchema, @@ -94,6 +97,10 @@ class ArtifactSchema(NamedSchema, table=True): back_populates="artifact", sa_relationship_kwargs={"cascade": "delete"}, ) + model_version_links: List["ModelVersionLinkSchema"] = Relationship( + back_populates="artifact", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 27e717461af..71986236ba7 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -14,20 +14,25 @@ """SQLModel implementation of model tables.""" -import json from datetime import datetime from typing import List, Optional from uuid import UUID -from sqlalchemy import TEXT, Column +from sqlalchemy import BOOLEAN, TEXT, Column from sqlmodel import Field, Relationship from zenml.models import ( ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionLinkRequestModel, + ModelVersionLinkResponseModel, + ModelVersionRequestModel, + ModelVersionResponseModel, ) +from zenml.zen_stores.schemas.artifact_schemas import ArtifactSchema from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema +from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.user_schemas import UserSchema from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema @@ -173,82 +178,220 @@ class ModelVersionSchema(BaseSchema, table=True): ondelete="CASCADE", nullable=False, ) - model: Optional["ModelSchema"] = Relationship( - back_populates="model_versions" + model: "ModelSchema" = Relationship(back_populates="model_versions") + objects_links: List["ModelVersionLinkSchema"] = Relationship( + back_populates="model_version", + sa_relationship_kwargs={"cascade": "delete"}, ) version: str = Field(sa_column=Column(TEXT, nullable=False)) description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - use_cases: str = Field(sa_column=Column(TEXT, nullable=True)) - limitations: str = Field(sa_column=Column(TEXT, nullable=True)) - trade_offs: str = Field(sa_column=Column(TEXT, nullable=True)) - ethic: str = Field(sa_column=Column(TEXT, nullable=True)) - tags: str = Field(sa_column=Column(TEXT, nullable=True)) @classmethod - def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": - """Convert an `ModelRequestModel` to an `ModelSchema`. + def from_request( + cls, model_version_request: ModelVersionRequestModel + ) -> "ModelVersionSchema": + """Convert an `ModelVersionRequestModel` to an `ModelVersionSchema`. Args: - model_request: The request model to convert. + model_version_request: The request model version to convert. Returns: The converted schema. """ return cls( - name=model_request.name, - workspace_id=model_request.workspace, - user_id=model_request.user, - license=model_request.license, - description=model_request.description, - audience=model_request.audience, - use_cases=model_request.use_cases, - limitations=model_request.limitations, - trade_offs=model_request.trade_offs, - ethic=model_request.ethic, - tags=json.dumps(model_request.tags), + workspace_id=model_version_request.workspace, + user_id=model_version_request.user, + version=model_version_request.version, + description=model_version_request.description, + stage=model_version_request.stage, ) - def to_model(self) -> ModelResponseModel: - """Convert an `ModelSchema` to an `ModelResponseModel`. + def to_model(self) -> ModelVersionResponseModel: + """Convert an `ModelVersionSchema` to an `ModelVersionResponseModel`. Returns: - The created `ModelResponseModel`. + The created `ModelVersionResponseModel`. """ - return ModelResponseModel( + return ModelVersionResponseModel( id=self.id, - name=self.name, user=self.user.to_model() if self.user else None, workspace=self.workspace.to_model(), created=self.created, updated=self.updated, - license=self.license, + version=self.version, description=self.description, - audience=self.audience, - use_cases=self.use_cases, - limitations=self.limitations, - trade_offs=self.trade_offs, - ethic=self.ethic, - tags=json.loads(self.tags) if self.tags else None, + stage=self.stage, + _model_objects={ + al.name: al.artifact_id + for al in self.objects_links + if al.artifact_id is not None and al.is_model_object + }, + _deployments={ + al.name: al.artifact_id + for al in self.objects_links + if al.artifact_id is not None and al.is_deployment + }, + _artifact_objects={ + al.name: al.artifact_id + for al in self.objects_links + if al.artifact_id is not None + and not (al.is_deployment or al.is_model_object) + }, + _pipeline_runs=[ + al.artifact_id + for al in self.objects_links + if al.pipeline_run_id is not None + ], ) - def update( - self, - model_update: ModelUpdateModel, - ) -> "ModelSchema": - """Updates a `ModelSchema` from a `ModelUpdateModel`. + # def update( + # self, + # model_update: ModelUpdateModel, + # ) -> "ModelSchema": + # """Updates a `ModelSchema` from a `ModelUpdateModel`. + + # Args: + # model_update: The `ModelUpdateModel` to update from. + + # Returns: + # The updated `ModelSchema`. + # """ + # # for field, value in model_update.dict(exclude_unset=True).items(): + # # if field == "tags": + # # setattr(self, field, json.dumps(value)) + # # else: + # # setattr(self, field, value) + # # self.updated = datetime.utcnow() + # # return self + + +class ModelVersionLinkSchema(NamedSchema, table=True): + """SQL Model for linking of Model Versions and Artifacts or Pipeline Runs M:M.""" + + __tablename__ = "model_version_links" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship( + back_populates="model_version_links" + ) + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship( + back_populates="model_version_links" + ) + + model_version_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelVersionSchema.__tablename__, + source_column="model_version_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model_version: "ModelVersionSchema" = Relationship( + back_populates="objects_links" + ) + artifact_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=ArtifactSchema.__tablename__, + source_column="artifact_id", + target_column="id", + ondelete="CASCADE", + nullable=True, + ) + artifact: Optional["ArtifactSchema"] = Relationship( + back_populates="model_version_links" + ) + pipeline_run_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=PipelineRunSchema.__tablename__, + source_column="run_id", + target_column="id", + ondelete="CASCADE", + nullable=True, + ) + pipeline_run: Optional["PipelineRunSchema"] = Relationship( + back_populates="model_version_links" + ) + + is_model_object: bool = Field(sa_column=Column(BOOLEAN, nullable=True)) + is_deployment: bool = Field(sa_column=Column(BOOLEAN, nullable=True)) + + @classmethod + def from_request( + cls, model_version_artifact_request: ModelVersionLinkRequestModel + ) -> "ModelVersionLinkSchema": + """Convert an `ModelVersionArtifactRequestModel` to an `ModelVersionArtifactSchema`. Args: - model_update: The `ModelUpdateModel` to update from. + model_version_artifact_request: The request link to convert. Returns: - The updated `ModelSchema`. + The converted schema. """ - for field, value in model_update.dict(exclude_unset=True).items(): - if field == "tags": - setattr(self, field, json.dumps(value)) - else: - setattr(self, field, value) - self.updated = datetime.utcnow() - return self + return cls( + name=model_version_artifact_request.name, + workspace_id=model_version_artifact_request.workspace, + user_id=model_version_artifact_request.user, + model_version_id=model_version_artifact_request.model_version_id, + artifact_id=model_version_artifact_request.artifact_id, + pipeline_run_id=model_version_artifact_request.pipeline_run_id, + is_model_object=model_version_artifact_request.is_model_object, + is_deployment=model_version_artifact_request.is_deployment, + ) + + def to_model(self) -> ModelVersionLinkResponseModel: + """Convert an `ModelVersionArtifactSchema` to an `ModelVersionArtifactResponseModel`. + + Returns: + The created `ModelVersionArtifactResponseModel`. + """ + return ModelVersionLinkResponseModel( + id=self.id, + name=self.name, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + model_version_id=self.model_version_id, + artifact_id=self.artifact_id, + pipeline_run_id=self.pipeline_run_id, + is_model_object=self.is_model_object, + is_deployment=self.is_deployment, + ) + + # def update( + # self, + # model_update: ModelUpdateModel, + # ) -> "ModelSchema": + # """Updates a `ModelSchema` from a `ModelUpdateModel`. + + # Args: + # model_update: The `ModelUpdateModel` to update from. + + # Returns: + # The updated `ModelSchema`. + # """ + # for field, value in model_update.dict(exclude_unset=True).items(): + # if field == "tags": + # setattr(self, field, json.dumps(value)) + # else: + # setattr(self, field, value) + # self.updated = datetime.utcnow() + # return self diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 9cb8cbe33a0..54c74a5a77f 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: from zenml.zen_stores.schemas.logs_schemas import LogsSchema + from zenml.zen_stores.schemas.model_schemas import ModelVersionLinkSchema from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -157,6 +158,10 @@ class PipelineRunSchema(NamedSchema, table=True): back_populates="pipeline_run", sa_relationship_kwargs={"cascade": "delete", "uselist": False}, ) + model_version_links: List["ModelVersionLinkSchema"] = Relationship( + back_populates="pipeline_run", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 5a4a51040ac..e11c65b1cac 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -28,6 +28,7 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, + ModelVersionLinkSchema, ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -99,6 +100,9 @@ class UserSchema(NamedSchema, table=True): model_versions: List["ModelVersionSchema"] = Relationship( back_populates="user", ) + model_version_links: List["ModelVersionLinkSchema"] = Relationship( + back_populates="user", + ) @classmethod def from_request(cls, model: UserRequestModel) -> "UserSchema": diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index f7b3d0693ac..ff2647f67ad 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -30,6 +30,7 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, + ModelVersionLinkSchema, ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -126,6 +127,10 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) + model_version_links: List["ModelVersionLinkSchema"] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request( From 6bb853fddaf7a76de4f46ea7996e0e0ebd9dbb6f Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 11 Sep 2023 18:40:46 +0200 Subject: [PATCH 17/40] wip --- src/zenml/constants.py | 1 + src/zenml/models/__init__.py | 2 + src/zenml/models/model_models.py | 22 ++++ .../zen_server/routers/models_endpoints.py | 105 ++++++++++++++++++ src/zenml/zen_stores/rest_zen_store.py | 74 ++++++++++++ src/zenml/zen_stores/schemas/model_schemas.py | 42 +------ src/zenml/zen_stores/sql_zen_store.py | 102 +++++++++++++++++ src/zenml/zen_stores/zen_store_interface.py | 57 ++++++++++ 8 files changed, 365 insertions(+), 40 deletions(-) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 007aa1e7a37..f36ef7db778 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -222,6 +222,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_RESOURCES = "/resources" SERVICE_CONNECTOR_CLIENT = "/client" MODELS = "/models" +MODEL_VERSIONS = "/model_versions" # mandatory stack component attributes MANDATORY_COMPONENT_ATTRIBUTES = ["name", "uuid"] diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index fc49834b365..1f6f035b435 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -160,6 +160,7 @@ ModelVersionLinkBaseModel, ModelVersionLinkRequestModel, ModelVersionLinkResponseModel, + ModelVersionFilterModel, ) ComponentResponseModel.update_forward_refs( @@ -411,6 +412,7 @@ "ModelConfigRequestModel", "ModelConfigResponseModel", "ModelVersionBaseModel", + "ModelVersionFilterModel", "ModelVersionRequestModel", "ModelVersionResponseModel", "ModelVersionLinkBaseModel", diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index f988f5385f5..5a33f1dd868 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -32,6 +32,10 @@ class ModelVersionBaseModel(BaseModel): """Model Version base model.""" + model_id: str = Field( + title="The ID of the model", + max_length=STR_FIELD_MAX_LENGTH, + ) version: str = Field( title="The name of the model version", max_length=STR_FIELD_MAX_LENGTH, @@ -117,6 +121,24 @@ def set_stage(self, stage: ModelStages): # """Return HTML/PDF based on input template""" +class ModelVersionFilterModel(WorkspaceScopedFilterModel): + """Filter Model for Model Version.""" + + model_name: str = Field( + description="Name of the Model", + ) + model_version_name: Optional[str] = Field( + default=None, + description="Name of the Model Version", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="Workspace of the Model Version" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="User of the Model Version" + ) + + class ModelVersionLinkBaseModel(BaseModel): """Model version links base model.""" diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 93c312497ad..61e0d9785f8 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -34,6 +34,10 @@ zen_store, ) +######### +# Models +######### + router = APIRouter( prefix=API + VERSION_1 + MODELS, tags=["models"], @@ -130,3 +134,104 @@ def update_model( model_id=model_id, model_update=model_update, ) + + +################# +# Model Versions +################# + +# router = APIRouter( +# prefix=API + VERSION_1 + MODELS, +# tags=["models"], +# responses={401: error_response}, +# ) + +# @router.get( +# "", +# response_model=Page[ModelResponseModel], +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def list_models( +# model_filter_model: ModelFilterModel = Depends( +# make_dependable(ModelFilterModel) +# ), +# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +# ) -> Page[ModelResponseModel]: +# """Get models according to query filters. + +# Args: +# model_filter_model: Filter model used for pagination, sorting, +# filtering + + +# Returns: +# The models according to query filters. +# """ +# return zen_store().list_models( +# model_filter_model=model_filter_model, +# ) + + +# @router.delete( +# "/{model_name_or_id}", +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def delete_model( +# model_name_or_id: Union[str, UUID], +# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +# ) -> None: +# """Delete a model by name or ID. + +# Args: +# model_name_or_id: The name or ID of the model to delete. +# """ +# zen_store().delete_model(model_name_or_id) + + +# @router.get( +# "/{model_name_or_id}", +# response_model=ModelResponseModel, +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def get_model( +# model_name_or_id: Union[str, UUID], +# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +# ) -> ModelResponseModel: +# """Get a model by name or ID. + +# Args: +# model_name_or_id: The name or ID of the model to get. + +# Returns: +# The model with the given name or ID. +# """ +# return zen_store().get_model(model_name_or_id) + + +# @router.put( +# "/{model_id}", +# response_model=ModelResponseModel, +# responses={401: error_response, 404: error_response, 422: error_response}, +# ) +# @handle_exceptions +# def update_model( +# model_id: UUID, +# model_update: ModelUpdateModel, +# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +# ) -> ModelResponseModel: +# """Updates a model. + +# Args: +# model_id: Name of the stack. +# model_update: Stack to use for the update. + +# Returns: +# The updated model. +# """ +# return zen_store().update_model( +# model_id=model_id, +# model_update=model_update, +# ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 62c99ee2fc5..1a5a709ff68 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -49,6 +49,7 @@ GET_OR_CREATE, INFO, LOGIN, + MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, @@ -99,6 +100,9 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionFilterModel, + ModelVersionRequestModel, + ModelVersionResponseModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -2368,6 +2372,76 @@ def list_models( filter_model=model_filter_model, ) + ################# + # Model Versions + ################# + + def create_model_version( + self, model_version: ModelVersionRequestModel + ) -> ModelVersionResponseModel: + """Creates a new model version. + Args: + model: the Model Version to be created. + Returns: + The newly created model version. + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ + return self._create_workspace_scoped_resource( + resource=model_version, + response_model=ModelVersionResponseModel, + route=MODEL_VERSIONS, + ) + + def delete_model_version( + self, model_name_or_id: Union[str, UUID], model_version_name: str + ) -> None: + """Deletes a model version. + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name: name of the model version to be deleted. + """ + self._delete_resource( + resource_id=(model_name_or_id, model_version_name), + route=MODEL_VERSIONS, + ) + + def get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name: str, + ) -> ModelVersionResponseModel: + """Get an existing model version. + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be retrieved. + Returns: + The model version of interest. + """ + return self._get_resource( + resource_id=(model_name_or_id, model_version_name), + route=MODEL_VERSIONS, + response_model=ModelVersionResponseModel, + ) + + def list_model_versions( + self, + model_version_filter_model: ModelVersionFilterModel, + ) -> Page[ModelVersionResponseModel]: + """Get all model versions by filter. + Args: + model_version_filter_model: All filter parameters including pagination + params. + Returns: + A page of all model versions. + """ + + return self._list_paginated_resources( + route=MODEL_VERSIONS, + response_model=ModelVersionFilterModel, + filter_model=model_version_filter_model, + ) + # ======================= # Internal helper methods # ======================= diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 71986236ba7..42a36416f67 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -203,6 +203,7 @@ def from_request( return cls( workspace_id=model_version_request.workspace, user_id=model_version_request.user, + model_id=model_version_request.model_id, version=model_version_request.version, description=model_version_request.description, stage=model_version_request.stage, @@ -220,6 +221,7 @@ def to_model(self) -> ModelVersionResponseModel: workspace=self.workspace.to_model(), created=self.created, updated=self.updated, + model_id=self.model_id, version=self.version, description=self.description, stage=self.stage, @@ -246,26 +248,6 @@ def to_model(self) -> ModelVersionResponseModel: ], ) - # def update( - # self, - # model_update: ModelUpdateModel, - # ) -> "ModelSchema": - # """Updates a `ModelSchema` from a `ModelUpdateModel`. - - # Args: - # model_update: The `ModelUpdateModel` to update from. - - # Returns: - # The updated `ModelSchema`. - # """ - # # for field, value in model_update.dict(exclude_unset=True).items(): - # # if field == "tags": - # # setattr(self, field, json.dumps(value)) - # # else: - # # setattr(self, field, value) - # # self.updated = datetime.utcnow() - # # return self - class ModelVersionLinkSchema(NamedSchema, table=True): """SQL Model for linking of Model Versions and Artifacts or Pipeline Runs M:M.""" @@ -375,23 +357,3 @@ def to_model(self) -> ModelVersionLinkResponseModel: is_model_object=self.is_model_object, is_deployment=self.is_deployment, ) - - # def update( - # self, - # model_update: ModelUpdateModel, - # ) -> "ModelSchema": - # """Updates a `ModelSchema` from a `ModelUpdateModel`. - - # Args: - # model_update: The `ModelUpdateModel` to update from. - - # Returns: - # The updated `ModelSchema`. - # """ - # for field, value in model_update.dict(exclude_unset=True).items(): - # if field == "tags": - # setattr(self, field, json.dumps(value)) - # else: - # setattr(self, field, value) - # self.updated = datetime.utcnow() - # return self diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 92dd71d150f..4da73269615 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -98,6 +98,9 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionFilterModel, + ModelVersionRequestModel, + ModelVersionResponseModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -194,6 +197,7 @@ FlavorSchema, IdentitySchema, ModelSchema, + ModelVersionSchema, NamedSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -5516,3 +5520,101 @@ def update_model( # Refresh the Model that was just created session.refresh(existing_model) return existing_model.to_model() + + ################# + # Model Versions + ################# + + def create_model_version( + self, model_version: ModelVersionRequestModel + ) -> ModelVersionResponseModel: + """Creates a new model version. + Args: + model: the Model Version to be created. + Returns: + The newly created model version. + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ + with Session(self.engine) as session: + model = self.get_model(model_version.model_id) + existing_model_version = session.exec( + select(ModelVersionSchema).where( + ModelVersionSchema.version == model_version.version + and ModelVersionSchema.model_id == model.id + ) + ).first() + if existing_model_version is not None: + raise EntityExistsError( + f"Unable to create model version {model_version.version}: " + f"A model version with this name already exists in {model.name} model." + ) + + model_version_schema = ModelVersionSchema.from_request( + model_version + ) + session.add(model_version_schema) + + session.commit() + return ModelVersionSchema.to_model(model_version_schema) + + def get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name: str, + ) -> ModelVersionResponseModel: + """Get an existing model version. + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be retrieved. + Returns: + The model version of interest. + """ + with Session(self.engine) as session: + model = self.get_model(model_name_or_id) + model_version = session.exec( + select(ModelVersionSchema).where( + ModelVersionSchema.version == model_version_name + and ModelVersionSchema.model_id == model + ) + ).first() + return ModelVersionSchema.to_model(model_version) + + def list_model_versions( + self, + model_version_filter_model: ModelVersionFilterModel, + ) -> Page[ModelVersionResponseModel]: + """Get all model versions by filter. + Args: + model_version_filter_model: All filter parameters including pagination + params. + Returns: + A page of all model versions. + """ + with Session(self.engine) as session: + query = select(ModelVersionSchema) + return self.filter_and_paginate( + session=session, + query=query, + table=ModelVersionSchema, + filter_model=model_version_filter_model, + ) + + def delete_model_version( + self, model_name_or_id: Union[str, UUID], model_version_name: str + ) -> None: + """Deletes a model version. + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name: name of the model version to be deleted. + """ + with Session(self.engine) as session: + model = self.get_model(model_name_or_id) + model_version = session.exec( + select(ModelVersionSchema).where( + ModelVersionSchema.version == model_version_name + and ModelVersionSchema.model_id == model + ) + ) + session.delete(model_version) + session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index f97dc27d628..fc95ce37492 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -36,6 +36,9 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionFilterModel, + ModelVersionRequestModel, + ModelVersionResponseModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1740,3 +1743,57 @@ def list_models( Returns: A page of all models. """ + + ################# + # Model Versions + ################# + + @abstractmethod + def create_model_version( + self, model_version: ModelVersionRequestModel + ) -> ModelVersionResponseModel: + """Creates a new model version. + Args: + model: the Model Version to be created. + Returns: + The newly created model version. + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ + + @abstractmethod + def delete_model_version( + self, model_name_or_id: Union[str, UUID], model_version_name: str + ) -> None: + """Deletes a model version. + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name: name of the model version to be deleted. + """ + + @abstractmethod + def get_model_version( + self, + model_name_or_id: Union[str, UUID], + model_version_name: str, + ) -> ModelVersionResponseModel: + """Get an existing model version. + Args: + model_name_or_id: name or id of the model containing the model version. + model_version_name_or_id: name or id of the model version to be retrieved. + Returns: + The model version of interest. + """ + + @abstractmethod + def list_model_versions( + self, + model_version_filter_model: ModelVersionFilterModel, + ) -> Page[ModelVersionResponseModel]: + """Get all model versions by filter. + Args: + model_version_filter_model: All filter parameters including pagination + params. + Returns: + A page of all model versions. + """ From f8764a707d5b75a0f752a1aea51b48c9f952b7b0 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:28:26 +0200 Subject: [PATCH 18/40] lint --- src/zenml/zen_stores/rest_zen_store.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 62c99ee2fc5..32349e7a7e2 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2293,9 +2293,6 @@ def create_model(self, model: ModelRequestModel) -> ModelResponseModel: Returns: The newly created model. - - Raises: - EntityExistsError: If a workspace with the given name already exists. """ return self._create_workspace_scoped_resource( resource=model, From 1205d175b85d7af2e12c28b54ec3b67ee5172b76 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 09:29:43 +0200 Subject: [PATCH 19/40] sync with model branch --- src/zenml/zen_stores/schemas/model_schemas.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 42a36416f67..ffd4954d010 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -14,6 +14,7 @@ """SQLModel implementation of model tables.""" +import json from datetime import datetime from typing import List, Optional from uuid import UUID @@ -97,7 +98,9 @@ def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": limitations=model_request.limitations, trade_offs=model_request.trade_offs, ethic=model_request.ethic, - tags=model_request.tags, + tags=json.dumps(model_request.tags) + if model_request.tags + else None, ) def to_model(self) -> ModelResponseModel: @@ -120,7 +123,7 @@ def to_model(self) -> ModelResponseModel: limitations=self.limitations, trade_offs=self.trade_offs, ethic=self.ethic, - tags=self.tags, + tags=json.loads(self.tags) if self.tags else None, ) def update( @@ -136,7 +139,10 @@ def update( The updated `ModelSchema`. """ for field, value in model_update.dict(exclude_unset=True).items(): - setattr(self, field, value) + if field == "tags": + setattr(self, field, json.dumps(value)) + else: + setattr(self, field, value) self.updated = datetime.utcnow() return self From e4c5ee0d42cebd83334e387f324e0ce1d3575f81 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 13:11:18 +0200 Subject: [PATCH 20/40] wip --- .gitignore | 2 + src/zenml/models/model_models.py | 22 +-- .../zen_server/routers/models_endpoints.py | 170 ++++++++---------- .../routers/workspaces_endpoints.py | 86 +++++++++ src/zenml/zen_stores/rest_zen_store.py | 18 +- src/zenml/zen_stores/schemas/model_schemas.py | 10 +- src/zenml/zen_stores/sql_zen_store.py | 40 +++-- src/zenml/zen_stores/zen_store_interface.py | 6 +- .../functional/zen_stores/test_zen_store.py | 152 ++++++++++++++++ .../functional/zen_stores/utils.py | 48 +++++ 10 files changed, 418 insertions(+), 136 deletions(-) diff --git a/.gitignore b/.gitignore index ed067acc544..44c0ca1aa7a 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,5 @@ zenml_tutorial/ # script for testing mlstacks_reset.sh + +.local/ \ No newline at end of file diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 5a33f1dd868..f7dbb72162e 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -32,10 +32,6 @@ class ModelVersionBaseModel(BaseModel): """Model Version base model.""" - model_id: str = Field( - title="The ID of the model", - max_length=STR_FIELD_MAX_LENGTH, - ) version: str = Field( title="The name of the model version", max_length=STR_FIELD_MAX_LENGTH, @@ -72,7 +68,9 @@ class ModelVersionRequestModel( ): """Model Version request model.""" - pass + model_id: UUID = Field( + title="The ID of the model containing version", + ) class ModelVersionResponseModel( @@ -81,6 +79,10 @@ class ModelVersionResponseModel( ): """Model Version response model.""" + model: "ModelResponseModel" = Field( + title="The model containing version", + ) + @staticmethod def _fetch_artifacts_from_list( artifacts: Dict[str, UUID] @@ -124,18 +126,18 @@ def set_stage(self, stage: ModelStages): class ModelVersionFilterModel(WorkspaceScopedFilterModel): """Filter Model for Model Version.""" - model_name: str = Field( - description="Name of the Model", + model_id: Optional[Union[str, UUID]] = Field( + description="The ID of the Model", ) model_version_name: Optional[str] = Field( default=None, - description="Name of the Model Version", + description="The name of the Model Version", ) workspace_id: Optional[Union[UUID, str]] = Field( - default=None, description="Workspace of the Model Version" + default=None, description="The workspace of the Model Version" ) user_id: Optional[Union[UUID, str]] = Field( - default=None, description="User of the Model Version" + default=None, description="The user of the Model Version" ) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 61e0d9785f8..270167a04ec 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -18,12 +18,14 @@ from fastapi import APIRouter, Depends, Security -from zenml.constants import API, MODELS, VERSION_1 +from zenml.constants import API, MODEL_VERSIONS, MODELS, VERSION_1 from zenml.enums import PermissionType from zenml.models import ( ModelFilterModel, ModelResponseModel, ModelUpdateModel, + ModelVersionFilterModel, + ModelVersionResponseModel, ) from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize @@ -140,98 +142,74 @@ def update_model( # Model Versions ################# -# router = APIRouter( -# prefix=API + VERSION_1 + MODELS, -# tags=["models"], -# responses={401: error_response}, -# ) - -# @router.get( -# "", -# response_model=Page[ModelResponseModel], -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def list_models( -# model_filter_model: ModelFilterModel = Depends( -# make_dependable(ModelFilterModel) -# ), -# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -# ) -> Page[ModelResponseModel]: -# """Get models according to query filters. - -# Args: -# model_filter_model: Filter model used for pagination, sorting, -# filtering - - -# Returns: -# The models according to query filters. -# """ -# return zen_store().list_models( -# model_filter_model=model_filter_model, -# ) - - -# @router.delete( -# "/{model_name_or_id}", -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def delete_model( -# model_name_or_id: Union[str, UUID], -# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -# ) -> None: -# """Delete a model by name or ID. - -# Args: -# model_name_or_id: The name or ID of the model to delete. -# """ -# zen_store().delete_model(model_name_or_id) - - -# @router.get( -# "/{model_name_or_id}", -# response_model=ModelResponseModel, -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def get_model( -# model_name_or_id: Union[str, UUID], -# _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -# ) -> ModelResponseModel: -# """Get a model by name or ID. - -# Args: -# model_name_or_id: The name or ID of the model to get. - -# Returns: -# The model with the given name or ID. -# """ -# return zen_store().get_model(model_name_or_id) - - -# @router.put( -# "/{model_id}", -# response_model=ModelResponseModel, -# responses={401: error_response, 404: error_response, 422: error_response}, -# ) -# @handle_exceptions -# def update_model( -# model_id: UUID, -# model_update: ModelUpdateModel, -# _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -# ) -> ModelResponseModel: -# """Updates a model. - -# Args: -# model_id: Name of the stack. -# model_update: Stack to use for the update. - -# Returns: -# The updated model. -# """ -# return zen_store().update_model( -# model_id=model_id, -# model_update=model_update, -# ) + +@router.get( + "/{model_name_or_id}" + MODEL_VERSIONS, + response_model=Page[ModelVersionResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_versions( + model_version_filter_model: ModelVersionFilterModel = Depends( + make_dependable(ModelVersionFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionResponseModel]: + """Get model versions according to query filters. + + Args: + model_version_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model versions according to query filters. + """ + return zen_store().list_model_versions( + model_version_filter_model=model_version_filter_model, + ) + + +@router.delete( + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model containing version. + model_version_name_or_id: The name or ID of the model version to delete. + """ + zen_store().delete_model_version( + model_name_or_id, model_version_name_or_id + ) + + +@router.get( + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_model_version( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> ModelVersionResponseModel: + """Get a model version by name or ID. + + Args: + model_name_or_id: The name or ID of the model containing version. + model_version_name_or_id: The name or ID of the model version to get. + + Returns: + The model version with the given name or ID. + """ + return zen_store().get_model_version( + model_name_or_id, model_version_name_or_id + ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 622dd1ee213..49486c1a74f 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -21,6 +21,7 @@ API, CODE_REPOSITORIES, GET_OR_CREATE, + MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, @@ -51,6 +52,9 @@ ModelFilterModel, ModelRequestModel, ModelResponseModel, + ModelVersionFilterModel, + ModelVersionRequestModel, + ModelVersionResponseModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1221,3 +1225,85 @@ def list_workspace_models( return zen_store().list_models( model_filter_model=model_filter_model, ) + + +@router.post( + WORKSPACES + + "/{workspace_name_or_id}" + + MODELS + + "/{model_name_or_id}" + + MODEL_VERSIONS, + response_model=ModelVersionResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version: ModelVersionRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelVersionResponseModel: + """Create a new model version. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version: The model to create. + auth_context: Authentication context. + + Returns: + The created model version. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model version does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model_version.workspace != workspace.id: + raise IllegalOperationError( + "Creating model versions outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model_version.user != auth_context.user.id: + raise IllegalOperationError( + "Creating models for a user other than yourself " + "is not supported." + ) + mv = zen_store().create_model_version(model_version) + return mv + + +@router.get( + WORKSPACES + "/{workspace_name_or_id}" + MODEL_VERSIONS, + response_model=Page[ModelVersionResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_workspace_model_versions( + workspace_name_or_id: Union[str, UUID], + model_version_filter_model: ModelVersionFilterModel = Depends( + make_dependable(ModelVersionFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionResponseModel]: + """Get models according to query filters. + + Args: + workspace_name_or_id: Name or ID of the workspace. + model_version_filter_model: Filter model used for pagination, sorting, + filtering + + + Returns: + The model versions according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_version_filter_model.set_scope_workspace(workspace_id) + return zen_store().list_model_versions( + model_version_filter_model=model_version_filter_model, + ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 16c8aef6a20..de09f565f3b 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2378,16 +2378,14 @@ def create_model_version( ) -> ModelVersionResponseModel: """Creates a new model version. Args: - model: the Model Version to be created. + model_version: the Model Version to be created. Returns: The newly created model version. - Raises: - EntityExistsError: If a workspace with the given name already exists. """ return self._create_workspace_scoped_resource( resource=model_version, response_model=ModelVersionResponseModel, - route=MODEL_VERSIONS, + route=f"{MODELS}/{model_version.model_id}{MODEL_VERSIONS}", ) def delete_model_version( @@ -2399,8 +2397,8 @@ def delete_model_version( model_version_name: name of the model version to be deleted. """ self._delete_resource( - resource_id=(model_name_or_id, model_version_name), - route=MODEL_VERSIONS, + resource_id=model_version_name, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", ) def get_model_version( @@ -2416,8 +2414,8 @@ def get_model_version( The model version of interest. """ return self._get_resource( - resource_id=(model_name_or_id, model_version_name), - route=MODEL_VERSIONS, + resource_id=model_version_name, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", response_model=ModelVersionResponseModel, ) @@ -2434,8 +2432,8 @@ def list_model_versions( """ return self._list_paginated_resources( - route=MODEL_VERSIONS, - response_model=ModelVersionFilterModel, + route=f"{MODELS}/{model_version_filter_model.model_id}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, filter_model=model_version_filter_model, ) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index ffd4954d010..a82f8ec948a 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -227,27 +227,27 @@ def to_model(self) -> ModelVersionResponseModel: workspace=self.workspace.to_model(), created=self.created, updated=self.updated, - model_id=self.model_id, + model=self.model.to_model(), version=self.version, description=self.description, stage=self.stage, - _model_objects={ + model_objects={ al.name: al.artifact_id for al in self.objects_links if al.artifact_id is not None and al.is_model_object }, - _deployments={ + deployments={ al.name: al.artifact_id for al in self.objects_links if al.artifact_id is not None and al.is_deployment }, - _artifact_objects={ + artifact_objects={ al.name: al.artifact_id for al in self.objects_links if al.artifact_id is not None and not (al.is_deployment or al.is_model_object) }, - _pipeline_runs=[ + pipeline_runs=[ al.artifact_id for al in self.objects_links if al.pipeline_run_id is not None diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 4da73269615..8ed49eed2d8 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5530,7 +5530,7 @@ def create_model_version( ) -> ModelVersionResponseModel: """Creates a new model version. Args: - model: the Model Version to be created. + model_version: the Model Version to be created. Returns: The newly created model version. Raises: @@ -5539,10 +5539,9 @@ def create_model_version( with Session(self.engine) as session: model = self.get_model(model_version.model_id) existing_model_version = session.exec( - select(ModelVersionSchema).where( - ModelVersionSchema.version == model_version.version - and ModelVersionSchema.model_id == model.id - ) + select(ModelVersionSchema) + .where(ModelVersionSchema.model_id == model.id) + .where(ModelVersionSchema.version == model_version.version) ).first() if existing_model_version is not None: raise EntityExistsError( @@ -5556,7 +5555,8 @@ def create_model_version( session.add(model_version_schema) session.commit() - return ModelVersionSchema.to_model(model_version_schema) + mv = ModelVersionSchema.to_model(model_version_schema) + return mv def get_model_version( self, @@ -5569,15 +5569,21 @@ def get_model_version( model_version_name_or_id: name or id of the model version to be retrieved. Returns: The model version of interest. + Raises: + KeyError: specified ID or name not found. """ with Session(self.engine) as session: model = self.get_model(model_name_or_id) model_version = session.exec( - select(ModelVersionSchema).where( - ModelVersionSchema.version == model_version_name - and ModelVersionSchema.model_id == model - ) + select(ModelVersionSchema) + .where(ModelVersionSchema.model_id == model.id) + .where(ModelVersionSchema.version == model_version_name) ).first() + if model_version is None: + raise KeyError( + f"Unable to get model version with name `{model_version_name}`: " + f"No model version with this name found." + ) return ModelVersionSchema.to_model(model_version) def list_model_versions( @@ -5607,14 +5613,20 @@ def delete_model_version( Args: model_name_or_id: name or id of the model containing the model version. model_version_name: name of the model version to be deleted. + Raises: + KeyError: specified ID or name not found. """ with Session(self.engine) as session: model = self.get_model(model_name_or_id) model_version = session.exec( - select(ModelVersionSchema).where( - ModelVersionSchema.version == model_version_name - and ModelVersionSchema.model_id == model + select(ModelVersionSchema) + .where(ModelVersionSchema.model_id == model.id) + .where(ModelVersionSchema.version == model_version_name) + ).first() + if model_version is None: + raise KeyError( + f"Unable to delete model version with name `{model_version_name}`: " + f"No model version with this name found." ) - ) session.delete(model_version) session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index fc95ce37492..d1a1d2562b3 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1754,7 +1754,7 @@ def create_model_version( ) -> ModelVersionResponseModel: """Creates a new model version. Args: - model: the Model Version to be created. + model_version: the Model Version to be created. Returns: The newly created model version. Raises: @@ -1769,6 +1769,8 @@ def delete_model_version( Args: model_name_or_id: name or id of the model containing the model version. model_version_name: name of the model version to be deleted. + Raises: + KeyError: specified ID or name not found. """ @abstractmethod @@ -1783,6 +1785,8 @@ def get_model_version( model_version_name_or_id: name or id of the model version to be retrieved. Returns: The model version of interest. + Raises: + KeyError: specified ID or name not found. """ @abstractmethod diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index b8c11ddf16f..a333163319d 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -16,6 +16,7 @@ from contextlib import ExitStack as does_not_raise from datetime import datetime from typing import Dict, List, Optional, Tuple +from uuid import uuid4 import pytest from pydantic import SecretStr @@ -25,6 +26,7 @@ CodeRepositoryContext, ComponentContext, CrudTestConfig, + ModelVersionContext, PipelineRunContext, RoleContext, ServiceConnectorContext, @@ -50,6 +52,8 @@ ArtifactFilterModel, ComponentFilterModel, ComponentUpdateModel, + ModelVersionFilterModel, + ModelVersionRequestModel, PipelineRunFilterModel, RoleFilterModel, RoleRequestModel, @@ -2423,3 +2427,151 @@ def test_connector_validation(): secrets=secrets, ): pass + + +################# +# Models +################# + + +def test_model_version_create_pass(): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="great one", + ) + ) + + +def test_model_version_create_duplicated(): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="great one", + ) + ) + with pytest.raises(EntityExistsError): + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="great one", + ) + ) + + +def test_model_version_create_no_model(): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=uuid4(), + version="great one", + ) + ) + + +def test_model_version_get_not_found(): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, model_version_name="1.0.0" + ) + + +def test_model_version_get_found(): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="great one", + ) + ) + zs.get_model_version( + model_name_or_id=model.id, + model_version_name="great one", + ) + + +def test_model_version_list_empty(): + with ModelVersionContext() as model: + zs = Client().zen_store + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) + ) + assert len(mvs) == 0 + + +def test_model_version_list_not_empty(): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="and yet another one", + ) + ) + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) + ) + assert len(mvs) == 2 + assert mv1 in mvs + assert mv2 in mvs + + +def test_model_version_delete_not_found(): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version( + model_name_or_id=model.id, + model_version_name="1.0.0", + ) + + +def test_model_version_delete_found(): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model_id=model.id, + version="great one", + ) + ) + zs.delete_model_version( + model_name_or_id=model.id, + model_version_name="great one", + ) + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, + model_version_name="great one", + ) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 89a3f3f87b2..bf6bd15cf3a 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -509,6 +509,53 @@ def __exit__(self, exc_type, exc_value, exc_traceback): pass +class ModelVersionContext: + def __init__(self): + self.workspace = "workspace" + self.user = "su" + self.model = "su_model" + self.del_ws = False + self.del_user = False + self.del_model = False + + def __enter__(self): + zs = Client().zen_store + try: + ws = zs.get_workspace(self.workspace) + except: + ws = zs.create_workspace( + WorkspaceRequestModel(name=self.workspace) + ) + self.del_ws = True + try: + user = zs.get_user(self.user) + except: + user = zs.create_user(UserRequestModel(name=self.user)) + self.del_user = True + try: + model = zs.get_model(self.model) + except: + model = zs.create_model( + ModelRequestModel( + name=self.model, user=user.id, workspace=ws.id + ) + ) + self.del_model = True + return model + + def __exit__(self, exc_type, exc_value, exc_traceback): + zs = Client().zen_store + if self.del_model: + print("del_model") + zs.delete_model(self.model) + if self.del_user: + print("del_user") + zs.delete_user(self.user) + if self.del_ws: + print("del_ws") + zs.delete_workspace(self.workspace) + + class CatClawMarks(AuthenticationConfig): """Cat claw marks authentication credentials.""" @@ -837,6 +884,7 @@ def update_method( entity_name="model", ) + # step_run_crud_test_config = CrudTestConfig( # create_model=StepRunRequestModel( # name=sample_name("sample_step_run"), From 13159a411dd4d0046ce1d2171cc9eff4bd584411 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 13:11:48 +0200 Subject: [PATCH 21/40] refactor --- src/zenml/models/model_models.py | 2 +- src/zenml/zen_stores/rest_zen_store.py | 2 +- src/zenml/zen_stores/schemas/model_schemas.py | 2 +- src/zenml/zen_stores/sql_zen_store.py | 2 +- .../functional/zen_stores/test_zen_store.py | 16 ++++++++-------- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index f7dbb72162e..01b655f746f 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -68,7 +68,7 @@ class ModelVersionRequestModel( ): """Model Version request model.""" - model_id: UUID = Field( + model: UUID = Field( title="The ID of the model containing version", ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index de09f565f3b..056fb41447f 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2385,7 +2385,7 @@ def create_model_version( return self._create_workspace_scoped_resource( resource=model_version, response_model=ModelVersionResponseModel, - route=f"{MODELS}/{model_version.model_id}{MODEL_VERSIONS}", + route=f"{MODELS}/{model_version.model}{MODEL_VERSIONS}", ) def delete_model_version( diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index a82f8ec948a..3922dfda18c 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -209,7 +209,7 @@ def from_request( return cls( workspace_id=model_version_request.workspace, user_id=model_version_request.user, - model_id=model_version_request.model_id, + model_id=model_version_request.model, version=model_version_request.version, description=model_version_request.description, stage=model_version_request.stage, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 8ed49eed2d8..da97c74b852 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5537,7 +5537,7 @@ def create_model_version( EntityExistsError: If a workspace with the given name already exists. """ with Session(self.engine) as session: - model = self.get_model(model_version.model_id) + model = self.get_model(model_version.model) existing_model_version = session.exec( select(ModelVersionSchema) .where(ModelVersionSchema.model_id == model.id) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index a333163319d..178d7905ea1 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2441,7 +2441,7 @@ def test_model_version_create_pass(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="great one", ) ) @@ -2454,7 +2454,7 @@ def test_model_version_create_duplicated(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="great one", ) ) @@ -2463,7 +2463,7 @@ def test_model_version_create_duplicated(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="great one", ) ) @@ -2477,7 +2477,7 @@ def test_model_version_create_no_model(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=uuid4(), + model=uuid4(), version="great one", ) ) @@ -2499,7 +2499,7 @@ def test_model_version_get_found(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="great one", ) ) @@ -2525,7 +2525,7 @@ def test_model_version_list_not_empty(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="great one", ) ) @@ -2533,7 +2533,7 @@ def test_model_version_list_not_empty(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="and yet another one", ) ) @@ -2562,7 +2562,7 @@ def test_model_version_delete_found(): ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model_id=model.id, + model=model.id, version="great one", ) ) From 7907b4e3b9af06df30088eecef77cc2852df5b79 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:34:49 +0200 Subject: [PATCH 22/40] add stage transition --- src/zenml/model/model_stages.py | 4 +- src/zenml/models/__init__.py | 2 + src/zenml/models/model_models.py | 39 ++++- .../zen_server/routers/models_endpoints.py | 26 ++++ src/zenml/zen_stores/rest_zen_store.py | 31 +++- src/zenml/zen_stores/sql_zen_store.py | 68 +++++++++ src/zenml/zen_stores/zen_store_interface.py | 22 ++- .../functional/zen_stores/test_zen_store.py | 143 +++++++++++++++++- 8 files changed, 320 insertions(+), 15 deletions(-) diff --git a/src/zenml/model/model_stages.py b/src/zenml/model/model_stages.py index 5c53793ce4d..f1bb7192985 100644 --- a/src/zenml/model/model_stages.py +++ b/src/zenml/model/model_stages.py @@ -17,8 +17,10 @@ class ModelStages(StrEnum): + """All possible stages of a Model Version.""" + NONE = "none" - STAGING = "starting" + STAGING = "staging" PRODUCTION = "production" ARCHIVED = "archived" # technical stages diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index 1f6f035b435..c61dd78aef6 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -161,6 +161,7 @@ ModelVersionLinkRequestModel, ModelVersionLinkResponseModel, ModelVersionFilterModel, + ModelVersionUpdateModel, ) ComponentResponseModel.update_forward_refs( @@ -415,6 +416,7 @@ "ModelVersionFilterModel", "ModelVersionRequestModel", "ModelVersionResponseModel", + "ModelVersionUpdateModel", "ModelVersionLinkBaseModel", "ModelVersionLinkRequestModel", "ModelVersionLinkResponseModel", diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 01b655f746f..d76df6bfbf3 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -98,25 +98,38 @@ def _fetch_artifacts_from_list( @property def model_objects(self) -> Dict[str, ArtifactResponseModel]: + """Get all model objects linked to this version.""" return self._fetch_artifacts_from_list(self._model_objects) @property def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: + """Get all artifacts linked to this version.""" return self._fetch_artifacts_from_list(self._artifact_objects) @property def deployments(self) -> Dict[str, ArtifactResponseModel]: + """Get all deployments linked to this version.""" return self._fetch_artifacts_from_list(self._deployments) @property def pipeline_runs(self) -> List[PipelineRunResponseModel]: + """Get all pipeline runs linked to this version.""" from zenml.client import Client - return [Client().get_run(pr) for pr in self._pipeline_runs] + return [Client().get_pipeline_run(pr) for pr in self._pipeline_runs] - def set_stage(self, stage: ModelStages): - """Sets Model Version to a desired stage.""" - pass + def set_stage(self, stage: ModelStages, force: bool = False): + """Sets this Model Version to a desired stage.""" + from zenml.client import Client + + return Client().zen_store.update_model_version( + model_version_id=self.id, + model_version_update_model=ModelVersionUpdateModel( + model=self.model.id, + stage=stage, + force=force, + ), + ) # TODO in https://zenml.atlassian.net/browse/OSS-2433 # def generate_model_card(self, template_name: str) -> str: @@ -126,10 +139,10 @@ def set_stage(self, stage: ModelStages): class ModelVersionFilterModel(WorkspaceScopedFilterModel): """Filter Model for Model Version.""" - model_id: Optional[Union[str, UUID]] = Field( + model: Optional[Union[str, UUID]] = Field( description="The ID of the Model", ) - model_version_name: Optional[str] = Field( + version: Optional[str] = Field( default=None, description="The name of the Model Version", ) @@ -141,6 +154,20 @@ class ModelVersionFilterModel(WorkspaceScopedFilterModel): ) +class ModelVersionUpdateModel(BaseModel): + model: UUID = Field( + title="The ID of the model containing version", + ) + stage: ModelStages = Field( + title="Target model version stage to be set", + ) + force: bool = Field( + title="Whether existing model version in target stage should be silently archived " + "or an error should be raised.", + default=False, + ) + + class ModelVersionLinkBaseModel(BaseModel): """Model version links base model.""" diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 270167a04ec..3bd87623dd0 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -26,6 +26,7 @@ ModelUpdateModel, ModelVersionFilterModel, ModelVersionResponseModel, + ModelVersionUpdateModel, ) from zenml.models.page_model import Page from zenml.zen_server.auth import AuthContext, authorize @@ -213,3 +214,28 @@ def get_model_version( return zen_store().get_model_version( model_name_or_id, model_version_name_or_id ) + + +@router.put( + "/{model_id}" + MODEL_VERSIONS + "/{model_version_id}", + response_model=ModelVersionResponseModel, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_model( + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> ModelVersionResponseModel: + """Get all model versions by filter. + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + """ + return zen_store().update_model_version( + model_version_id=model_version_id, + model_version_update_model=model_version_update_model, + ) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 056fb41447f..63a75996b2e 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -103,6 +103,7 @@ ModelVersionFilterModel, ModelVersionRequestModel, ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -2377,8 +2378,10 @@ def create_model_version( self, model_version: ModelVersionRequestModel ) -> ModelVersionResponseModel: """Creates a new model version. + Args: model_version: the Model Version to be created. + Returns: The newly created model version. """ @@ -2392,6 +2395,7 @@ def delete_model_version( self, model_name_or_id: Union[str, UUID], model_version_name: str ) -> None: """Deletes a model version. + Args: model_name_or_id: name or id of the model containing the model version. model_version_name: name of the model version to be deleted. @@ -2407,9 +2411,11 @@ def get_model_version( model_version_name: str, ) -> ModelVersionResponseModel: """Get an existing model version. + Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_id: name or id of the model version to be retrieved. + Returns: The model version of interest. """ @@ -2424,19 +2430,42 @@ def list_model_versions( model_version_filter_model: ModelVersionFilterModel, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. + Args: model_version_filter_model: All filter parameters including pagination params. + Returns: A page of all model versions. """ return self._list_paginated_resources( - route=f"{MODELS}/{model_version_filter_model.model_id}{MODEL_VERSIONS}", + route=f"{MODELS}/{model_version_filter_model.model}{MODEL_VERSIONS}", response_model=ModelVersionResponseModel, filter_model=model_version_filter_model, ) + def update_model_version( + self, + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + ) -> ModelVersionResponseModel: + """Get all model versions by filter. + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + + """ + return self._update_resource( + resource_id=model_version_id, + resource_update=model_version_update_model, + route=f"{MODELS}/{model_version_update_model.model}{MODEL_VERSIONS}", + response_model=ModelVersionResponseModel, + ) + # ======================= # Internal helper methods # ======================= diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index da97c74b852..0bd681352f3 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -76,6 +76,7 @@ ) from zenml.io import fileio from zenml.logger import get_console_handler, get_logger, get_logging_level +from zenml.model import ModelStages from zenml.models import ( ArtifactFilterModel, ArtifactRequestModel, @@ -101,6 +102,7 @@ ModelVersionFilterModel, ModelVersionRequestModel, ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -5630,3 +5632,69 @@ def delete_model_version( ) session.delete(model_version) session.commit() + + def update_model_version( + self, + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + ) -> ModelVersionResponseModel: + """Get all model versions by filter. + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + + Raises: + KeyError: If the model version not found + RuntimeError: If there is a model version with target stage, but `force` flag is off + """ + with Session(self.engine) as session: + existing_model_version = session.exec( + select(ModelVersionSchema) + .where( + ModelVersionSchema.model_id + == model_version_update_model.model + ) + .where(ModelVersionSchema.id == model_version_id) + ).first() + + if not existing_model_version: + raise KeyError(f"Model version {model_version_id} not found.") + + existing_model_version_in_target_stage = session.exec( + select(ModelVersionSchema) + .where( + ModelVersionSchema.model_id + == model_version_update_model.model + ) + .where( + ModelVersionSchema.stage + == model_version_update_model.stage.value + ) + ).first() + + if existing_model_version_in_target_stage is not None: + if not model_version_update_model.force: + raise RuntimeError( + f"Model version {existing_model_version_in_target_stage.version} is " + f"in {model_version_update_model.stage.value}, but `force` flag is False." + ) + else: + existing_model_version_in_target_stage.stage = ( + ModelStages.ARCHIVED.value + ) + session.add(existing_model_version_in_target_stage) + session.commit() + session.refresh(existing_model_version_in_target_stage) + + logger.info( + f"Model version {existing_model_version_in_target_stage.version} has been set to {ModelStages.ARCHIVED.value}." + ) + existing_model_version.stage = model_version_update_model.stage + session.add(existing_model_version) + session.commit() + session.refresh(existing_model_version) + + return existing_model_version.to_model() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index d1a1d2562b3..5ebf682cec8 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -39,6 +39,7 @@ ModelVersionFilterModel, ModelVersionRequestModel, ModelVersionResponseModel, + ModelVersionUpdateModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineBuildResponseModel, @@ -1782,7 +1783,7 @@ def get_model_version( """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. - model_version_name_or_id: name or id of the model version to be retrieved. + model_version_name: name or id of the model version to be retrieved. Returns: The model version of interest. Raises: @@ -1801,3 +1802,22 @@ def list_model_versions( Returns: A page of all model versions. """ + + @abstractmethod + def update_model_version( + self, + model_version_id: UUID, + model_version_update_model: ModelVersionUpdateModel, + ) -> ModelVersionResponseModel: + """Get all model versions by filter. + Args: + model_version_id: The ID of model version to be updated. + model_version_update_model: The model version to be updated. + + Returns: + An updated model version. + + Raises: + KeyError: If the model version not found + RuntimeError: If there is a model version with target stage, but `force` flag is off + """ diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 178d7905ea1..cf88c2fd59d 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -54,6 +54,7 @@ ComponentUpdateModel, ModelVersionFilterModel, ModelVersionRequestModel, + ModelVersionUpdateModel, PipelineRunFilterModel, RoleFilterModel, RoleRequestModel, @@ -2512,9 +2513,7 @@ def test_model_version_get_found(): def test_model_version_list_empty(): with ModelVersionContext() as model: zs = Client().zen_store - mvs = zs.list_model_versions( - ModelVersionFilterModel(model_id=model.id) - ) + mvs = zs.list_model_versions(ModelVersionFilterModel(model=model.id)) assert len(mvs) == 0 @@ -2537,9 +2536,7 @@ def test_model_version_list_not_empty(): version="and yet another one", ) ) - mvs = zs.list_model_versions( - ModelVersionFilterModel(model_id=model.id) - ) + mvs = zs.list_model_versions(ModelVersionFilterModel(model=model.id)) assert len(mvs) == 2 assert mv1 in mvs assert mv2 in mvs @@ -2575,3 +2572,137 @@ def test_model_version_delete_found(): model_name_or_id=model.id, model_version_name="great one", ) + + +def test_model_version_update_not_found(): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.update_model_version( + model_version_id=uuid4(), + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + + +def test_model_version_update_not_forced(): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="yet another one", + ) + ) + zs.update_model_version( + model_version_id=mv1.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + with pytest.raises(RuntimeError): + zs.update_model_version( + model_version_id=mv2.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + + +def test_model_version_update_forced(): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="yet another one", + ) + ) + zs.update_model_version( + model_version_id=mv1.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, model_version_name=mv1.version + ).stage + == "staging" + ) + zs.update_model_version( + model_version_id=mv2.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=True, + ), + ) + + assert ( + zs.get_model_version( + model_name_or_id=model.id, model_version_name=mv1.version + ).stage + == "archived" + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, model_version_name=mv2.version + ).stage + == "staging" + ) + + +def test_model_version_update_public_interface(): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, model_version_name=mv1.version + ).stage + is None + ) + mv1.set_stage("staging") + assert ( + zs.get_model_version( + model_name_or_id=model.id, model_version_name=mv1.version + ).stage + == "staging" + ) From 1a083b0c6038e78e3b8d504d86389809c1e62277 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:39:31 +0200 Subject: [PATCH 23/40] add update interface --- src/zenml/zen_stores/schemas/model_schemas.py | 16 ++++++++++++++++ src/zenml/zen_stores/sql_zen_store.py | 6 ++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 3922dfda18c..74993ceb8b3 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -254,6 +254,22 @@ def to_model(self) -> ModelVersionResponseModel: ], ) + def update( + self, + target_stage: str, + ) -> "ModelVersionSchema": + """Updates a `ModelVersionSchema` to a target stage. + + Args: + target_stage: The stage to be updated. + + Returns: + The updated `ModelVersionSchema`. + """ + self.stage = target_stage + self.updated = datetime.utcnow() + return self + class ModelVersionLinkSchema(NamedSchema, table=True): """SQL Model for linking of Model Versions and Artifacts or Pipeline Runs M:M.""" diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 0bd681352f3..d5c11f5ba05 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5682,7 +5682,7 @@ def update_model_version( f"in {model_version_update_model.stage.value}, but `force` flag is False." ) else: - existing_model_version_in_target_stage.stage = ( + existing_model_version_in_target_stage.update( ModelStages.ARCHIVED.value ) session.add(existing_model_version_in_target_stage) @@ -5692,7 +5692,9 @@ def update_model_version( logger.info( f"Model version {existing_model_version_in_target_stage.version} has been set to {ModelStages.ARCHIVED.value}." ) - existing_model_version.stage = model_version_update_model.stage + existing_model_version.update( + model_version_update_model.stage.value + ) session.add(existing_model_version) session.commit() session.refresh(existing_model_version) From 571026f904614892ea8b8dc3ed8828e482e33866 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 17:28:54 +0200 Subject: [PATCH 24/40] add model version links --- src/zenml/constants.py | 1 + src/zenml/models/__init__.py | 2 + src/zenml/models/model_models.py | 61 +++++- .../zen_server/routers/models_endpoints.py | 73 ++++++- .../routers/workspaces_endpoints.py | 101 +++++++++- ...8b82e9253a9_add_model_version_and_links.py | 7 + src/zenml/zen_stores/rest_zen_store.py | 68 ++++++- src/zenml/zen_stores/schemas/model_schemas.py | 27 ++- src/zenml/zen_stores/sql_zen_store.py | 172 +++++++++++++++- src/zenml/zen_stores/zen_store_interface.py | 64 +++++- .../functional/zen_stores/test_zen_store.py | 188 +++++++++++++++++- .../functional/zen_stores/utils.py | 30 ++- 12 files changed, 746 insertions(+), 48 deletions(-) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index f36ef7db778..46df4fac5b9 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -223,6 +223,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_CLIENT = "/client" MODELS = "/models" MODEL_VERSIONS = "/model_versions" +MODEL_VERSION_LINKS = "/model_version_links" # mandatory stack component attributes MANDATORY_COMPONENT_ATTRIBUTES = ["name", "uuid"] diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index c61dd78aef6..de8ce8e1f12 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -158,6 +158,7 @@ ModelVersionResponseModel, ModelVersionRequestModel, ModelVersionLinkBaseModel, + ModelVersionLinkFilterModel, ModelVersionLinkRequestModel, ModelVersionLinkResponseModel, ModelVersionFilterModel, @@ -418,6 +419,7 @@ "ModelVersionResponseModel", "ModelVersionUpdateModel", "ModelVersionLinkBaseModel", + "ModelVersionLinkFilterModel", "ModelVersionLinkRequestModel", "ModelVersionLinkResponseModel", ] diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index d76df6bfbf3..1f97a21c21b 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -139,10 +139,10 @@ def set_stage(self, stage: ModelStages, force: bool = False): class ModelVersionFilterModel(WorkspaceScopedFilterModel): """Filter Model for Model Version.""" - model: Optional[Union[str, UUID]] = Field( + model_id: Union[str, UUID] = Field( description="The ID of the Model", ) - version: Optional[str] = Field( + version: Optional[Union[str, UUID]] = Field( default=None, description="The name of the Model Version", ) @@ -175,21 +175,22 @@ class ModelVersionLinkBaseModel(BaseModel): title="The name of the artifact inside model version.", max_length=STR_FIELD_MAX_LENGTH, ) - artifact_id: Optional[UUID] - pipeline_run_id: Optional[UUID] - model_version_id: UUID + artifact: Optional[UUID] + pipeline_run: Optional[UUID] + model: UUID + model_version: UUID is_model_object: bool = False is_deployment: bool = False - @validator("model_version_id") + @validator("model_version") def validate_links(cls, model_version_id, values): - artifact_id = values.get("artifact_id", None) - pipeline_run_id = values.get("pipeline_run_id", None) - if (artifact_id is None and pipeline_run_id is None) or ( - artifact_id is not None and pipeline_run_id is not None + artifact = values.get("artifact", None) + pipeline_run = values.get("pipeline_run", None) + if (artifact is None and pipeline_run is None) or ( + artifact is not None and pipeline_run is not None ): raise ValueError( - "You must provide only `artifact_id` or only `pipeline_run_id`." + "You must provide only `artifact` or only `pipeline_run`." ) return model_version_id @@ -206,6 +207,44 @@ class ModelVersionLinkResponseModel( """Model version links response model.""" +class ModelVersionLinkFilterModel(WorkspaceScopedFilterModel): + """Model version links filter model.""" + + model_id: Union[str, UUID] = Field( + description="The name or ID of the Model", + ) + model_version_id: Union[str, UUID] = Field( + default=None, + description="The name or ID of the Model Version", + ) + name: Optional[str] = Field( + title="The name of the artifact inside model version.", + max_length=STR_FIELD_MAX_LENGTH, + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="The workspace of the Model Version" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="The user of the Model Version" + ) + only_artifacts: bool = False + only_model_objects: bool = False + only_deployments: bool = False + only_pipeline_runs: bool = False + + @validator("only_pipeline_runs") + def validate_flags(cls, only_pipeline_runs, values): + s = int(only_pipeline_runs) + s += int(values.get("only_artifacts", False)) + s += int(values.get("only_model_objects", False)) + s += int(values.get("only_deployments", False)) + if s > 1: + raise ValueError( + "Only one of the selection flags can be used at once." + ) + return only_pipeline_runs + + class ModelConfigBaseModel(BaseModel): """Model Config base model.""" diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 3bd87623dd0..1a18e622106 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -18,13 +18,21 @@ from fastapi import APIRouter, Depends, Security -from zenml.constants import API, MODEL_VERSIONS, MODELS, VERSION_1 +from zenml.constants import ( + API, + MODEL_VERSION_LINKS, + MODEL_VERSIONS, + MODELS, + VERSION_1, +) from zenml.enums import PermissionType from zenml.models import ( ModelFilterModel, ModelResponseModel, ModelUpdateModel, ModelVersionFilterModel, + ModelVersionLinkFilterModel, + ModelVersionLinkResponseModel, ModelVersionResponseModel, ModelVersionUpdateModel, ) @@ -239,3 +247,66 @@ def update_model( model_version_id=model_version_id, model_version_update_model=model_version_update_model, ) + + +###################### +# Model Version Links +###################### + + +@router.get( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + MODEL_VERSION_LINKS, + response_model=Page[ModelVersionLinkResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_links( + model_version_link_filter_model: ModelVersionLinkFilterModel = Depends( + make_dependable(ModelVersionLinkFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionLinkResponseModel]: + """Get model version links according to query filters. + + Args: + model_version_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version links according to query filters. + """ + return zen_store().list_model_version_links( + model_version_link_filter_model=model_version_link_filter_model, + ) + + +@router.delete( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + MODEL_VERSION_LINKS + + "/{model_version_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_link( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_link_name_or_id: name or ID of the model version link to be deleted. + """ + zen_store().delete_model_version_link( + model_name_or_id, + model_version_name_or_id, + model_version_link_name_or_id, + ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 49486c1a74f..8e849a0a57b 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -21,6 +21,7 @@ API, CODE_REPOSITORIES, GET_OR_CREATE, + MODEL_VERSION_LINKS, MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, @@ -53,6 +54,9 @@ ModelRequestModel, ModelResponseModel, ModelVersionFilterModel, + ModelVersionLinkFilterModel, + ModelVersionLinkRequestModel, + ModelVersionLinkResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, PipelineBuildFilterModel, @@ -1250,7 +1254,7 @@ def create_model_version( Args: model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. - model_version: The model to create. + model_version: The model version to create. auth_context: Authentication context. Returns: @@ -1307,3 +1311,98 @@ def list_workspace_model_versions( return zen_store().list_model_versions( model_version_filter_model=model_version_filter_model, ) + + +@router.post( + WORKSPACES + + "/{workspace_name_or_id}" + + MODELS + + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + MODEL_VERSION_LINKS, + response_model=ModelVersionLinkResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version_link( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_link: ModelVersionLinkRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelVersionLinkResponseModel: + """Create a new model version link. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_link: The model version link to create. + auth_context: Authentication context. + + Returns: + The created model version link. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model version does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model_version_link.workspace != workspace.id: + raise IllegalOperationError( + "Creating model versions outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model_version_link.user != auth_context.user.id: + raise IllegalOperationError( + "Creating models for a user other than yourself " + "is not supported." + ) + mv = zen_store().create_model_version_link(model_version_link) + return mv + + +@router.get( + WORKSPACES + + "/{workspace_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + MODEL_VERSION_LINKS, + response_model=Page[ModelVersionLinkResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_workspace_model_version_links( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_link_filter_model: ModelVersionLinkFilterModel = Depends( + make_dependable(ModelVersionLinkFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionLinkResponseModel]: + """Get models according to query filters. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_link: The model version link to create. + model_version_link_filter_model: Filter model used for pagination, sorting, + filtering + + + Returns: + The model version links according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_version_link_filter_model.set_scope_workspace(workspace_id) + return zen_store().list_model_version_links( + model_version_link_filter_model=model_version_link_filter_model, + ) diff --git a/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py b/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py index 7dd1a6dabad..87a7d77d634 100644 --- a/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py +++ b/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py @@ -61,6 +61,7 @@ def upgrade() -> None: sa.Column( "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=False ), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), sa.Column("artifact_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), sa.Column( "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True @@ -77,6 +78,12 @@ def upgrade() -> None: name="fk_model_version_links_artifact_id_artifact", ondelete="CASCADE", ), + sa.ForeignKeyConstraint( + ["model_id"], + ["model.id"], + name="fk_model_version_links_model_id_model", + ondelete="CASCADE", + ), sa.ForeignKeyConstraint( ["model_version_id"], ["model_version.id"], diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 63a75996b2e..bee8dbde207 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -49,6 +49,7 @@ GET_OR_CREATE, INFO, LOGIN, + MODEL_VERSION_LINKS, MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, @@ -101,6 +102,9 @@ ModelResponseModel, ModelUpdateModel, ModelVersionFilterModel, + ModelVersionLinkFilterModel, + ModelVersionLinkRequestModel, + ModelVersionLinkResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ModelVersionUpdateModel, @@ -2408,7 +2412,7 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name: str, + model_version_name_or_id: Union[str, UUID], ) -> ModelVersionResponseModel: """Get an existing model version. @@ -2420,7 +2424,7 @@ def get_model_version( The model version of interest. """ return self._get_resource( - resource_id=model_version_name, + resource_id=model_version_name_or_id, route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", response_model=ModelVersionResponseModel, ) @@ -2440,7 +2444,7 @@ def list_model_versions( """ return self._list_paginated_resources( - route=f"{MODELS}/{model_version_filter_model.model}{MODEL_VERSIONS}", + route=f"{MODELS}/{model_version_filter_model.model_id}{MODEL_VERSIONS}", response_model=ModelVersionResponseModel, filter_model=model_version_filter_model, ) @@ -2466,6 +2470,64 @@ def update_model_version( response_model=ModelVersionResponseModel, ) + ####################### + # Model Versions Links + ####################### + + def create_model_version_link( + self, model_version_link: ModelVersionLinkRequestModel + ) -> ModelVersionLinkResponseModel: + """Creates a new model version link. + + Args: + model_version_link: the Model Version Link to be created. + + Returns: + The newly created model version link. + """ + return self._create_workspace_scoped_resource( + resource=model_version_link, + response_model=ModelVersionLinkResponseModel, + route=f"{MODELS}/{model_version_link.model}{MODEL_VERSIONS}/{model_version_link.model_version}{MODEL_VERSION_LINKS}", + ) + + def list_model_version_links( + self, + model_version_link_filter_model: ModelVersionLinkFilterModel, + ) -> Page[ModelVersionLinkResponseModel]: + """Get all model version links by filter. + + Args: + model_version_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version links. + """ + return self._list_paginated_resources( + route=f"{MODELS}/{model_version_link_filter_model.model_id}{MODEL_VERSIONS}/{model_version_link_filter_model.model_version_id}{MODEL_VERSION_LINKS}", + response_model=ModelVersionLinkResponseModel, + filter_model=model_version_link_filter_model, + ) + + def delete_model_version_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_link_name_or_id: name or ID of the model version link to be deleted. + """ + self._delete_resource( + resource_id=model_version_link_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{MODEL_VERSION_LINKS}", + ) + # ======================= # Internal helper methods # ======================= diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 74993ceb8b3..dc0e8e9bd12 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -76,6 +76,10 @@ class ModelSchema(NamedSchema, table=True): back_populates="model", sa_relationship_kwargs={"cascade": "delete"}, ) + objects_links: List["ModelVersionLinkSchema"] = Relationship( + back_populates="model", + sa_relationship_kwargs={"cascade": "delete"}, + ) @classmethod def from_request(cls, model_request: ModelRequestModel) -> "ModelSchema": @@ -300,6 +304,15 @@ class ModelVersionLinkSchema(NamedSchema, table=True): back_populates="model_version_links" ) + model_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelSchema.__tablename__, + source_column="model_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model: "ModelSchema" = Relationship(back_populates="objects_links") model_version_id: UUID = build_foreign_key_field( source=__tablename__, target=ModelVersionSchema.__tablename__, @@ -353,9 +366,10 @@ def from_request( name=model_version_artifact_request.name, workspace_id=model_version_artifact_request.workspace, user_id=model_version_artifact_request.user, - model_version_id=model_version_artifact_request.model_version_id, - artifact_id=model_version_artifact_request.artifact_id, - pipeline_run_id=model_version_artifact_request.pipeline_run_id, + model_id=model_version_artifact_request.model, + model_version_id=model_version_artifact_request.model_version, + artifact_id=model_version_artifact_request.artifact, + pipeline_run_id=model_version_artifact_request.pipeline_run, is_model_object=model_version_artifact_request.is_model_object, is_deployment=model_version_artifact_request.is_deployment, ) @@ -373,9 +387,10 @@ def to_model(self) -> ModelVersionLinkResponseModel: workspace=self.workspace.to_model(), created=self.created, updated=self.updated, - model_version_id=self.model_version_id, - artifact_id=self.artifact_id, - pipeline_run_id=self.pipeline_run_id, + model=self.model_id, + model_version=self.model_version_id, + artifact=self.artifact_id, + pipeline_run=self.pipeline_run_id, is_model_object=self.is_model_object, is_deployment=self.is_deployment, ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index d5c11f5ba05..fe4ced2d120 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -100,6 +100,9 @@ ModelResponseModel, ModelUpdateModel, ModelVersionFilterModel, + ModelVersionLinkFilterModel, + ModelVersionLinkRequestModel, + ModelVersionLinkResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ModelVersionUpdateModel, @@ -199,6 +202,7 @@ FlavorSchema, IdentitySchema, ModelSchema, + ModelVersionLinkSchema, ModelVersionSchema, NamedSchema, PipelineBuildSchema, @@ -5563,7 +5567,7 @@ def create_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name: str, + model_version_name_or_id: Union[str, UUID], ) -> ModelVersionResponseModel: """Get an existing model version. Args: @@ -5576,14 +5580,22 @@ def get_model_version( """ with Session(self.engine) as session: model = self.get_model(model_name_or_id) - model_version = session.exec( - select(ModelVersionSchema) - .where(ModelVersionSchema.model_id == model.id) - .where(ModelVersionSchema.version == model_version_name) - ).first() + query = select(ModelVersionSchema).where( + ModelVersionSchema.model_id == model.id + ) + try: + UUID(str(model_version_name_or_id)) + query = query.where( + ModelVersionSchema.id == model_version_name_or_id + ) + except: + query = query.where( + ModelVersionSchema.version == model_version_name_or_id + ) + model_version = session.exec(query).first() if model_version is None: raise KeyError( - f"Unable to get model version with name `{model_version_name}`: " + f"Unable to get model version with name `{model_version_name_or_id}`: " f"No model version with this name found." ) return ModelVersionSchema.to_model(model_version) @@ -5700,3 +5712,149 @@ def update_model_version( session.refresh(existing_model_version) return existing_model_version.to_model() + + ####################### + # Model Versions Links + ####################### + + def create_model_version_link( + self, model_version_link: ModelVersionLinkRequestModel + ) -> ModelVersionLinkResponseModel: + """Creates a new model version link. + + Args: + model_version_link: the Model Version Link to be created. + + Returns: + The newly created model version link. + + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ + with Session(self.engine) as session: + existing_model_version_link = session.exec( + select(ModelVersionLinkSchema) + .where( + ModelVersionLinkSchema.model_version_id + == model_version_link.model_version + ) + .where(ModelVersionLinkSchema.name == model_version_link.name) + ).first() + if existing_model_version_link is not None: + raise EntityExistsError( + f"Unable to create model version link {existing_model_version_link.name}: " + f"A model version link with this name already exists in {model_version_link.model_version} model version." + ) + + model_version_link_schema = ModelVersionLinkSchema.from_request( + model_version_link + ) + session.add(model_version_link_schema) + + session.commit() + mvl = ModelVersionLinkSchema.to_model(model_version_link_schema) + return mvl + + def list_model_version_links( + self, + model_version_link_filter_model: ModelVersionLinkFilterModel, + ) -> Page[ModelVersionLinkResponseModel]: + """Get all model version links by filter. + + Args: + model_version_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version links. + """ + with Session(self.engine) as session: + if model_version_link_filter_model.only_artifacts: + query = ( + select(ModelVersionLinkSchema) + .where(ModelVersionLinkSchema.is_model_object == False) + .where(ModelVersionLinkSchema.is_deployment == False) + .where(ModelVersionLinkSchema.pipeline_run == None) + .where(ModelVersionLinkSchema.artifact != None) + ) + elif model_version_link_filter_model.only_deployments: + query = ( + select(ModelVersionLinkSchema) + .where(ModelVersionLinkSchema.is_deployment) + .where(ModelVersionLinkSchema.is_model_object == False) + .where(ModelVersionLinkSchema.pipeline_run == None) + .where(ModelVersionLinkSchema.artifact != None) + ) + elif model_version_link_filter_model.only_model_objects: + query = ( + select(ModelVersionLinkSchema) + .where(ModelVersionLinkSchema.is_model_object) + .where(ModelVersionLinkSchema.is_deployment == False) + .where(ModelVersionLinkSchema.pipeline_run == None) + .where(ModelVersionLinkSchema.artifact != None) + ) + elif model_version_link_filter_model.only_pipeline_runs: + query = ( + select(ModelVersionLinkSchema) + .where(ModelVersionLinkSchema.is_model_object == False) + .where(ModelVersionLinkSchema.is_deployment == False) + .where(ModelVersionLinkSchema.pipeline_run != None) + .where(ModelVersionLinkSchema.artifact == None) + ) + else: + query = select(ModelVersionLinkSchema) + model_version_link_filter_model.only_artifacts = None + model_version_link_filter_model.only_deployments = None + model_version_link_filter_model.only_model_objects = None + model_version_link_filter_model.only_pipeline_runs = None + return self.filter_and_paginate( + session=session, + query=query, + table=ModelVersionLinkSchema, + filter_model=model_version_link_filter_model, + ) + + def delete_model_version_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_link_name_or_id: name or ID of the model version link to be deleted. + + Raises: + KeyError: specified ID or name not found. + """ + with Session(self.engine) as session: + self.get_model(model_name_or_id) + model_version = self.get_model_version( + model_name_or_id, model_version_name_or_id + ) + query = select(ModelVersionLinkSchema).where( + ModelVersionLinkSchema.model_version_id == model_version.id + ) + try: + UUID(str(model_version_link_name_or_id)) + query = query.where( + ModelVersionLinkSchema.id == model_version_link_name_or_id + ) + except: + query = query.where( + ModelVersionLinkSchema.name + == model_version_link_name_or_id + ) + + model_version_link = session.exec(query).first() + if model_version_link is None: + raise KeyError( + f"Unable to delete model version link with name `{model_version_link_name_or_id}`: " + f"No model version link with this name found." + ) + + session.delete(model_version_link) + session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 5ebf682cec8..1c9cba7e018 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -37,6 +37,9 @@ ModelResponseModel, ModelUpdateModel, ModelVersionFilterModel, + ModelVersionLinkFilterModel, + ModelVersionLinkRequestModel, + ModelVersionLinkResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ModelVersionUpdateModel, @@ -1690,7 +1693,7 @@ def create_model(self, model: ModelRequestModel) -> ModelResponseModel: The newly created model. Raises: - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a model with the given name already exists. """ @abstractmethod @@ -1759,7 +1762,7 @@ def create_model_version( Returns: The newly created model version. Raises: - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a model version with the given name already exists. """ @abstractmethod @@ -1778,12 +1781,12 @@ def delete_model_version( def get_model_version( self, model_name_or_id: Union[str, UUID], - model_version_name: str, + model_version_name_or_id: Union[str, UUID], ) -> ModelVersionResponseModel: """Get an existing model version. Args: model_name_or_id: name or id of the model containing the model version. - model_version_name: name or id of the model version to be retrieved. + model_version_name_or_id: name or id of the model version to be retrieved. Returns: The model version of interest. Raises: @@ -1821,3 +1824,56 @@ def update_model_version( KeyError: If the model version not found RuntimeError: If there is a model version with target stage, but `force` flag is off """ + + ####################### + # Model Versions Links + ####################### + + @abstractmethod + def create_model_version_link( + self, model_version_link: ModelVersionLinkRequestModel + ) -> ModelVersionLinkResponseModel: + """Creates a new model version link. + + Args: + model_version_link: the Model Version Link to be created. + + Returns: + The newly created model version link. + + Raises: + EntityExistsError: If a workspace with the given name already exists. + """ + + @abstractmethod + def list_model_version_links( + self, + model_version_link_filter_model: ModelVersionLinkFilterModel, + ) -> Page[ModelVersionLinkResponseModel]: + """Get all model version links by filter. + + Args: + model_version_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version links. + """ + + @abstractmethod + def delete_model_version_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_link_name_or_id: name or ID of the model version link to be deleted. + + Raises: + KeyError: specified ID or name not found. + """ diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index cf88c2fd59d..23d2e4832d3 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -53,6 +53,8 @@ ComponentFilterModel, ComponentUpdateModel, ModelVersionFilterModel, + ModelVersionLinkFilterModel, + ModelVersionLinkRequestModel, ModelVersionRequestModel, ModelVersionUpdateModel, PipelineRunFilterModel, @@ -2489,7 +2491,7 @@ def test_model_version_get_not_found(): zs = Client().zen_store with pytest.raises(KeyError): zs.get_model_version( - model_name_or_id=model.id, model_version_name="1.0.0" + model_name_or_id=model.id, model_version_name_or_id="1.0.0" ) @@ -2506,14 +2508,16 @@ def test_model_version_get_found(): ) zs.get_model_version( model_name_or_id=model.id, - model_version_name="great one", + model_version_name_or_id="great one", ) def test_model_version_list_empty(): with ModelVersionContext() as model: zs = Client().zen_store - mvs = zs.list_model_versions(ModelVersionFilterModel(model=model.id)) + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) + ) assert len(mvs) == 0 @@ -2536,7 +2540,9 @@ def test_model_version_list_not_empty(): version="and yet another one", ) ) - mvs = zs.list_model_versions(ModelVersionFilterModel(model=model.id)) + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) + ) assert len(mvs) == 2 assert mv1 in mvs assert mv2 in mvs @@ -2570,7 +2576,7 @@ def test_model_version_delete_found(): with pytest.raises(KeyError): zs.get_model_version( model_name_or_id=model.id, - model_version_name="great one", + model_version_name_or_id="great one", ) @@ -2655,7 +2661,7 @@ def test_model_version_update_forced(): ) assert ( zs.get_model_version( - model_name_or_id=model.id, model_version_name=mv1.version + model_name_or_id=model.id, model_version_name_or_id=mv1.version ).stage == "staging" ) @@ -2670,13 +2676,13 @@ def test_model_version_update_forced(): assert ( zs.get_model_version( - model_name_or_id=model.id, model_version_name=mv1.version + model_name_or_id=model.id, model_version_name_or_id=mv1.version ).stage == "archived" ) assert ( zs.get_model_version( - model_name_or_id=model.id, model_version_name=mv2.version + model_name_or_id=model.id, model_version_name_or_id=mv2.version ).stage == "staging" ) @@ -2695,14 +2701,176 @@ def test_model_version_update_public_interface(): ) assert ( zs.get_model_version( - model_name_or_id=model.id, model_version_name=mv1.version + model_name_or_id=model.id, model_version_name_or_id=mv1.version ).stage is None ) mv1.set_stage("staging") assert ( zs.get_model_version( - model_name_or_id=model.id, model_version_name=mv1.version + model_name_or_id=model.id, model_version_name_or_id=mv1.version ).stage == "staging" ) + + +def test_model_version_link_create_pass(): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=uuid4(), + ) + ) + + +def test_model_version_link_create_duplicated(): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=uuid4(), + ) + ) + + with pytest.raises(EntityExistsError): + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=uuid4(), + ) + ) + + +def test_model_version_link_delete_found(): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=uuid4(), + ) + ) + zs.delete_model_version_link( + model_version.model.id, model_version.id, "link" + ) + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(l) == 0 + + +def test_model_version_link_delete_not_found(): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version_link( + model_version.model.id, model_version.id, "link" + ) + + +def test_model_version_link_list_empty(): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(l) == 0 + + +def test_model_version_link_list_populated(): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(l) == 0 + for n, mo, dep, pr in [ + ("link1", False, False, False), + ("link2", True, False, False), + ("link3", False, True, False), + ("link4", False, False, True), + ]: + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name=n, + artifact=uuid4() if not pr else None, + pipeline_run=uuid4() if pr else None, + is_model_object=mo, + is_deployment=dep, + ) + ) + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(l) == 4 + + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_artifacts=True, + ) + ) + assert len(l) == 1 and l[0].name == "link1" + + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_model_objects=True, + ) + ) + assert len(l) == 1 and l[0].name == "link2" + + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_deployments=True, + ) + ) + assert len(l) == 1 and l[0].name == "link3" + + l = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_pipeline_runs=True, + ) + ) + assert len(l) == 1 and l[0].name == "link4" diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index bf6bd15cf3a..8d416219ce7 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -45,6 +45,7 @@ ModelFilterModel, ModelRequestModel, ModelUpdateModel, + ModelVersionRequestModel, PipelineBuildFilterModel, PipelineBuildRequestModel, PipelineDeploymentFilterModel, @@ -510,13 +511,17 @@ def __exit__(self, exc_type, exc_value, exc_traceback): class ModelVersionContext: - def __init__(self): + def __init__(self, create_version: bool = False): self.workspace = "workspace" self.user = "su" self.model = "su_model" + self.model_version = "2.0.0" self.del_ws = False self.del_user = False self.del_model = False + self.del_mv = False + + self.create_version = create_version def __enter__(self): zs = Client().zen_store @@ -541,18 +546,33 @@ def __enter__(self): ) ) self.del_model = True - return model + if self.create_version: + try: + mv = zs.get_model_version(self.model, self.model_version) + except: + mv = zs.create_model_version( + ModelVersionRequestModel( + user=user.id, + workspace=ws.id, + model=model.id, + version=self.model_version, + ) + ) + self.del_mv = True + if self.create_version: + return mv + else: + return model def __exit__(self, exc_type, exc_value, exc_traceback): zs = Client().zen_store + if self.del_mv: + zs.delete_model_version(self.model, self.model_version) if self.del_model: - print("del_model") zs.delete_model(self.model) if self.del_user: - print("del_user") zs.delete_user(self.user) if self.del_ws: - print("del_ws") zs.delete_workspace(self.workspace) From ad66d2ecab4ec3f17ff4c4e22b5fb1b94944cdae Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 12 Sep 2023 17:59:23 +0200 Subject: [PATCH 25/40] lint --- src/zenml/models/model_models.py | 60 +++++++--- .../zen_server/routers/models_endpoints.py | 3 +- .../routers/workspaces_endpoints.py | 2 - src/zenml/zen_stores/rest_zen_store.py | 2 +- src/zenml/zen_stores/sql_zen_store.py | 112 ++++++++++++++---- src/zenml/zen_stores/zen_store_interface.py | 14 +++ .../functional/zen_stores/test_zen_store.py | 32 ++--- .../functional/zen_stores/utils.py | 8 +- 8 files changed, 170 insertions(+), 63 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 1f97a21c21b..999dccd548a 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Model implementation to support Model WatchTower feature.""" -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from uuid import UUID from pydantic import BaseModel, Field, validator @@ -98,28 +98,53 @@ def _fetch_artifacts_from_list( @property def model_objects(self) -> Dict[str, ArtifactResponseModel]: - """Get all model objects linked to this version.""" + """Get all model objects linked to this version. + + Returns: + Dictionary of Model Objects as ArtifactResponseModel + """ return self._fetch_artifacts_from_list(self._model_objects) @property def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: - """Get all artifacts linked to this version.""" + """Get all artifacts linked to this version. + + Returns: + Dictionary of Artifact Objects as ArtifactResponseModel + """ return self._fetch_artifacts_from_list(self._artifact_objects) @property def deployments(self) -> Dict[str, ArtifactResponseModel]: - """Get all deployments linked to this version.""" + """Get all deployments linked to this version. + + Returns: + Dictionary of Deployments as ArtifactResponseModel + """ return self._fetch_artifacts_from_list(self._deployments) @property def pipeline_runs(self) -> List[PipelineRunResponseModel]: - """Get all pipeline runs linked to this version.""" + """Get all pipeline runs linked to this version. + + Returns: + List of Pipeline Runs as PipelineRunResponseModel + """ from zenml.client import Client return [Client().get_pipeline_run(pr) for pr in self._pipeline_runs] - def set_stage(self, stage: ModelStages, force: bool = False): - """Sets this Model Version to a desired stage.""" + def set_stage( + self, stage: ModelStages, force: bool = False + ) -> "ModelVersionResponseModel": + """Sets this Model Version to a desired stage. + + Args: + stage: the target stage for model version. + force: whether to force archiving of current model version in target stage or raise. + + Returns: + Dictionary of Model Objects as model_version_name_or_id""" from zenml.client import Client return Client().zen_store.update_model_version( @@ -155,6 +180,8 @@ class ModelVersionFilterModel(WorkspaceScopedFilterModel): class ModelVersionUpdateModel(BaseModel): + """Update Model for Model Version.""" + model: UUID = Field( title="The ID of the model containing version", ) @@ -183,7 +210,9 @@ class ModelVersionLinkBaseModel(BaseModel): is_deployment: bool = False @validator("model_version") - def validate_links(cls, model_version_id, values): + def _validate_links( + cls, model_version: UUID, values: Dict[str, Any] + ) -> UUID: artifact = values.get("artifact", None) pipeline_run = values.get("pipeline_run", None) if (artifact is None and pipeline_run is None) or ( @@ -192,7 +221,7 @@ def validate_links(cls, model_version_id, values): raise ValueError( "You must provide only `artifact` or only `pipeline_run`." ) - return model_version_id + return model_version class ModelVersionLinkRequestModel( @@ -214,7 +243,6 @@ class ModelVersionLinkFilterModel(WorkspaceScopedFilterModel): description="The name or ID of the Model", ) model_version_id: Union[str, UUID] = Field( - default=None, description="The name or ID of the Model Version", ) name: Optional[str] = Field( @@ -227,13 +255,15 @@ class ModelVersionLinkFilterModel(WorkspaceScopedFilterModel): user_id: Optional[Union[UUID, str]] = Field( default=None, description="The user of the Model Version" ) - only_artifacts: bool = False - only_model_objects: bool = False - only_deployments: bool = False - only_pipeline_runs: bool = False + only_artifacts: Optional[bool] = False + only_model_objects: Optional[bool] = False + only_deployments: Optional[bool] = False + only_pipeline_runs: Optional[bool] = False @validator("only_pipeline_runs") - def validate_flags(cls, only_pipeline_runs, values): + def _validate_flags( + cls, only_pipeline_runs: bool, values: Dict[str, Any] + ) -> bool: s = int(only_pipeline_runs) s += int(values.get("only_artifacts", False)) s += int(values.get("only_model_objects", False)) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 1a18e622106..2a386d2ebac 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -230,12 +230,13 @@ def get_model_version( responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions -def update_model( +def update_model_version( model_version_id: UUID, model_version_update_model: ModelVersionUpdateModel, _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), ) -> ModelVersionResponseModel: """Get all model versions by filter. + Args: model_version_id: The ID of model version to be updated. model_version_update_model: The model version to be updated. diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 8e849a0a57b..39da18affca 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1393,11 +1393,9 @@ def list_workspace_model_version_links( model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. model_version_name_or_id: Name or ID of the model version. - model_version_link: The model version link to create. model_version_link_filter_model: Filter model used for pagination, sorting, filtering - Returns: The model version links according to query filters. """ diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index bee8dbde207..04a024e2abc 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2442,7 +2442,6 @@ def list_model_versions( Returns: A page of all model versions. """ - return self._list_paginated_resources( route=f"{MODELS}/{model_version_filter_model.model_id}{MODEL_VERSIONS}", response_model=ModelVersionResponseModel, @@ -2455,6 +2454,7 @@ def update_model_version( model_version_update_model: ModelVersionUpdateModel, ) -> ModelVersionResponseModel: """Get all model versions by filter. + Args: model_version_id: The ID of model version to be updated. model_version_update_model: The model version to be updated. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index fe4ced2d120..ee64030a052 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -5535,10 +5535,13 @@ def create_model_version( self, model_version: ModelVersionRequestModel ) -> ModelVersionResponseModel: """Creates a new model version. + Args: model_version: the Model Version to be created. + Returns: The newly created model version. + Raises: EntityExistsError: If a workspace with the given name already exists. """ @@ -5570,11 +5573,14 @@ def get_model_version( model_version_name_or_id: Union[str, UUID], ) -> ModelVersionResponseModel: """Get an existing model version. + Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_id: name or id of the model version to be retrieved. + Returns: The model version of interest. + Raises: KeyError: specified ID or name not found. """ @@ -5588,7 +5594,7 @@ def get_model_version( query = query.where( ModelVersionSchema.id == model_version_name_or_id ) - except: + except ValueError: query = query.where( ModelVersionSchema.version == model_version_name_or_id ) @@ -5605,9 +5611,11 @@ def list_model_versions( model_version_filter_model: ModelVersionFilterModel, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. + Args: model_version_filter_model: All filter parameters including pagination params. + Returns: A page of all model versions. """ @@ -5621,25 +5629,37 @@ def list_model_versions( ) def delete_model_version( - self, model_name_or_id: Union[str, UUID], model_version_name: str + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version. + Args: model_name_or_id: name or id of the model containing the model version. - model_version_name: name of the model version to be deleted. + model_version_name_or_id: name or id of the model version to be deleted. + Raises: KeyError: specified ID or name not found. """ with Session(self.engine) as session: model = self.get_model(model_name_or_id) - model_version = session.exec( - select(ModelVersionSchema) - .where(ModelVersionSchema.model_id == model.id) - .where(ModelVersionSchema.version == model_version_name) - ).first() + query = select(ModelVersionSchema).where( + ModelVersionSchema.model_id == model.id + ) + try: + UUID(str(model_version_name_or_id)) + query = query.where( + ModelVersionSchema.id == model_version_name_or_id + ) + except ValueError: + query = query.where( + ModelVersionSchema.version == model_version_name_or_id + ) + model_version = session.exec(query).first() if model_version is None: raise KeyError( - f"Unable to delete model version with name `{model_version_name}`: " + f"Unable to delete model version with name `{model_version_name_or_id}`: " f"No model version with this name found." ) session.delete(model_version) @@ -5651,6 +5671,7 @@ def update_model_version( model_version_update_model: ModelVersionUpdateModel, ) -> ModelVersionResponseModel: """Get all model versions by filter. + Args: model_version_id: The ID of model version to be updated. model_version_update_model: The model version to be updated. @@ -5769,37 +5790,80 @@ def list_model_version_links( A page of all model version links. """ with Session(self.engine) as session: + # issue: https://github.com/tiangolo/sqlmodel/issues/109 if model_version_link_filter_model.only_artifacts: query = ( select(ModelVersionLinkSchema) - .where(ModelVersionLinkSchema.is_model_object == False) - .where(ModelVersionLinkSchema.is_deployment == False) - .where(ModelVersionLinkSchema.pipeline_run == None) - .where(ModelVersionLinkSchema.artifact != None) + .where( + ModelVersionLinkSchema.is_model_object + == False # noqa: E712 + ) + .where( + ModelVersionLinkSchema.is_deployment + == False # noqa: E712 + ) + .where( + ModelVersionLinkSchema.pipeline_run + == None # noqa: E712, E711 + ) + .where( + ModelVersionLinkSchema.artifact + != None # noqa: E712, E711 + ) ) elif model_version_link_filter_model.only_deployments: query = ( select(ModelVersionLinkSchema) .where(ModelVersionLinkSchema.is_deployment) - .where(ModelVersionLinkSchema.is_model_object == False) - .where(ModelVersionLinkSchema.pipeline_run == None) - .where(ModelVersionLinkSchema.artifact != None) + .where( + ModelVersionLinkSchema.is_model_object + == False # noqa: E712 + ) + .where( + ModelVersionLinkSchema.pipeline_run + == None # noqa: E712, E711 + ) + .where( + ModelVersionLinkSchema.artifact + != None # noqa: E712, E711 + ) ) elif model_version_link_filter_model.only_model_objects: query = ( select(ModelVersionLinkSchema) .where(ModelVersionLinkSchema.is_model_object) - .where(ModelVersionLinkSchema.is_deployment == False) - .where(ModelVersionLinkSchema.pipeline_run == None) - .where(ModelVersionLinkSchema.artifact != None) + .where( + ModelVersionLinkSchema.is_deployment + == False # noqa: E712 + ) + .where( + ModelVersionLinkSchema.pipeline_run + == None # noqa: E712, E711 + ) + .where( + ModelVersionLinkSchema.artifact + != None # noqa: E712, E711 + ) ) elif model_version_link_filter_model.only_pipeline_runs: query = ( select(ModelVersionLinkSchema) - .where(ModelVersionLinkSchema.is_model_object == False) - .where(ModelVersionLinkSchema.is_deployment == False) - .where(ModelVersionLinkSchema.pipeline_run != None) - .where(ModelVersionLinkSchema.artifact == None) + .where( + ModelVersionLinkSchema.is_model_object + == False # noqa: E712 + ) + .where( + ModelVersionLinkSchema.is_deployment + == False # noqa: E712 + ) + .where( + ModelVersionLinkSchema.pipeline_run + != None # noqa: E712, E711 + ) + .where( + ModelVersionLinkSchema.artifact + == None # noqa: E712, E711 + ) ) else: query = select(ModelVersionLinkSchema) @@ -5843,7 +5907,7 @@ def delete_model_version_link( query = query.where( ModelVersionLinkSchema.id == model_version_link_name_or_id ) - except: + except ValueError: query = query.where( ModelVersionLinkSchema.name == model_version_link_name_or_id diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 1c9cba7e018..33630082602 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1731,6 +1731,9 @@ def get_model( Returns: The model of interest. + + Raises: + KeyError: specified ID or name not found. """ @abstractmethod @@ -1757,10 +1760,13 @@ def create_model_version( self, model_version: ModelVersionRequestModel ) -> ModelVersionResponseModel: """Creates a new model version. + Args: model_version: the Model Version to be created. + Returns: The newly created model version. + Raises: EntityExistsError: If a model version with the given name already exists. """ @@ -1770,9 +1776,11 @@ def delete_model_version( self, model_name_or_id: Union[str, UUID], model_version_name: str ) -> None: """Deletes a model version. + Args: model_name_or_id: name or id of the model containing the model version. model_version_name: name of the model version to be deleted. + Raises: KeyError: specified ID or name not found. """ @@ -1784,11 +1792,14 @@ def get_model_version( model_version_name_or_id: Union[str, UUID], ) -> ModelVersionResponseModel: """Get an existing model version. + Args: model_name_or_id: name or id of the model containing the model version. model_version_name_or_id: name or id of the model version to be retrieved. + Returns: The model version of interest. + Raises: KeyError: specified ID or name not found. """ @@ -1799,9 +1810,11 @@ def list_model_versions( model_version_filter_model: ModelVersionFilterModel, ) -> Page[ModelVersionResponseModel]: """Get all model versions by filter. + Args: model_version_filter_model: All filter parameters including pagination params. + Returns: A page of all model versions. """ @@ -1813,6 +1826,7 @@ def update_model_version( model_version_update_model: ModelVersionUpdateModel, ) -> ModelVersionResponseModel: """Get all model versions by filter. + Args: model_version_id: The ID of model version to be updated. model_version_update_model: The model version to be updated. diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 23d2e4832d3..733a8978e5b 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2772,13 +2772,13 @@ def test_model_version_link_delete_found(): zs.delete_model_version_link( model_version.model.id, model_version.id, "link" ) - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) - assert len(l) == 0 + assert len(mvls) == 0 def test_model_version_link_delete_not_found(): @@ -2793,25 +2793,25 @@ def test_model_version_link_delete_not_found(): def test_model_version_link_list_empty(): with ModelVersionContext(True) as model_version: zs = Client().zen_store - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) - assert len(l) == 0 + assert len(mvls) == 0 def test_model_version_link_list_populated(): with ModelVersionContext(True) as model_version: zs = Client().zen_store - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) - assert len(l) == 0 + assert len(mvls) == 0 for n, mo, dep, pr in [ ("link1", False, False, False), ("link2", True, False, False), @@ -2831,46 +2831,46 @@ def test_model_version_link_list_populated(): is_deployment=dep, ) ) - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) - assert len(l) == 4 + assert len(mvls) == 4 - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_artifacts=True, ) ) - assert len(l) == 1 and l[0].name == "link1" + assert len(mvls) == 1 and mvls[0].name == "link1" - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_model_objects=True, ) ) - assert len(l) == 1 and l[0].name == "link2" + assert len(mvls) == 1 and mvls[0].name == "link2" - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_deployments=True, ) ) - assert len(l) == 1 and l[0].name == "link3" + assert len(mvls) == 1 and mvls[0].name == "link3" - l = zs.list_model_version_links( + mvls = zs.list_model_version_links( ModelVersionLinkFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_pipeline_runs=True, ) ) - assert len(l) == 1 and l[0].name == "link4" + assert len(mvls) == 1 and mvls[0].name == "link4" diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 8d416219ce7..f51cb2099a0 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -527,19 +527,19 @@ def __enter__(self): zs = Client().zen_store try: ws = zs.get_workspace(self.workspace) - except: + except KeyError: ws = zs.create_workspace( WorkspaceRequestModel(name=self.workspace) ) self.del_ws = True try: user = zs.get_user(self.user) - except: + except KeyError: user = zs.create_user(UserRequestModel(name=self.user)) self.del_user = True try: model = zs.get_model(self.model) - except: + except KeyError: model = zs.create_model( ModelRequestModel( name=self.model, user=user.id, workspace=ws.id @@ -549,7 +549,7 @@ def __enter__(self): if self.create_version: try: mv = zs.get_model_version(self.model, self.model_version) - except: + except KeyError: mv = zs.create_model_version( ModelVersionRequestModel( user=user.id, From 876730d933e783849f16138a8aac16dc005582e7 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 13 Sep 2023 09:54:58 +0200 Subject: [PATCH 26/40] fix crud tests --- tests/integration/functional/zen_stores/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 89a3f3f87b2..4c0ba0046be 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -832,6 +832,7 @@ def update_method( ), update_model=ModelUpdateModel( name=sample_name("updated_sample_service_connector"), + description="new_description", ), filter_model=ModelFilterModel, entity_name="model", From 07c26eb6548acdf9d84c61b8947610384ed83352 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:51:30 +0200 Subject: [PATCH 27/40] fix alembic branching --- .../migrations/versions/3b68abe58f44_add_model_entity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py b/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py index 541699e23bf..5dca2e3a061 100644 --- a/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py +++ b/src/zenml/zen_stores/migrations/versions/3b68abe58f44_add_model_entity.py @@ -1,7 +1,7 @@ """add model entity [3b68abe58f44]. Revision ID: 3b68abe58f44 -Revises: 0.44.1 +Revises: 0.44.2 Create Date: 2023-09-11 07:53:18.641081 """ @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "3b68abe58f44" -down_revision = "0.44.1" +down_revision = "0.44.2" branch_labels = None depends_on = None From 4e38be737b95449aafa148cb049f7753e02a577e Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:28:56 +0200 Subject: [PATCH 28/40] patch azure --- src/zenml/integrations/azure/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zenml/integrations/azure/__init__.py b/src/zenml/integrations/azure/__init__.py index 2f35b0de2b0..6f5d07df58d 100644 --- a/src/zenml/integrations/azure/__init__.py +++ b/src/zenml/integrations/azure/__init__.py @@ -41,6 +41,7 @@ class AzureIntegration(Integration): "azure-identity==1.10.0", "azureml-core==1.48.0", "azure-mgmt-containerservice>=20.0.0", + "azure-storage-blob==12.17.0", # temporary fix for https://github.com/Azure/azure-sdk-for-python/issues/32056 "kubernetes", ] From 8d14736a28f214b9074e5d7cfe4458d5d6bdd4c9 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 14 Sep 2023 08:15:14 +0200 Subject: [PATCH 29/40] lint --- src/zenml/models/model_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 471dbfb6dc0..9a6bcd32d69 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -20,7 +20,6 @@ from zenml.model import ModelStages from zenml.models.artifact_models import ArtifactResponseModel - from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, @@ -145,7 +144,8 @@ def set_stage( force: whether to force archiving of current model version in target stage or raise. Returns: - Dictionary of Model Objects as model_version_name_or_id""" + Dictionary of Model Objects as model_version_name_or_id + """ from zenml.client import Client return Client().zen_store.update_model_version( From 574ef0342b27f23a34b1bf38e8a9717b14bdbb80 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 14 Sep 2023 08:58:36 +0200 Subject: [PATCH 30/40] use zenml StrEnum --- src/zenml/model/model_stages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/model/model_stages.py b/src/zenml/model/model_stages.py index f1bb7192985..04feaab5ac3 100644 --- a/src/zenml/model/model_stages.py +++ b/src/zenml/model/model_stages.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """ModelStages lists supported stages of a Model Version.""" -from enum import StrEnum +from zenml.utils.enum_utils import StrEnum class ModelStages(StrEnum): From 02bd3cea6c74f8d546edcbf90a022a93d45c4551 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 14 Sep 2023 10:56:54 +0200 Subject: [PATCH 31/40] fix param name --- src/zenml/zen_stores/rest_zen_store.py | 8 +++++--- src/zenml/zen_stores/zen_store_interface.py | 6 ++++-- tests/integration/functional/zen_stores/test_zen_store.py | 4 ++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 04a024e2abc..160b7c6e5d2 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -2396,16 +2396,18 @@ def create_model_version( ) def delete_model_version( - self, model_name_or_id: Union[str, UUID], model_version_name: str + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version. Args: model_name_or_id: name or id of the model containing the model version. - model_version_name: name of the model version to be deleted. + model_version_name_or_id: name or id of the model version to be deleted. """ self._delete_resource( - resource_id=model_version_name, + resource_id=model_version_name_or_id, route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}", ) diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 33630082602..6cd459d20f6 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -1773,13 +1773,15 @@ def create_model_version( @abstractmethod def delete_model_version( - self, model_name_or_id: Union[str, UUID], model_version_name: str + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], ) -> None: """Deletes a model version. Args: model_name_or_id: name or id of the model containing the model version. - model_version_name: name of the model version to be deleted. + model_version_name_or_id: name or id of the model version to be deleted. Raises: KeyError: specified ID or name not found. diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 733a8978e5b..bad3fbe54e9 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2554,7 +2554,7 @@ def test_model_version_delete_not_found(): with pytest.raises(KeyError): zs.delete_model_version( model_name_or_id=model.id, - model_version_name="1.0.0", + model_version_name_or_id="1.0.0", ) @@ -2571,7 +2571,7 @@ def test_model_version_delete_found(): ) zs.delete_model_version( model_name_or_id=model.id, - model_version_name="great one", + model_version_name_or_id="great one", ) with pytest.raises(KeyError): zs.get_model_version( From efed59bda5e085f9688ec85d3c1a664b4d893caa Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:22:21 +0200 Subject: [PATCH 32/40] fix tests in docker --- tests/integration/functional/zen_stores/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index a9a52313d7e..465d386030d 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -512,8 +512,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback): class ModelVersionContext: def __init__(self, create_version: bool = False): - self.workspace = "workspace" - self.user = "su" + client = Client() + self.workspace = client.active_workspace.id + self.user = client.active_user.id self.model = "su_model" self.model_version = "2.0.0" self.del_ws = False From 8160116f009df92191afca9e41f330aba2159f5d Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 12:16:45 +0200 Subject: [PATCH 33/40] fix tests for mysql --- .../functional/zen_stores/test_zen_store.py | 708 +++++++++--------- .../functional/zen_stores/utils.py | 37 +- 2 files changed, 393 insertions(+), 352 deletions(-) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index bad3fbe54e9..9d75289d92b 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -40,6 +40,7 @@ StubLocalRepositoryContext, ) from zenml.client import Client +from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.enums import SecretScope, StackComponentType, StoreType from zenml.exceptions import ( DoesNotExistException, @@ -58,6 +59,7 @@ ModelVersionRequestModel, ModelVersionUpdateModel, PipelineRunFilterModel, + PipelineRunRequestModel, RoleFilterModel, RoleRequestModel, RoleUpdateModel, @@ -2437,31 +2439,22 @@ def test_connector_validation(): ################# -def test_model_version_create_pass(): - with ModelVersionContext() as model: - zs = Client().zen_store - zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", +class TestModelVersion: + def test_model_version_create_pass(self): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) ) - ) - -def test_model_version_create_duplicated(): - with ModelVersionContext() as model: - zs = Client().zen_store - zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", - ) - ) - with pytest.raises(EntityExistsError): + def test_model_version_create_duplicated(self): + with ModelVersionContext() as model: + zs = Client().zen_store zs.create_model_version( ModelVersionRequestModel( user=model.user.id, @@ -2470,280 +2463,277 @@ def test_model_version_create_duplicated(): version="great one", ) ) + with pytest.raises(EntityExistsError): + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + def test_model_version_create_no_model(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=uuid4(), + version="great one", + ) + ) -def test_model_version_create_no_model(): - with ModelVersionContext() as model: - zs = Client().zen_store - with pytest.raises(KeyError): + def test_model_version_get_not_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, model_version_name_or_id="1.0.0" + ) + + def test_model_version_get_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store zs.create_model_version( ModelVersionRequestModel( user=model.user.id, workspace=model.workspace.id, - model=uuid4(), + model=model.id, version="great one", ) ) - - -def test_model_version_get_not_found(): - with ModelVersionContext() as model: - zs = Client().zen_store - with pytest.raises(KeyError): zs.get_model_version( - model_name_or_id=model.id, model_version_name_or_id="1.0.0" + model_name_or_id=model.id, + model_version_name_or_id="great one", ) - -def test_model_version_get_found(): - with ModelVersionContext() as model: - zs = Client().zen_store - zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", + def test_model_version_list_empty(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) ) - ) - zs.get_model_version( - model_name_or_id=model.id, - model_version_name_or_id="great one", - ) + assert len(mvs) == 0 - -def test_model_version_list_empty(): - with ModelVersionContext() as model: - zs = Client().zen_store - mvs = zs.list_model_versions( - ModelVersionFilterModel(model_id=model.id) - ) - assert len(mvs) == 0 - - -def test_model_version_list_not_empty(): - with ModelVersionContext() as model: - zs = Client().zen_store - mv1 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", + def test_model_version_list_not_empty(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) ) - ) - mv2 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="and yet another one", + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="and yet another one", + ) ) - ) - mvs = zs.list_model_versions( - ModelVersionFilterModel(model_id=model.id) - ) - assert len(mvs) == 2 - assert mv1 in mvs - assert mv2 in mvs - - -def test_model_version_delete_not_found(): - with ModelVersionContext() as model: - zs = Client().zen_store - with pytest.raises(KeyError): - zs.delete_model_version( - model_name_or_id=model.id, - model_version_name_or_id="1.0.0", + mvs = zs.list_model_versions( + ModelVersionFilterModel(model_id=model.id) ) + assert len(mvs) == 2 + assert mv1 in mvs + assert mv2 in mvs + def test_model_version_delete_not_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version( + model_name_or_id=model.id, + model_version_name_or_id="1.0.0", + ) -def test_model_version_delete_found(): - with ModelVersionContext() as model: - zs = Client().zen_store - zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", + def test_model_version_delete_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) ) - ) - zs.delete_model_version( - model_name_or_id=model.id, - model_version_name_or_id="great one", - ) - with pytest.raises(KeyError): - zs.get_model_version( + zs.delete_model_version( model_name_or_id=model.id, model_version_name_or_id="great one", ) + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id="great one", + ) + def test_model_version_update_not_found(self): + with ModelVersionContext() as model: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.update_model_version( + model_version_id=uuid4(), + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) -def test_model_version_update_not_found(): - with ModelVersionContext() as model: - zs = Client().zen_store - with pytest.raises(KeyError): + def test_model_version_update_not_forced(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="yet another one", + ) + ) zs.update_model_version( - model_version_id=uuid4(), + model_version_id=mv1.id, model_version_update_model=ModelVersionUpdateModel( model=model.id, stage="staging", force=False, ), ) + with pytest.raises(RuntimeError): + zs.update_model_version( + model_version_id=mv2.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=False, + ), + ) - -def test_model_version_update_not_forced(): - with ModelVersionContext() as model: - zs = Client().zen_store - mv1 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", + def test_model_version_update_forced(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) ) - ) - mv2 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="yet another one", + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="yet another one", + ) ) - ) - zs.update_model_version( - model_version_id=mv1.id, - model_version_update_model=ModelVersionUpdateModel( - model=model.id, - stage="staging", - force=False, - ), - ) - with pytest.raises(RuntimeError): zs.update_model_version( - model_version_id=mv2.id, + model_version_id=mv1.id, model_version_update_model=ModelVersionUpdateModel( model=model.id, stage="staging", force=False, ), ) - - -def test_model_version_update_forced(): - with ModelVersionContext() as model: - zs = Client().zen_store - mv1 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + == "staging" ) - ) - mv2 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="yet another one", + zs.update_model_version( + model_version_id=mv2.id, + model_version_update_model=ModelVersionUpdateModel( + model=model.id, + stage="staging", + force=True, + ), ) - ) - zs.update_model_version( - model_version_id=mv1.id, - model_version_update_model=ModelVersionUpdateModel( - model=model.id, - stage="staging", - force=False, - ), - ) - assert ( - zs.get_model_version( - model_name_or_id=model.id, model_version_name_or_id=mv1.version - ).stage - == "staging" - ) - zs.update_model_version( - model_version_id=mv2.id, - model_version_update_model=ModelVersionUpdateModel( - model=model.id, - stage="staging", - force=True, - ), - ) - assert ( - zs.get_model_version( - model_name_or_id=model.id, model_version_name_or_id=mv1.version - ).stage - == "archived" - ) - assert ( - zs.get_model_version( - model_name_or_id=model.id, model_version_name_or_id=mv2.version - ).stage - == "staging" - ) - - -def test_model_version_update_public_interface(): - with ModelVersionContext() as model: - zs = Client().zen_store - mv1 = zs.create_model_version( - ModelVersionRequestModel( - user=model.user.id, - workspace=model.workspace.id, - model=model.id, - version="great one", + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + == "archived" + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv2.version, + ).stage + == "staging" ) - ) - assert ( - zs.get_model_version( - model_name_or_id=model.id, model_version_name_or_id=mv1.version - ).stage - is None - ) - mv1.set_stage("staging") - assert ( - zs.get_model_version( - model_name_or_id=model.id, model_version_name_or_id=mv1.version - ).stage - == "staging" - ) - -def test_model_version_link_create_pass(): - with ModelVersionContext(True) as model_version: - zs = Client().zen_store - zs.create_model_version_link( - ModelVersionLinkRequestModel( - user=model_version.user.id, - workspace=model_version.workspace.id, - model=model_version.model.id, - model_version=model_version.id, - name="link", - artifact=uuid4(), + def test_model_version_update_public_interface(self): + with ModelVersionContext() as model: + zs = Client().zen_store + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=model.user.id, + workspace=model.workspace.id, + model=model.id, + version="great one", + ) + ) + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + is None + ) + mv1.set_stage("staging") + assert ( + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=mv1.version, + ).stage + == "staging" ) - ) -def test_model_version_link_create_duplicated(): - with ModelVersionContext(True) as model_version: - zs = Client().zen_store - zs.create_model_version_link( - ModelVersionLinkRequestModel( - user=model_version.user.id, - workspace=model_version.workspace.id, - model=model_version.model.id, - model_version=model_version.id, - name="link", - artifact=uuid4(), +class TestModelVersionLink: + def test_model_version_link_create_pass(self): + with ModelVersionContext(True, create_artifact=True) as ( + model_version, + artifact, + ): + zs = Client().zen_store + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=artifact.id, + ) ) - ) - with pytest.raises(EntityExistsError): + def test_model_version_link_create_duplicated(self): + with ModelVersionContext(True, create_artifact=True) as ( + model_version, + artifact, + ): + zs = Client().zen_store zs.create_model_version_link( ModelVersionLinkRequestModel( user=model_version.user.id, @@ -2751,126 +2741,154 @@ def test_model_version_link_create_duplicated(): model=model_version.model.id, model_version=model_version.id, name="link", - artifact=uuid4(), + artifact=artifact.id, ) ) + with pytest.raises(EntityExistsError): + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=artifact.id, + ) + ) -def test_model_version_link_delete_found(): - with ModelVersionContext(True) as model_version: - zs = Client().zen_store - zs.create_model_version_link( - ModelVersionLinkRequestModel( - user=model_version.user.id, - workspace=model_version.workspace.id, - model=model_version.model.id, - model_version=model_version.id, - name="link", - artifact=uuid4(), - ) - ) - zs.delete_model_version_link( - model_version.model.id, model_version.id, "link" - ) - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, + def test_model_version_link_delete_found(self): + with ModelVersionContext(True, create_artifact=True) as ( + model_version, + artifact, + ): + zs = Client().zen_store + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + artifact=artifact.id, + ) ) - ) - assert len(mvls) == 0 - - -def test_model_version_link_delete_not_found(): - with ModelVersionContext(True) as model_version: - zs = Client().zen_store - with pytest.raises(KeyError): zs.delete_model_version_link( model_version.model.id, model_version.id, "link" ) - - -def test_model_version_link_list_empty(): - with ModelVersionContext(True) as model_version: - zs = Client().zen_store - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) ) - ) - assert len(mvls) == 0 + assert len(mvls) == 0 + def test_model_version_link_delete_not_found(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version_link( + model_version.model.id, model_version.id, "link" + ) -def test_model_version_link_list_populated(): - with ModelVersionContext(True) as model_version: - zs = Client().zen_store - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, + def test_model_version_link_list_empty(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) ) - ) - assert len(mvls) == 0 - for n, mo, dep, pr in [ - ("link1", False, False, False), - ("link2", True, False, False), - ("link3", False, True, False), - ("link4", False, False, True), - ]: - zs.create_model_version_link( - ModelVersionLinkRequestModel( - user=model_version.user.id, - workspace=model_version.workspace.id, - model=model_version.model.id, - model_version=model_version.id, - name=n, - artifact=uuid4() if not pr else None, - pipeline_run=uuid4() if pr else None, - is_model_object=mo, - is_deployment=dep, + assert len(mvls) == 0 + + def test_model_version_link_list_populated(self): + with ModelVersionContext(True, create_artifact=True) as ( + model_version, + artifact, + ): + zs = Client().zen_store + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, ) ) - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, + assert len(mvls) == 0 + for n, mo, dep, pr in [ + ("link1", False, False, False), + ("link2", True, False, False), + ("link3", False, True, False), + ("link4", False, False, True), + ]: + if pr: + pipeline_run = zs.create_run( + PipelineRunRequestModel( + id=uuid.uuid4(), + name=sample_name("sample_pipeline_run"), + status="running", + config=PipelineConfiguration(name="aria_pipeline"), + user=model_version.user.id, + workspace=model_version.workspace.id, + ) + ) + zs.create_model_version_link( + ModelVersionLinkRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name=n, + artifact=artifact.id if not pr else None, + pipeline_run=pipeline_run.id if pr else None, + is_model_object=mo, + is_deployment=dep, + ) + ) + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) ) - ) - assert len(mvls) == 4 + assert len(mvls) == 4 - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, - only_artifacts=True, + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_artifacts=True, + ) ) - ) - assert len(mvls) == 1 and mvls[0].name == "link1" + assert len(mvls) == 1 and mvls[0].name == "link1" - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, - only_model_objects=True, + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_model_objects=True, + ) ) - ) - assert len(mvls) == 1 and mvls[0].name == "link2" + assert len(mvls) == 1 and mvls[0].name == "link2" - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, - only_deployments=True, + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_deployments=True, + ) ) - ) - assert len(mvls) == 1 and mvls[0].name == "link3" + assert len(mvls) == 1 and mvls[0].name == "link3" - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, - only_pipeline_runs=True, + mvls = zs.list_model_version_links( + ModelVersionLinkFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + only_pipeline_runs=True, + ) ) - ) - assert len(mvls) == 1 and mvls[0].name == "link4" + assert len(mvls) == 1 and mvls[0].name == "link4" + + if pr: + zs.delete_run(pipeline_run.id) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 465d386030d..f9fe5549ce3 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -511,7 +511,9 @@ def __exit__(self, exc_type, exc_value, exc_traceback): class ModelVersionContext: - def __init__(self, create_version: bool = False): + def __init__( + self, create_version: bool = False, create_artifact: bool = False + ): client = Client() self.workspace = client.active_workspace.id self.user = client.active_user.id @@ -520,9 +522,10 @@ def __init__(self, create_version: bool = False): self.del_ws = False self.del_user = False self.del_model = False - self.del_mv = False self.create_version = create_version + self.create_artifact = create_artifact + self.artifact = None def __enter__(self): zs = Client().zen_store @@ -559,18 +562,38 @@ def __enter__(self): version=self.model_version, ) ) - self.del_mv = True - if self.create_version: - return mv + + if self.create_artifact: + artifact = zs.create_artifact( + ArtifactRequestModel( + name=sample_name("sample_artifact"), + data_type="module.class", + materializer="module.class", + type=ArtifactType.DATA, + uri="", + user=user.id, + workspace=ws.id, + ) + ) + self.artifact = artifact + if self.create_version: + return mv, self.artifact + else: + return model, self.artifact else: - return model + if self.create_version: + return mv + else: + return model def __exit__(self, exc_type, exc_value, exc_traceback): zs = Client().zen_store - if self.del_mv: + if self.create_version: zs.delete_model_version(self.model, self.model_version) if self.del_model: zs.delete_model(self.model) + if self.create_artifact: + zs.delete_artifact(self.artifact.id) if self.del_user: zs.delete_user(self.user) if self.del_ws: From 180295be87a8eb37e3a732eff88027f8b69e8aeb Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 12:34:24 +0200 Subject: [PATCH 34/40] rename artifact ids variables --- src/zenml/models/model_models.py | 40 +++++++++---------- src/zenml/zen_stores/schemas/model_schemas.py | 10 ++--- .../functional/zen_stores/test_zen_store.py | 32 +++++++++++++++ 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 9a6bcd32d69..d2b4e749f6b 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -44,22 +44,6 @@ class ModelVersionBaseModel(BaseModel): title="The stage of the model version", max_length=STR_FIELD_MAX_LENGTH, ) - _model_objects: Dict[str, UUID] = Field( - title="Model Objects linked to the model version", - default={}, - ) - _artifact_objects: Dict[str, UUID] = Field( - title="Artifacts linked to the model version", - default={}, - ) - _deployments: Dict[str, UUID] = Field( - title="Deployments linked to the model version", - default={}, - ) - _pipeline_runs: List[UUID] = Field( - title="Pipeline runs linked to the model version", - default=[], - ) class ModelVersionRequestModel( @@ -82,6 +66,22 @@ class ModelVersionResponseModel( model: "ModelResponseModel" = Field( title="The model containing version", ) + model_object_ids: Dict[str, UUID] = Field( + title="Model Objects linked to the model version", + default={}, + ) + artifact_object_ids: Dict[str, UUID] = Field( + title="Artifacts linked to the model version", + default={}, + ) + deployment_ids: Dict[str, UUID] = Field( + title="Deployments linked to the model version", + default={}, + ) + pipeline_run_ids: List[UUID] = Field( + title="Pipeline runs linked to the model version", + default=[], + ) @staticmethod def _fetch_artifacts_from_list( @@ -103,7 +103,7 @@ def model_objects(self) -> Dict[str, ArtifactResponseModel]: Returns: Dictionary of Model Objects as ArtifactResponseModel """ - return self._fetch_artifacts_from_list(self._model_objects) + return self._fetch_artifacts_from_list(self.model_object_ids) @property def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: @@ -112,7 +112,7 @@ def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: Returns: Dictionary of Artifact Objects as ArtifactResponseModel """ - return self._fetch_artifacts_from_list(self._artifact_objects) + return self._fetch_artifacts_from_list(self.artifact_object_ids) @property def deployments(self) -> Dict[str, ArtifactResponseModel]: @@ -121,7 +121,7 @@ def deployments(self) -> Dict[str, ArtifactResponseModel]: Returns: Dictionary of Deployments as ArtifactResponseModel """ - return self._fetch_artifacts_from_list(self._deployments) + return self._fetch_artifacts_from_list(self.deployment_ids) @property def pipeline_runs(self) -> List[PipelineRunResponseModel]: @@ -132,7 +132,7 @@ def pipeline_runs(self) -> List[PipelineRunResponseModel]: """ from zenml.client import Client - return [Client().get_pipeline_run(pr) for pr in self._pipeline_runs] + return [Client().get_pipeline_run(pr) for pr in self.pipeline_run_ids] def set_stage( self, stage: ModelStages, force: bool = False diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index dc0e8e9bd12..a1d23137277 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -235,24 +235,24 @@ def to_model(self) -> ModelVersionResponseModel: version=self.version, description=self.description, stage=self.stage, - model_objects={ + model_object_ids={ al.name: al.artifact_id for al in self.objects_links if al.artifact_id is not None and al.is_model_object }, - deployments={ + deployment_ids={ al.name: al.artifact_id for al in self.objects_links if al.artifact_id is not None and al.is_deployment }, - artifact_objects={ + artifact_object_ids={ al.name: al.artifact_id for al in self.objects_links if al.artifact_id is not None and not (al.is_deployment or al.is_model_object) }, - pipeline_runs=[ - al.artifact_id + pipeline_run_ids=[ + al.pipeline_run_id for al in self.objects_links if al.pipeline_run_id is not None ], diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 9d75289d92b..638bc3b67a5 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -51,6 +51,7 @@ from zenml.logging.step_logging import prepare_logs_uri from zenml.models import ( ArtifactFilterModel, + ArtifactResponseModel, ComponentFilterModel, ComponentUpdateModel, ModelVersionFilterModel, @@ -60,6 +61,7 @@ ModelVersionUpdateModel, PipelineRunFilterModel, PipelineRunRequestModel, + PipelineRunResponseModel, RoleFilterModel, RoleRequestModel, RoleUpdateModel, @@ -2890,5 +2892,35 @@ def test_model_version_link_list_populated(self): ) assert len(mvls) == 1 and mvls[0].name == "link4" + mv = zs.get_model_version( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + ) + + assert len(mv.model_object_ids) == 1 + assert len(mv.artifact_object_ids) == 1 + assert len(mv.deployment_ids) == 1 + assert len(mv.pipeline_run_ids) == 1 + + assert isinstance( + mv.model_objects["link2"], + ArtifactResponseModel, + ) + assert isinstance( + mv.artifact_objects["link1"], + ArtifactResponseModel, + ) + assert isinstance( + mv.deployments["link3"], + ArtifactResponseModel, + ) + assert isinstance( + mv.pipeline_runs[0], + PipelineRunResponseModel, + ) + + assert mv.pipeline_runs[0].id == pipeline_run.id + assert mv.model_objects["link2"].id == artifact.id + if pr: zs.delete_run(pipeline_run.id) From 6db61b12cd043f70f99e22cb7b11d18127977cc4 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 12:35:43 +0200 Subject: [PATCH 35/40] reorder methods --- .../zen_server/routers/models_endpoints.py | 76 +++++++++---------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 2a386d2ebac..7de11e78655 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -83,23 +83,6 @@ def list_models( ) -@router.delete( - "/{model_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model( - model_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Delete a model by name or ID. - - Args: - model_name_or_id: The name or ID of the model to delete. - """ - zen_store().delete_model(model_name_or_id) - - @router.get( "/{model_name_or_id}", response_model=ModelResponseModel, @@ -147,6 +130,23 @@ def update_model( ) +@router.delete( + "/{model_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model( + model_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model to delete. + """ + zen_store().delete_model(model_name_or_id) + + ################# # Model Versions ################# @@ -178,27 +178,6 @@ def list_model_versions( ) -@router.delete( - "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", - responses={401: error_response, 404: error_response, 422: error_response}, -) -@handle_exceptions -def delete_model_version( - model_name_or_id: Union[str, UUID], - model_version_name_or_id: Union[str, UUID], - _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), -) -> None: - """Delete a model by name or ID. - - Args: - model_name_or_id: The name or ID of the model containing version. - model_version_name_or_id: The name or ID of the model version to delete. - """ - zen_store().delete_model_version( - model_name_or_id, model_version_name_or_id - ) - - @router.get( "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", response_model=ModelVersionResponseModel, @@ -250,6 +229,27 @@ def update_model_version( ) +@router.delete( + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Delete a model by name or ID. + + Args: + model_name_or_id: The name or ID of the model containing version. + model_version_name_or_id: The name or ID of the model version to delete. + """ + zen_store().delete_model_version( + model_name_or_id, model_version_name_or_id + ) + + ###################### # Model Version Links ###################### From 2c8c08368d1cf2dad7f1a889d3d072ef0d0d6962 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:50:38 +0200 Subject: [PATCH 36/40] add direct getters --- src/zenml/models/model_models.py | 97 +++++++++++++++---- src/zenml/zen_stores/schemas/model_schemas.py | 6 +- .../functional/zen_stores/test_zen_store.py | 11 ++- 3 files changed, 89 insertions(+), 25 deletions(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index d2b4e749f6b..1238489035c 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -78,24 +78,11 @@ class ModelVersionResponseModel( title="Deployments linked to the model version", default={}, ) - pipeline_run_ids: List[UUID] = Field( + pipeline_run_ids: Dict[str, UUID] = Field( title="Pipeline runs linked to the model version", default=[], ) - @staticmethod - def _fetch_artifacts_from_list( - artifacts: Dict[str, UUID] - ) -> Dict[str, ArtifactResponseModel]: - from zenml.client import Client - - if artifacts: - return { - name: Client().get_artifact(a) for name, a in artifacts.items() - } - else: - return {} - @property def model_objects(self) -> Dict[str, ArtifactResponseModel]: """Get all model objects linked to this version. @@ -103,7 +90,12 @@ def model_objects(self) -> Dict[str, ArtifactResponseModel]: Returns: Dictionary of Model Objects as ArtifactResponseModel """ - return self._fetch_artifacts_from_list(self.model_object_ids) + from zenml.client import Client + + return { + name: Client().get_artifact(a) + for name, a in self.model_object_ids.items() + } @property def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: @@ -112,7 +104,12 @@ def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: Returns: Dictionary of Artifact Objects as ArtifactResponseModel """ - return self._fetch_artifacts_from_list(self.artifact_object_ids) + from zenml.client import Client + + return { + name: Client().get_artifact(a) + for name, a in self.artifact_object_ids.items() + } @property def deployments(self) -> Dict[str, ArtifactResponseModel]: @@ -121,18 +118,78 @@ def deployments(self) -> Dict[str, ArtifactResponseModel]: Returns: Dictionary of Deployments as ArtifactResponseModel """ - return self._fetch_artifacts_from_list(self.deployment_ids) + from zenml.client import Client + + return { + name: Client().get_artifact(a) + for name, a in self.deployment_ids.items() + } @property - def pipeline_runs(self) -> List[PipelineRunResponseModel]: + def pipeline_runs(self) -> Dict[str, PipelineRunResponseModel]: """Get all pipeline runs linked to this version. Returns: - List of Pipeline Runs as PipelineRunResponseModel + Dictionary of Pipeline Runs as PipelineRunResponseModel + """ + from zenml.client import Client + + return { + name: Client().get_pipeline_run(pr) + for name, pr in self.pipeline_run_ids.items() + } + + def get_model_object(self, name: str) -> ArtifactResponseModel: + """Get model object linked to this version. + + Args: + name: The name of the model object to retrieve. + + Returns: + Model Object as ArtifactResponseModel + """ + from zenml.client import Client + + return Client().get_artifact(self.model_object_ids[name]) + + def get_artifact_object(self, name: str) -> ArtifactResponseModel: + """Get artifact linked to this version. + + Args: + name: The name of the artifact to retrieve. + + Returns: + Artifact Object as ArtifactResponseModel + """ + from zenml.client import Client + + return Client().get_artifact(self.artifact_object_ids[name]) + + def get_deployment(self, name: str) -> ArtifactResponseModel: + """Get deployment linked to this version. + + Args: + name: The name of the deployment to retrieve. + + Returns: + Deployment as ArtifactResponseModel + """ + from zenml.client import Client + + return Client().get_artifact(self.deployment_ids[name]) + + def get_pipeline_run(self, name: str) -> PipelineRunResponseModel: + """Get pipeline run linked to this version. + + Args: + name: The name of the pipeline run to retrieve. + + Returns: + PipelineRun as PipelineRunResponseModel """ from zenml.client import Client - return [Client().get_pipeline_run(pr) for pr in self.pipeline_run_ids] + return Client().get_pipeline_run(self.pipeline_run_ids[name]) def set_stage( self, stage: ModelStages, force: bool = False diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index a1d23137277..0b4da5877a1 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -251,11 +251,11 @@ def to_model(self) -> ModelVersionResponseModel: if al.artifact_id is not None and not (al.is_deployment or al.is_model_object) }, - pipeline_run_ids=[ - al.pipeline_run_id + pipeline_run_ids={ + al.name: al.pipeline_run_id for al in self.objects_links if al.pipeline_run_id is not None - ], + }, ) def update( diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 638bc3b67a5..64d729f2d70 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -2915,12 +2915,19 @@ def test_model_version_link_list_populated(self): ArtifactResponseModel, ) assert isinstance( - mv.pipeline_runs[0], + mv.pipeline_runs["link4"], PipelineRunResponseModel, ) - assert mv.pipeline_runs[0].id == pipeline_run.id + assert mv.pipeline_runs["link4"].id == pipeline_run.id assert mv.model_objects["link2"].id == artifact.id + assert mv.get_model_object("link2") == mv.model_objects["link2"] + assert ( + mv.get_artifact_object("link1") == mv.artifact_objects["link1"] + ) + assert mv.get_deployment("link3") == mv.deployments["link3"] + assert mv.get_pipeline_run("link4") == mv.pipeline_runs["link4"] + if pr: zs.delete_run(pipeline_run.id) From 243477166424ee09734ff95bbb60051ed3abdb1e Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 14:26:03 +0200 Subject: [PATCH 37/40] lint --- src/zenml/models/model_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 1238489035c..6a586be6145 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -80,7 +80,7 @@ class ModelVersionResponseModel( ) pipeline_run_ids: Dict[str, UUID] = Field( title="Pipeline runs linked to the model version", - default=[], + default={}, ) @property From 567a1bd53bd55d61ba8510f947bbd2ff8bc6cb71 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:12:15 +0200 Subject: [PATCH 38/40] split links into 2 tables --- src/zenml/constants.py | 1 - src/zenml/models/__init__.py | 42 +-- src/zenml/models/model_models.py | 104 +++--- .../zen_server/routers/models_endpoints.py | 114 ++++-- .../routers/workspaces_endpoints.py | 158 +++++++-- ...d9599a008d_add_model_version_and_links.py} | 73 +++- src/zenml/zen_stores/rest_zen_store.py | 121 +++++-- src/zenml/zen_stores/schemas/__init__.py | 6 +- .../zen_stores/schemas/artifact_schemas.py | 6 +- src/zenml/zen_stores/schemas/model_schemas.py | 167 +++++++-- .../schemas/pipeline_run_schemas.py | 8 +- src/zenml/zen_stores/schemas/user_schemas.py | 12 +- .../zen_stores/schemas/workspace_schemas.py | 13 +- src/zenml/zen_stores/sql_zen_store.py | 313 +++++++++++------ src/zenml/zen_stores/zen_store_interface.py | 101 ++++-- .../functional/zen_stores/test_zen_store.py | 327 +++++++++++++----- .../functional/zen_stores/utils.py | 68 ++-- 17 files changed, 1183 insertions(+), 451 deletions(-) rename src/zenml/zen_stores/migrations/versions/{e8b82e9253a9_add_model_version_and_links.py => cdd9599a008d_add_model_version_and_links.py} (63%) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 46df4fac5b9..f36ef7db778 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -223,7 +223,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_CLIENT = "/client" MODELS = "/models" MODEL_VERSIONS = "/model_versions" -MODEL_VERSION_LINKS = "/model_version_links" # mandatory stack component attributes MANDATORY_COMPONENT_ATTRIBUTES = ["name", "uuid"] diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index de8ce8e1f12..10ed29ae571 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -151,16 +151,17 @@ ModelResponseModel, ModelRequestModel, ModelUpdateModel, - ModelConfigBaseModel, - ModelConfigResponseModel, - ModelConfigRequestModel, ModelVersionBaseModel, ModelVersionResponseModel, ModelVersionRequestModel, - ModelVersionLinkBaseModel, - ModelVersionLinkFilterModel, - ModelVersionLinkRequestModel, - ModelVersionLinkResponseModel, + ModelVersionArtifactBaseModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionPipelineRunBaseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, ModelVersionFilterModel, ModelVersionUpdateModel, ) @@ -283,32 +284,29 @@ WorkspaceResponseModel=WorkspaceResponseModel, ) -ModelConfigRequestModel.update_forward_refs( +ModelVersionRequestModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) -ModelConfigResponseModel.update_forward_refs( +ModelVersionResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) -ModelVersionRequestModel.update_forward_refs( +ModelVersionArtifactRequestModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) - -ModelVersionResponseModel.update_forward_refs( +ModelVersionArtifactResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) - -ModelVersionLinkRequestModel.update_forward_refs( +ModelVersionPipelineRunRequestModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) - -ModelVersionLinkResponseModel.update_forward_refs( +ModelVersionPipelineRunResponseModel.update_forward_refs( UserResponseModel=UserResponseModel, WorkspaceResponseModel=WorkspaceResponseModel, ) @@ -418,8 +416,12 @@ "ModelVersionRequestModel", "ModelVersionResponseModel", "ModelVersionUpdateModel", - "ModelVersionLinkBaseModel", - "ModelVersionLinkFilterModel", - "ModelVersionLinkRequestModel", - "ModelVersionLinkResponseModel", + "ModelVersionArtifactBaseModel", + "ModelVersionArtifactFilterModel", + "ModelVersionArtifactRequestModel", + "ModelVersionArtifactResponseModel", + "ModelVersionPipelineRunBaseModel", + "ModelVersionPipelineRunFilterModel", + "ModelVersionPipelineRunRequestModel", + "ModelVersionPipelineRunResponseModel", ] diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index 6a586be6145..543de9ca721 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -253,49 +253,45 @@ class ModelVersionUpdateModel(BaseModel): ) -class ModelVersionLinkBaseModel(BaseModel): - """Model version links base model.""" +class ModelVersionArtifactBaseModel(BaseModel): + """Model version links with artifact base model.""" - name: str = Field( + name: Optional[str] = Field( title="The name of the artifact inside model version.", max_length=STR_FIELD_MAX_LENGTH, ) - artifact: Optional[UUID] - pipeline_run: Optional[UUID] + artifact: UUID model: UUID model_version: UUID is_model_object: bool = False is_deployment: bool = False - @validator("model_version") - def _validate_links( - cls, model_version: UUID, values: Dict[str, Any] - ) -> UUID: - artifact = values.get("artifact", None) - pipeline_run = values.get("pipeline_run", None) - if (artifact is None and pipeline_run is None) or ( - artifact is not None and pipeline_run is not None - ): + @validator("is_deployment") + def _validate_is_deployment( + cls, is_deployment: bool, values: Dict[str, Any] + ) -> bool: + is_model_object = values.get("is_model_object", False) + if is_model_object and is_deployment: raise ValueError( - "You must provide only `artifact` or only `pipeline_run`." + "Artifact cannot be a model object and deployment at the same time." ) - return model_version + return is_deployment -class ModelVersionLinkRequestModel( - ModelVersionLinkBaseModel, WorkspaceScopedRequestModel +class ModelVersionArtifactRequestModel( + ModelVersionArtifactBaseModel, WorkspaceScopedRequestModel ): - """Model version links request model.""" + """Model version link with artifact request model.""" -class ModelVersionLinkResponseModel( - ModelVersionLinkBaseModel, WorkspaceScopedResponseModel +class ModelVersionArtifactResponseModel( + ModelVersionArtifactBaseModel, WorkspaceScopedResponseModel ): - """Model version links response model.""" + """Model version link with artifact response model.""" -class ModelVersionLinkFilterModel(WorkspaceScopedFilterModel): - """Model version links filter model.""" +class ModelVersionArtifactFilterModel(WorkspaceScopedFilterModel): + """Model version pipeline run links filter model.""" model_id: Union[str, UUID] = Field( description="The name or ID of the Model", @@ -316,45 +312,47 @@ class ModelVersionLinkFilterModel(WorkspaceScopedFilterModel): only_artifacts: Optional[bool] = False only_model_objects: Optional[bool] = False only_deployments: Optional[bool] = False - only_pipeline_runs: Optional[bool] = False - - @validator("only_pipeline_runs") - def _validate_flags( - cls, only_pipeline_runs: bool, values: Dict[str, Any] - ) -> bool: - s = int(only_pipeline_runs) - s += int(values.get("only_artifacts", False)) - s += int(values.get("only_model_objects", False)) - s += int(values.get("only_deployments", False)) - if s > 1: - raise ValueError( - "Only one of the selection flags can be used at once." - ) - return only_pipeline_runs -class ModelConfigBaseModel(BaseModel): - """Model Config base model.""" +class ModelVersionPipelineRunBaseModel(BaseModel): + """Model version links with pipeline run base model.""" - pass + name: Optional[str] = Field( + title="The name of the pipeline run inside model version.", + max_length=STR_FIELD_MAX_LENGTH, + ) + pipeline_run: UUID + model: UUID + model_version: UUID -class ModelConfigRequestModel( - ModelConfigBaseModel, - WorkspaceScopedRequestModel, +class ModelVersionPipelineRunRequestModel( + ModelVersionPipelineRunBaseModel, WorkspaceScopedRequestModel ): - """Model Config request model.""" - - pass + """Model version link with pipeline run request model.""" -class ModelConfigResponseModel( - ModelConfigBaseModel, - WorkspaceScopedResponseModel, +class ModelVersionPipelineRunResponseModel( + ModelVersionPipelineRunBaseModel, WorkspaceScopedResponseModel ): - """Model Config response model.""" + """Model version link with pipeline run response model.""" - pass + +class ModelVersionPipelineRunFilterModel(WorkspaceScopedFilterModel): + """Model version pipeline run links filter model.""" + + model_id: Union[str, UUID] = Field( + description="The name or ID of the Model", + ) + model_version_id: Union[str, UUID] = Field( + description="The name or ID of the Model Version", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="The workspace of the Model Version" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="The user of the Model Version" + ) class ModelBaseModel(BaseModel): diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 7de11e78655..e1790b27daf 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -20,9 +20,10 @@ from zenml.constants import ( API, - MODEL_VERSION_LINKS, + ARTIFACTS, MODEL_VERSIONS, MODELS, + RUNS, VERSION_1, ) from zenml.enums import PermissionType @@ -30,9 +31,11 @@ ModelFilterModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionLinkFilterModel, - ModelVersionLinkResponseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunResponseModel, ModelVersionResponseModel, ModelVersionUpdateModel, ) @@ -250,37 +253,37 @@ def delete_model_version( ) -###################### -# Model Version Links -###################### +########################## +# Model Version Artifacts +########################## @router.get( "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}" - + MODEL_VERSION_LINKS, - response_model=Page[ModelVersionLinkResponseModel], + + ARTIFACTS, + response_model=Page[ModelVersionArtifactResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions -def list_model_version_links( - model_version_link_filter_model: ModelVersionLinkFilterModel = Depends( - make_dependable(ModelVersionLinkFilterModel) +def list_model_version_artifact_links( + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( + make_dependable(ModelVersionArtifactFilterModel) ), _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[ModelVersionLinkResponseModel]: - """Get model version links according to query filters. +) -> Page[ModelVersionArtifactResponseModel]: + """Get model version to artifact links according to query filters. Args: - model_version_link_filter_model: Filter model used for pagination, sorting, + model_version_artifact_link_filter_model: Filter model used for pagination, sorting, filtering Returns: - The model version links according to query filters. + The model version to artifact links according to query filters. """ - return zen_store().list_model_version_links( - model_version_link_filter_model=model_version_link_filter_model, + return zen_store().list_model_version_artifact_links( + model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, ) @@ -288,15 +291,15 @@ def list_model_version_links( "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}" - + MODEL_VERSION_LINKS - + "/{model_version_link_name_or_id}", + + ARTIFACTS + + "/{model_version_artifact_link_name_or_id}", responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions -def delete_model_version_link( +def delete_model_version_artifact_link( model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - model_version_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), ) -> None: """Deletes a model version link. @@ -304,10 +307,73 @@ def delete_model_version_link( Args: model_name_or_id: name or ID of the model containing the model version. model_version_name_or_id: name or ID of the model version containing the link. - model_version_link_name_or_id: name or ID of the model version link to be deleted. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. """ - zen_store().delete_model_version_link( + zen_store().delete_model_version_artifact_link( model_name_or_id, model_version_name_or_id, - model_version_link_name_or_id, + model_version_artifact_link_name_or_id, + ) + + +############################## +# Model Version Pipeline Runs +############################## + + +@router.get( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS, + response_model=Page[ModelVersionPipelineRunResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_model_version_pipeline_run_links( + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( + make_dependable(ModelVersionPipelineRunFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionPipelineRunResponseModel]: + """Get model version to pipeline run links according to query filters. + + Args: + model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to pipeline run links according to query filters. + """ + return zen_store().list_model_version_pipeline_run_links( + model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, + ) + + +@router.delete( + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS + + "/{model_version_pipeline_run_link_name_or_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_model_version_pipeline_run_link( + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + _: AuthContext = Security(authorize, scopes=[PermissionType.WRITE]), +) -> None: + """Deletes a model version link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version link to be deleted. + """ + zen_store().delete_model_version_pipeline_run_link( + model_name_or_id, + model_version_name_or_id, + model_version_pipeline_run_link_name_or_id, ) diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 39da18affca..b6ce0edcd6c 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -19,9 +19,9 @@ from zenml.constants import ( API, + ARTIFACTS, CODE_REPOSITORIES, GET_OR_CREATE, - MODEL_VERSION_LINKS, MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, @@ -53,10 +53,13 @@ ModelFilterModel, ModelRequestModel, ModelResponseModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionLinkFilterModel, - ModelVersionLinkRequestModel, - ModelVersionLinkResponseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, PipelineBuildFilterModel, @@ -1295,7 +1298,7 @@ def list_workspace_model_versions( ), _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), ) -> Page[ModelVersionResponseModel]: - """Get models according to query filters. + """Get model versions according to query filters. Args: workspace_name_or_id: Name or ID of the workspace. @@ -1320,31 +1323,31 @@ def list_workspace_model_versions( + "/{model_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}" - + MODEL_VERSION_LINKS, - response_model=ModelVersionLinkResponseModel, + + ARTIFACTS, + response_model=ModelVersionArtifactResponseModel, responses={401: error_response, 409: error_response, 422: error_response}, ) @handle_exceptions -def create_model_version_link( +def create_model_version_artifact_link( workspace_name_or_id: Union[str, UUID], model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - model_version_link: ModelVersionLinkRequestModel, + model_version_artifact_link: ModelVersionArtifactRequestModel, auth_context: AuthContext = Security( authorize, scopes=[PermissionType.WRITE] ), -) -> ModelVersionLinkResponseModel: - """Create a new model version link. +) -> ModelVersionArtifactResponseModel: + """Create a new model version to artifact link. Args: model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. model_version_name_or_id: Name or ID of the model version. - model_version_link: The model version link to create. + model_version_artifact_link: The model version to artifact link to create. auth_context: Authentication context. Returns: - The created model version link. + The created model version to artifact link. Raises: IllegalOperationError: If the workspace or user specified in the @@ -1353,18 +1356,115 @@ def create_model_version_link( """ workspace = zen_store().get_workspace(workspace_name_or_id) - if model_version_link.workspace != workspace.id: + if model_version_artifact_link.workspace != workspace.id: + raise IllegalOperationError( + "Creating model version to artifact links outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + if model_version_artifact_link.user != auth_context.user.id: + raise IllegalOperationError( + "Creating model to artifact links for a user other than yourself " + "is not supported." + ) + mv = zen_store().create_model_version_artifact_link( + model_version_artifact_link + ) + return mv + + +@router.get( + WORKSPACES + + "/{workspace_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + ARTIFACTS, + response_model=Page[ModelVersionArtifactResponseModel], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_workspace_model_version_artifact_links( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel = Depends( + make_dependable(ModelVersionArtifactFilterModel) + ), + _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), +) -> Page[ModelVersionArtifactResponseModel]: + """Get model version to artifact links according to query filters. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_artifact_link_filter_model: Filter model used for pagination, sorting, + filtering + + Returns: + The model version to artifact links according to query filters. + """ + workspace_id = zen_store().get_workspace(workspace_name_or_id).id + model_version_artifact_link_filter_model.set_scope_workspace(workspace_id) + return zen_store().list_model_version_artifact_links( + model_version_artifact_link_filter_model=model_version_artifact_link_filter_model, + ) + + +@router.post( + WORKSPACES + + "/{workspace_name_or_id}" + + MODELS + + "/{model_name_or_id}" + + MODEL_VERSIONS + + "/{model_version_name_or_id}" + + RUNS, + response_model=ModelVersionPipelineRunResponseModel, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_model_version_pipeline_run_link( + workspace_name_or_id: Union[str, UUID], + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + auth_context: AuthContext = Security( + authorize, scopes=[PermissionType.WRITE] + ), +) -> ModelVersionPipelineRunResponseModel: + """Create a new model version to pipeline run link. + + Args: + model_name_or_id: Name or ID of the model. + workspace_name_or_id: Name or ID of the workspace. + model_version_name_or_id: Name or ID of the model version. + model_version_pipeline_run_link: The model version to pipeline run link to create. + auth_context: Authentication context. + + Returns: + The created model version to pipeline run link. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model version does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if model_version_pipeline_run_link.workspace != workspace.id: raise IllegalOperationError( "Creating model versions outside of the workspace scope " f"of this endpoint `{workspace_name_or_id}` is " f"not supported." ) - if model_version_link.user != auth_context.user.id: + if model_version_pipeline_run_link.user != auth_context.user.id: raise IllegalOperationError( "Creating models for a user other than yourself " "is not supported." ) - mv = zen_store().create_model_version_link(model_version_link) + mv = zen_store().create_model_version_pipeline_run_link( + model_version_pipeline_run_link + ) return mv @@ -1373,34 +1473,36 @@ def create_model_version_link( + "/{workspace_name_or_id}" + MODEL_VERSIONS + "/{model_version_name_or_id}" - + MODEL_VERSION_LINKS, - response_model=Page[ModelVersionLinkResponseModel], + + RUNS, + response_model=Page[ModelVersionPipelineRunResponseModel], responses={401: error_response, 404: error_response, 422: error_response}, ) @handle_exceptions -def list_workspace_model_version_links( +def list_workspace_model_version_pipeline_run_links( workspace_name_or_id: Union[str, UUID], model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - model_version_link_filter_model: ModelVersionLinkFilterModel = Depends( - make_dependable(ModelVersionLinkFilterModel) + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel = Depends( + make_dependable(ModelVersionPipelineRunResponseModel) ), _: AuthContext = Security(authorize, scopes=[PermissionType.READ]), -) -> Page[ModelVersionLinkResponseModel]: - """Get models according to query filters. +) -> Page[ModelVersionPipelineRunResponseModel]: + """Get model version to pipeline links according to query filters. Args: model_name_or_id: Name or ID of the model. workspace_name_or_id: Name or ID of the workspace. model_version_name_or_id: Name or ID of the model version. - model_version_link_filter_model: Filter model used for pagination, sorting, + model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, filtering Returns: - The model version links according to query filters. + The model version to pipeline run links according to query filters. """ workspace_id = zen_store().get_workspace(workspace_name_or_id).id - model_version_link_filter_model.set_scope_workspace(workspace_id) - return zen_store().list_model_version_links( - model_version_link_filter_model=model_version_link_filter_model, + model_version_pipeline_run_link_filter_model.set_scope_workspace( + workspace_id + ) + return zen_store().list_model_version_pipeline_run_links( + model_version_pipeline_run_link_filter_model=model_version_pipeline_run_link_filter_model, ) diff --git a/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py b/src/zenml/zen_stores/migrations/versions/cdd9599a008d_add_model_version_and_links.py similarity index 63% rename from src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py rename to src/zenml/zen_stores/migrations/versions/cdd9599a008d_add_model_version_and_links.py index 87a7d77d634..e694bb0f574 100644 --- a/src/zenml/zen_stores/migrations/versions/e8b82e9253a9_add_model_version_and_links.py +++ b/src/zenml/zen_stores/migrations/versions/cdd9599a008d_add_model_version_and_links.py @@ -1,8 +1,8 @@ -"""add model_version and links [e8b82e9253a9]. +"""add model_version and links [cdd9599a008d]. -Revision ID: e8b82e9253a9 +Revision ID: cdd9599a008d Revises: 3b68abe58f44 -Create Date: 2023-09-11 18:05:43.367994 +Create Date: 2023-09-15 17:53:23.963414 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision = "e8b82e9253a9" +revision = "cdd9599a008d" down_revision = "3b68abe58f44" branch_labels = None depends_on = None @@ -53,19 +53,16 @@ def upgrade() -> None: sa.PrimaryKeyConstraint("id"), ) op.create_table( - "model_version_links", + "model_versions_artifacts", sa.Column( "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False ), sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), sa.Column( "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=False ), - sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), sa.Column("artifact_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), - sa.Column( - "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True - ), sa.Column("is_model_object", sa.BOOLEAN(), nullable=True), sa.Column("is_deployment", sa.BOOLEAN(), nullable=True), sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), @@ -75,37 +72,80 @@ def upgrade() -> None: sa.ForeignKeyConstraint( ["artifact_id"], ["artifact.id"], - name="fk_model_version_links_artifact_id_artifact", + name="fk_model_versions_artifacts_artifact_id_artifact", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["model_id"], + ["model.id"], + name="fk_model_versions_artifacts_model_id_model", + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["model_version_id"], + ["model_version.id"], + name="fk_model_versions_artifacts_model_version_id_model_version", ondelete="CASCADE", ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_model_versions_artifacts_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_model_versions_artifacts_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "model_versions_runs", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("model_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column( + "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column( + "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.ForeignKeyConstraint( ["model_id"], ["model.id"], - name="fk_model_version_links_model_id_model", + name="fk_model_versions_runs_model_id_model", ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["model_version_id"], ["model_version.id"], - name="fk_model_version_links_model_version_id_model_version", + name="fk_model_versions_runs_model_version_id_model_version", ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["pipeline_run_id"], ["pipeline_run.id"], - name="fk_model_version_links_run_id_pipeline_run", + name="fk_model_versions_runs_run_id_pipeline_run", ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], - name="fk_model_version_links_user_id_user", + name="fk_model_versions_runs_user_id_user", ondelete="SET NULL", ), sa.ForeignKeyConstraint( ["workspace_id"], ["workspace.id"], - name="fk_model_version_links_workspace_id_workspace", + name="fk_model_versions_runs_workspace_id_workspace", ondelete="CASCADE", ), sa.PrimaryKeyConstraint("id"), @@ -116,6 +156,7 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("model_version_links") + op.drop_table("model_versions_runs") + op.drop_table("model_versions_artifacts") op.drop_table("model_version") # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 160b7c6e5d2..4ebbd308b03 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -49,7 +49,6 @@ GET_OR_CREATE, INFO, LOGIN, - MODEL_VERSION_LINKS, MODEL_VERSIONS, MODELS, PIPELINE_BUILDS, @@ -101,10 +100,13 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionLinkFilterModel, - ModelVersionLinkRequestModel, - ModelVersionLinkResponseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ModelVersionUpdateModel, @@ -2472,62 +2474,121 @@ def update_model_version( response_model=ModelVersionResponseModel, ) - ####################### - # Model Versions Links - ####################### + ########################### + # Model Versions Artifacts + ########################### - def create_model_version_link( - self, model_version_link: ModelVersionLinkRequestModel - ) -> ModelVersionLinkResponseModel: + def create_model_version_artifact_link( + self, model_version_artifact_link: ModelVersionArtifactRequestModel + ) -> ModelVersionArtifactResponseModel: """Creates a new model version link. Args: - model_version_link: the Model Version Link to be created. + model_version_artifact_link: the Model Version to Artifact Link to be created. Returns: - The newly created model version link. + The newly created model version to artifact link. """ return self._create_workspace_scoped_resource( - resource=model_version_link, - response_model=ModelVersionLinkResponseModel, - route=f"{MODELS}/{model_version_link.model}{MODEL_VERSIONS}/{model_version_link.model_version}{MODEL_VERSION_LINKS}", + resource=model_version_artifact_link, + response_model=ModelVersionArtifactResponseModel, + route=f"{MODELS}/{model_version_artifact_link.model}{MODEL_VERSIONS}/{model_version_artifact_link.model_version}{ARTIFACTS}", ) - def list_model_version_links( + def list_model_version_artifact_links( self, - model_version_link_filter_model: ModelVersionLinkFilterModel, - ) -> Page[ModelVersionLinkResponseModel]: - """Get all model version links by filter. + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, + ) -> Page[ModelVersionArtifactResponseModel]: + """Get all model version to artifact links by filter. Args: - model_version_link_filter_model: All filter parameters including pagination + model_version_artifact_link_filter_model: All filter parameters including pagination params. Returns: - A page of all model version links. + A page of all model version to artifact links. """ return self._list_paginated_resources( - route=f"{MODELS}/{model_version_link_filter_model.model_id}{MODEL_VERSIONS}/{model_version_link_filter_model.model_version_id}{MODEL_VERSION_LINKS}", - response_model=ModelVersionLinkResponseModel, - filter_model=model_version_link_filter_model, + route=f"{MODELS}/{model_version_artifact_link_filter_model.model_id}{MODEL_VERSIONS}/{model_version_artifact_link_filter_model.model_version_id}{ARTIFACTS}", + response_model=ModelVersionArtifactResponseModel, + filter_model=model_version_artifact_link_filter_model, ) - def delete_model_version_link( + def delete_model_version_artifact_link( self, model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - model_version_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: - """Deletes a model version link. + """Deletes a model version to artifact link. Args: model_name_or_id: name or ID of the model containing the model version. model_version_name_or_id: name or ID of the model version containing the link. - model_version_link_name_or_id: name or ID of the model version link to be deleted. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. """ self._delete_resource( - resource_id=model_version_link_name_or_id, - route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{MODEL_VERSION_LINKS}", + resource_id=model_version_artifact_link_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{ARTIFACTS}", + ) + + ############################### + # Model Versions Pipeline Runs + ############################### + + def create_model_version_pipeline_run_link( + self, + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + ) -> ModelVersionPipelineRunResponseModel: + """Creates a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: the Model Version to Pipeline Run Link to be created. + + Returns: + The newly created model version to pipeline run link. + """ + return self._create_workspace_scoped_resource( + resource=model_version_pipeline_run_link, + response_model=ModelVersionPipelineRunResponseModel, + route=f"{MODELS}/{model_version_pipeline_run_link.model}{MODEL_VERSIONS}/{model_version_pipeline_run_link.model_version}{RUNS}", + ) + + def list_model_version_pipeline_run_links( + self, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, + ) -> Page[ModelVersionPipelineRunResponseModel]: + """Get all model version to pipeline run links by filter. + + Args: + model_version_pipeline_run_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to pipeline run links. + """ + return self._list_paginated_resources( + route=f"{MODELS}/{model_version_pipeline_run_link_filter_model.model_id}{MODEL_VERSIONS}/{model_version_pipeline_run_link_filter_model.model_version_id}{RUNS}", + response_model=ModelVersionPipelineRunResponseModel, + filter_model=model_version_pipeline_run_link_filter_model, + ) + + def delete_model_version_pipeline_run_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to pipeline run link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. + """ + self._delete_resource( + resource_id=model_version_pipeline_run_link_name_or_id, + route=f"{MODELS}/{model_name_or_id}{MODEL_VERSIONS}/{model_version_name_or_id}{RUNS}", ) # ======================= diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 3dc9b8d29cc..69dfa63e92a 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -60,7 +60,8 @@ from zenml.zen_stores.schemas.model_schemas import ( ModelSchema, ModelVersionSchema, - ModelVersionLinkSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, ) __all__ = [ @@ -98,5 +99,6 @@ "LogsSchema", "ModelSchema", "ModelVersionSchema", - "ModelVersionLinkSchema", + "ModelVersionArtifactSchema", + "ModelVersionPipelineRunSchema", ] diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index c7c980bfad6..bd55820fe19 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from zenml.zen_stores.schemas.model_schemas import ( - ModelVersionLinkSchema, + ModelVersionArtifactSchema, ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.step_run_schemas import ( @@ -97,7 +97,9 @@ class ArtifactSchema(NamedSchema, table=True): back_populates="artifact", sa_relationship_kwargs={"cascade": "delete"}, ) - model_version_links: List["ModelVersionLinkSchema"] = Relationship( + model_versions_artifacts_links: List[ + "ModelVersionArtifactSchema" + ] = Relationship( back_populates="artifact", sa_relationship_kwargs={"cascade": "delete"}, ) diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 0b4da5877a1..2f88c457cf6 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -26,8 +26,10 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, - ModelVersionLinkRequestModel, - ModelVersionLinkResponseModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ) @@ -76,7 +78,11 @@ class ModelSchema(NamedSchema, table=True): back_populates="model", sa_relationship_kwargs={"cascade": "delete"}, ) - objects_links: List["ModelVersionLinkSchema"] = Relationship( + artifact_links: List["ModelVersionArtifactSchema"] = Relationship( + back_populates="model", + sa_relationship_kwargs={"cascade": "delete"}, + ) + pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship( back_populates="model", sa_relationship_kwargs={"cascade": "delete"}, ) @@ -189,7 +195,11 @@ class ModelVersionSchema(BaseSchema, table=True): nullable=False, ) model: "ModelSchema" = Relationship(back_populates="model_versions") - objects_links: List["ModelVersionLinkSchema"] = Relationship( + artifact_links: List["ModelVersionArtifactSchema"] = Relationship( + back_populates="model_version", + sa_relationship_kwargs={"cascade": "delete"}, + ) + pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship( back_populates="model_version", sa_relationship_kwargs={"cascade": "delete"}, ) @@ -237,24 +247,22 @@ def to_model(self) -> ModelVersionResponseModel: stage=self.stage, model_object_ids={ al.name: al.artifact_id - for al in self.objects_links + for al in self.artifact_links if al.artifact_id is not None and al.is_model_object }, deployment_ids={ al.name: al.artifact_id - for al in self.objects_links + for al in self.artifact_links if al.artifact_id is not None and al.is_deployment }, artifact_object_ids={ al.name: al.artifact_id - for al in self.objects_links + for al in self.artifact_links if al.artifact_id is not None and not (al.is_deployment or al.is_model_object) }, pipeline_run_ids={ - al.name: al.pipeline_run_id - for al in self.objects_links - if al.pipeline_run_id is not None + al.name: al.pipeline_run_id for al in self.pipeline_run_links }, ) @@ -275,10 +283,10 @@ def update( return self -class ModelVersionLinkSchema(NamedSchema, table=True): - """SQL Model for linking of Model Versions and Artifacts or Pipeline Runs M:M.""" +class ModelVersionArtifactSchema(NamedSchema, table=True): + """SQL Model for linking of Model Versions and Artifacts M:M.""" - __tablename__ = "model_version_links" + __tablename__ = "model_versions_artifacts" workspace_id: UUID = build_foreign_key_field( source=__tablename__, @@ -289,7 +297,7 @@ class ModelVersionLinkSchema(NamedSchema, table=True): nullable=False, ) workspace: "WorkspaceSchema" = Relationship( - back_populates="model_version_links" + back_populates="model_versions_artifacts_links" ) user_id: Optional[UUID] = build_foreign_key_field( @@ -301,7 +309,7 @@ class ModelVersionLinkSchema(NamedSchema, table=True): nullable=True, ) user: Optional["UserSchema"] = Relationship( - back_populates="model_version_links" + back_populates="model_versions_artifacts_links" ) model_id: UUID = build_foreign_key_field( @@ -312,7 +320,7 @@ class ModelVersionLinkSchema(NamedSchema, table=True): ondelete="CASCADE", nullable=False, ) - model: "ModelSchema" = Relationship(back_populates="objects_links") + model: "ModelSchema" = Relationship(back_populates="artifact_links") model_version_id: UUID = build_foreign_key_field( source=__tablename__, target=ModelVersionSchema.__tablename__, @@ -322,7 +330,7 @@ class ModelVersionLinkSchema(NamedSchema, table=True): nullable=False, ) model_version: "ModelVersionSchema" = Relationship( - back_populates="objects_links" + back_populates="artifact_links" ) artifact_id: Optional[UUID] = build_foreign_key_field( source=__tablename__, @@ -333,18 +341,7 @@ class ModelVersionLinkSchema(NamedSchema, table=True): nullable=True, ) artifact: Optional["ArtifactSchema"] = Relationship( - back_populates="model_version_links" - ) - pipeline_run_id: Optional[UUID] = build_foreign_key_field( - source=__tablename__, - target=PipelineRunSchema.__tablename__, - source_column="run_id", - target_column="id", - ondelete="CASCADE", - nullable=True, - ) - pipeline_run: Optional["PipelineRunSchema"] = Relationship( - back_populates="model_version_links" + back_populates="model_versions_artifacts_links" ) is_model_object: bool = Field(sa_column=Column(BOOLEAN, nullable=True)) @@ -352,8 +349,8 @@ class ModelVersionLinkSchema(NamedSchema, table=True): @classmethod def from_request( - cls, model_version_artifact_request: ModelVersionLinkRequestModel - ) -> "ModelVersionLinkSchema": + cls, model_version_artifact_request: ModelVersionArtifactRequestModel + ) -> "ModelVersionArtifactSchema": """Convert an `ModelVersionArtifactRequestModel` to an `ModelVersionArtifactSchema`. Args: @@ -369,18 +366,17 @@ def from_request( model_id=model_version_artifact_request.model, model_version_id=model_version_artifact_request.model_version, artifact_id=model_version_artifact_request.artifact, - pipeline_run_id=model_version_artifact_request.pipeline_run, is_model_object=model_version_artifact_request.is_model_object, is_deployment=model_version_artifact_request.is_deployment, ) - def to_model(self) -> ModelVersionLinkResponseModel: + def to_model(self) -> ModelVersionArtifactResponseModel: """Convert an `ModelVersionArtifactSchema` to an `ModelVersionArtifactResponseModel`. Returns: The created `ModelVersionArtifactResponseModel`. """ - return ModelVersionLinkResponseModel( + return ModelVersionArtifactResponseModel( id=self.id, name=self.name, user=self.user.to_model() if self.user else None, @@ -390,7 +386,108 @@ def to_model(self) -> ModelVersionLinkResponseModel: model=self.model_id, model_version=self.model_version_id, artifact=self.artifact_id, - pipeline_run=self.pipeline_run_id, is_model_object=self.is_model_object, is_deployment=self.is_deployment, ) + + +class ModelVersionPipelineRunSchema(NamedSchema, table=True): + """SQL Model for linking of Model Versions and Pipeline Runs M:M.""" + + __tablename__ = "model_versions_runs" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship( + back_populates="model_versions_pipeline_runs_links" + ) + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship( + back_populates="model_versions_pipeline_runs_links" + ) + + model_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelSchema.__tablename__, + source_column="model_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model: "ModelSchema" = Relationship(back_populates="pipeline_run_links") + model_version_id: UUID = build_foreign_key_field( + source=__tablename__, + target=ModelVersionSchema.__tablename__, + source_column="model_version_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + model_version: "ModelVersionSchema" = Relationship( + back_populates="pipeline_run_links" + ) + pipeline_run_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=PipelineRunSchema.__tablename__, + source_column="run_id", + target_column="id", + ondelete="CASCADE", + nullable=True, + ) + pipeline_run: Optional["PipelineRunSchema"] = Relationship( + back_populates="model_versions_pipeline_runs_links" + ) + + @classmethod + def from_request( + cls, + model_version_pipeline_run_request: ModelVersionPipelineRunRequestModel, + ) -> "ModelVersionPipelineRunSchema": + """Convert an `ModelVersionPipelineRunRequestModel` to an `ModelVersionPipelineRunSchema`. + + Args: + model_version_pipeline_run_request: The request link to convert. + + Returns: + The converted schema. + """ + return cls( + workspace_id=model_version_pipeline_run_request.workspace, + user_id=model_version_pipeline_run_request.user, + name=model_version_pipeline_run_request.name, + model_id=model_version_pipeline_run_request.model, + model_version_id=model_version_pipeline_run_request.model_version, + pipeline_run_id=model_version_pipeline_run_request.pipeline_run, + ) + + def to_model(self) -> ModelVersionPipelineRunResponseModel: + """Convert an `ModelVersionPipelineRunSchema` to an `ModelVersionPipelineRunResponseModel`. + + Returns: + The created `ModelVersionPipelineRunResponseModel`. + """ + return ModelVersionPipelineRunResponseModel( + id=self.id, + name=self.name, + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + model=self.model_id, + model_version=self.model_version_id, + pipeline_run=self.pipeline_run_id, + ) diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 54c74a5a77f..629b44a0341 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -43,7 +43,9 @@ if TYPE_CHECKING: from zenml.zen_stores.schemas.logs_schemas import LogsSchema - from zenml.zen_stores.schemas.model_schemas import ModelVersionLinkSchema + from zenml.zen_stores.schemas.model_schemas import ( + ModelVersionPipelineRunSchema, + ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -158,7 +160,9 @@ class PipelineRunSchema(NamedSchema, table=True): back_populates="pipeline_run", sa_relationship_kwargs={"cascade": "delete", "uselist": False}, ) - model_version_links: List["ModelVersionLinkSchema"] = Relationship( + model_versions_pipeline_runs_links: List[ + "ModelVersionPipelineRunSchema" + ] = Relationship( back_populates="pipeline_run", sa_relationship_kwargs={"cascade": "delete"}, ) diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index e11c65b1cac..6be3c152509 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -28,7 +28,8 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, - ModelVersionLinkSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -100,9 +101,12 @@ class UserSchema(NamedSchema, table=True): model_versions: List["ModelVersionSchema"] = Relationship( back_populates="user", ) - model_version_links: List["ModelVersionLinkSchema"] = Relationship( - back_populates="user", - ) + model_versions_artifacts_links: List[ + "ModelVersionArtifactSchema" + ] = Relationship(back_populates="user") + model_versions_pipeline_runs_links: List[ + "ModelVersionPipelineRunSchema" + ] = Relationship(back_populates="user") @classmethod def from_request(cls, model: UserRequestModel) -> "UserSchema": diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index ff2647f67ad..f0eb0792d5e 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -30,7 +30,8 @@ CodeRepositorySchema, FlavorSchema, ModelSchema, - ModelVersionLinkSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, ModelVersionSchema, PipelineBuildSchema, PipelineDeploymentSchema, @@ -127,7 +128,15 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) - model_version_links: List["ModelVersionLinkSchema"] = Relationship( + model_versions_artifacts_links: List[ + "ModelVersionArtifactSchema" + ] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) + model_versions_pipeline_runs_links: List[ + "ModelVersionPipelineRunSchema" + ] = Relationship( back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index ee64030a052..ca50eff4a8d 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -99,10 +99,13 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionLinkFilterModel, - ModelVersionLinkRequestModel, - ModelVersionLinkResponseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ModelVersionUpdateModel, @@ -202,7 +205,8 @@ FlavorSchema, IdentitySchema, ModelSchema, - ModelVersionLinkSchema, + ModelVersionArtifactSchema, + ModelVersionPipelineRunSchema, ModelVersionSchema, NamedSchema, PipelineBuildSchema, @@ -5734,162 +5738,145 @@ def update_model_version( return existing_model_version.to_model() - ####################### - # Model Versions Links - ####################### + ########################### + # Model Versions Artifacts + ########################### - def create_model_version_link( - self, model_version_link: ModelVersionLinkRequestModel - ) -> ModelVersionLinkResponseModel: + def create_model_version_artifact_link( + self, model_version_artifact_link: ModelVersionArtifactRequestModel + ) -> ModelVersionArtifactResponseModel: """Creates a new model version link. Args: - model_version_link: the Model Version Link to be created. + model_version_artifact_link: the Model Version to Artifact Link to be created. Returns: - The newly created model version link. + The newly created model version to artifact link. Raises: - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a link with the given name already exists. """ with Session(self.engine) as session: - existing_model_version_link = session.exec( - select(ModelVersionLinkSchema) + existing_model_version_artifact_link = session.exec( + select(ModelVersionArtifactSchema) + .where( + ModelVersionArtifactSchema.model_version_id + == model_version_artifact_link.model_version + ) .where( - ModelVersionLinkSchema.model_version_id - == model_version_link.model_version + or_( + ModelVersionArtifactSchema.name + == model_version_artifact_link.name, + ModelVersionArtifactSchema.artifact_id + == model_version_artifact_link.artifact, + ) ) - .where(ModelVersionLinkSchema.name == model_version_link.name) ).first() - if existing_model_version_link is not None: + if existing_model_version_artifact_link is not None: raise EntityExistsError( - f"Unable to create model version link {existing_model_version_link.name}: " - f"A model version link with this name already exists in {model_version_link.model_version} model version." + f"Unable to create model version link {existing_model_version_artifact_link.name}: " + f"A model version link with this name already exists in {existing_model_version_artifact_link.model_version} model version." ) - model_version_link_schema = ModelVersionLinkSchema.from_request( - model_version_link + if model_version_artifact_link.name is None: + model_version_artifact_link.name = self.get_artifact( + model_version_artifact_link.artifact + ).name + + model_version_artifact_link_schema = ( + ModelVersionArtifactSchema.from_request( + model_version_artifact_link + ) ) - session.add(model_version_link_schema) + session.add(model_version_artifact_link_schema) session.commit() - mvl = ModelVersionLinkSchema.to_model(model_version_link_schema) + mvl = ModelVersionArtifactSchema.to_model( + model_version_artifact_link_schema + ) return mvl - def list_model_version_links( + def list_model_version_artifact_links( self, - model_version_link_filter_model: ModelVersionLinkFilterModel, - ) -> Page[ModelVersionLinkResponseModel]: - """Get all model version links by filter. + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, + ) -> Page[ModelVersionArtifactResponseModel]: + """Get all model version to artifact links by filter. Args: - model_version_link_filter_model: All filter parameters including pagination + model_version_artifact_link_filter_model: All filter parameters including pagination params. Returns: - A page of all model version links. + A page of all model version to artifact links. """ with Session(self.engine) as session: # issue: https://github.com/tiangolo/sqlmodel/issues/109 - if model_version_link_filter_model.only_artifacts: + if model_version_artifact_link_filter_model.only_artifacts: query = ( - select(ModelVersionLinkSchema) - .where( - ModelVersionLinkSchema.is_model_object - == False # noqa: E712 - ) + select(ModelVersionArtifactSchema) .where( - ModelVersionLinkSchema.is_deployment + ModelVersionArtifactSchema.is_model_object == False # noqa: E712 ) .where( - ModelVersionLinkSchema.pipeline_run - == None # noqa: E712, E711 - ) - .where( - ModelVersionLinkSchema.artifact - != None # noqa: E712, E711 - ) - ) - elif model_version_link_filter_model.only_deployments: - query = ( - select(ModelVersionLinkSchema) - .where(ModelVersionLinkSchema.is_deployment) - .where( - ModelVersionLinkSchema.is_model_object + ModelVersionArtifactSchema.is_deployment == False # noqa: E712 ) .where( - ModelVersionLinkSchema.pipeline_run - == None # noqa: E712, E711 - ) - .where( - ModelVersionLinkSchema.artifact + ModelVersionArtifactSchema.artifact != None # noqa: E712, E711 ) ) - elif model_version_link_filter_model.only_model_objects: + elif model_version_artifact_link_filter_model.only_deployments: query = ( - select(ModelVersionLinkSchema) - .where(ModelVersionLinkSchema.is_model_object) + select(ModelVersionArtifactSchema) + .where(ModelVersionArtifactSchema.is_deployment) .where( - ModelVersionLinkSchema.is_deployment + ModelVersionArtifactSchema.is_model_object == False # noqa: E712 ) .where( - ModelVersionLinkSchema.pipeline_run - == None # noqa: E712, E711 - ) - .where( - ModelVersionLinkSchema.artifact + ModelVersionArtifactSchema.artifact != None # noqa: E712, E711 ) ) - elif model_version_link_filter_model.only_pipeline_runs: + elif model_version_artifact_link_filter_model.only_model_objects: query = ( - select(ModelVersionLinkSchema) - .where( - ModelVersionLinkSchema.is_model_object - == False # noqa: E712 - ) + select(ModelVersionArtifactSchema) + .where(ModelVersionArtifactSchema.is_model_object) .where( - ModelVersionLinkSchema.is_deployment + ModelVersionArtifactSchema.is_deployment == False # noqa: E712 ) .where( - ModelVersionLinkSchema.pipeline_run + ModelVersionArtifactSchema.artifact != None # noqa: E712, E711 ) - .where( - ModelVersionLinkSchema.artifact - == None # noqa: E712, E711 - ) ) else: - query = select(ModelVersionLinkSchema) - model_version_link_filter_model.only_artifacts = None - model_version_link_filter_model.only_deployments = None - model_version_link_filter_model.only_model_objects = None - model_version_link_filter_model.only_pipeline_runs = None + query = select(ModelVersionArtifactSchema) + model_version_artifact_link_filter_model.only_artifacts = None + model_version_artifact_link_filter_model.only_deployments = None + model_version_artifact_link_filter_model.only_model_objects = None return self.filter_and_paginate( session=session, query=query, - table=ModelVersionLinkSchema, - filter_model=model_version_link_filter_model, + table=ModelVersionArtifactSchema, + filter_model=model_version_artifact_link_filter_model, ) - def delete_model_version_link( + def delete_model_version_artifact_link( self, model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - model_version_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: - """Deletes a model version link. + """Deletes a model version to artifact link. Args: model_name_or_id: name or ID of the model containing the model version. model_version_name_or_id: name or ID of the model version containing the link. - model_version_link_name_or_id: name or ID of the model version link to be deleted. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. Raises: KeyError: specified ID or name not found. @@ -5899,26 +5886,154 @@ def delete_model_version_link( model_version = self.get_model_version( model_name_or_id, model_version_name_or_id ) - query = select(ModelVersionLinkSchema).where( - ModelVersionLinkSchema.model_version_id == model_version.id + query = select(ModelVersionArtifactSchema).where( + ModelVersionArtifactSchema.model_version_id == model_version.id + ) + try: + UUID(str(model_version_artifact_link_name_or_id)) + query = query.where( + ModelVersionArtifactSchema.id + == model_version_artifact_link_name_or_id + ) + except ValueError: + query = query.where( + ModelVersionArtifactSchema.name + == model_version_artifact_link_name_or_id + ) + + model_version_artifact_link = session.exec(query).first() + if model_version_artifact_link is None: + raise KeyError( + f"Unable to delete model version link with name `{model_version_artifact_link_name_or_id}`: " + f"No model version link with this name found." + ) + + session.delete(model_version_artifact_link) + session.commit() + + ############################### + # Model Versions Pipeline Runs + ############################### + + def create_model_version_pipeline_run_link( + self, + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + ) -> ModelVersionPipelineRunResponseModel: + """Creates a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: the Model Version to Pipeline Run Link to be created. + + Returns: + The newly created model version to pipeline run link. + + Raises: + EntityExistsError: If a link with the given ID already exists. + """ + with Session(self.engine) as session: + existing_model_version_pipeline_run_link = session.exec( + select(ModelVersionPipelineRunSchema) + .where( + ModelVersionPipelineRunSchema.model_version_id + == model_version_pipeline_run_link.model_version + ) + .where( + or_( + ModelVersionPipelineRunSchema.pipeline_run_id + == model_version_pipeline_run_link.pipeline_run, + ModelVersionPipelineRunSchema.name + == model_version_pipeline_run_link.name, + ) + ) + ).first() + if existing_model_version_pipeline_run_link is not None: + raise EntityExistsError( + f"Unable to create model version link {existing_model_version_pipeline_run_link.name}: " + f"A model version link with this name already exists in {existing_model_version_pipeline_run_link.model_version} model version." + ) + + if model_version_pipeline_run_link.name is None: + model_version_pipeline_run_link.name = self.get_run( + model_version_pipeline_run_link.pipeline_run + ).name + + model_version_pipeline_run_link_schema = ( + ModelVersionPipelineRunSchema.from_request( + model_version_pipeline_run_link + ) + ) + session.add(model_version_pipeline_run_link_schema) + + session.commit() + mvl = ModelVersionPipelineRunSchema.to_model( + model_version_pipeline_run_link_schema + ) + return mvl + + def list_model_version_pipeline_run_links( + self, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, + ) -> Page[ModelVersionPipelineRunResponseModel]: + """Get all model version to pipeline run links by filter. + + Args: + model_version_pipeline_run_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to pipeline run links. + """ + with Session(self.engine) as session: + return self.filter_and_paginate( + session=session, + query=select(ModelVersionPipelineRunSchema), + table=ModelVersionPipelineRunSchema, + filter_model=model_version_pipeline_run_link_filter_model, + ) + + def delete_model_version_pipeline_run_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to pipeline run link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. + + Raises: + KeyError: specified ID not found. + """ + with Session(self.engine) as session: + self.get_model(model_name_or_id) + model_version = self.get_model_version( + model_name_or_id, model_version_name_or_id + ) + query = select(ModelVersionPipelineRunSchema).where( + ModelVersionPipelineRunSchema.model_version_id + == model_version.id ) try: - UUID(str(model_version_link_name_or_id)) + UUID(str(model_version_pipeline_run_link_name_or_id)) query = query.where( - ModelVersionLinkSchema.id == model_version_link_name_or_id + ModelVersionPipelineRunSchema.id + == model_version_pipeline_run_link_name_or_id ) except ValueError: query = query.where( - ModelVersionLinkSchema.name - == model_version_link_name_or_id + ModelVersionPipelineRunSchema.name + == model_version_pipeline_run_link_name_or_id ) - model_version_link = session.exec(query).first() - if model_version_link is None: + model_version_pipeline_run_link = session.exec(query).first() + if model_version_pipeline_run_link is None: raise KeyError( - f"Unable to delete model version link with name `{model_version_link_name_or_id}`: " + f"Unable to delete model version link with name `{model_version_pipeline_run_link_name_or_id}`: " f"No model version link with this name found." ) - session.delete(model_version_link) + session.delete(model_version_pipeline_run_link) session.commit() diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 6cd459d20f6..e5bdc518c20 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -36,10 +36,13 @@ ModelRequestModel, ModelResponseModel, ModelUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, + ModelVersionArtifactResponseModel, ModelVersionFilterModel, - ModelVersionLinkFilterModel, - ModelVersionLinkRequestModel, - ModelVersionLinkResponseModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, + ModelVersionPipelineRunResponseModel, ModelVersionRequestModel, ModelVersionResponseModel, ModelVersionUpdateModel, @@ -1841,55 +1844,109 @@ def update_model_version( RuntimeError: If there is a model version with target stage, but `force` flag is off """ - ####################### - # Model Versions Links - ####################### + ########################### + # Model Versions Artifacts + ########################### @abstractmethod - def create_model_version_link( - self, model_version_link: ModelVersionLinkRequestModel - ) -> ModelVersionLinkResponseModel: + def create_model_version_artifact_link( + self, model_version_artifact_link: ModelVersionArtifactRequestModel + ) -> ModelVersionArtifactResponseModel: """Creates a new model version link. Args: - model_version_link: the Model Version Link to be created. + model_version_artifact_link: the Model Version to Artifact Link to be created. Returns: - The newly created model version link. + The newly created model version to artifact link. Raises: - EntityExistsError: If a workspace with the given name already exists. + EntityExistsError: If a link with the given name already exists. """ @abstractmethod - def list_model_version_links( + def list_model_version_artifact_links( self, - model_version_link_filter_model: ModelVersionLinkFilterModel, - ) -> Page[ModelVersionLinkResponseModel]: - """Get all model version links by filter. + model_version_artifact_link_filter_model: ModelVersionArtifactFilterModel, + ) -> Page[ModelVersionArtifactResponseModel]: + """Get all model version to artifact links by filter. Args: - model_version_link_filter_model: All filter parameters including pagination + model_version_artifact_link_filter_model: All filter parameters including pagination params. Returns: - A page of all model version links. + A page of all model version to artifact links. """ @abstractmethod - def delete_model_version_link( + def delete_model_version_artifact_link( self, model_name_or_id: Union[str, UUID], model_version_name_or_id: Union[str, UUID], - model_version_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: Union[str, UUID], ) -> None: - """Deletes a model version link. + """Deletes a model version to artifact link. Args: model_name_or_id: name or ID of the model containing the model version. model_version_name_or_id: name or ID of the model version containing the link. - model_version_link_name_or_id: name or ID of the model version link to be deleted. + model_version_artifact_link_name_or_id: name or ID of the model version to artifact link to be deleted. Raises: KeyError: specified ID or name not found. """ + + ############################### + # Model Versions Pipeline Runs + ############################### + + @abstractmethod + def create_model_version_pipeline_run_link( + self, + model_version_pipeline_run_link: ModelVersionPipelineRunRequestModel, + ) -> ModelVersionPipelineRunResponseModel: + """Creates a new model version to pipeline run link. + + Args: + model_version_pipeline_run_link: the Model Version to Pipeline Run Link to be created. + + Returns: + The newly created model version to pipeline run link. + + Raises: + EntityExistsError: If a link with the given ID already exists. + """ + + @abstractmethod + def list_model_version_pipeline_run_links( + self, + model_version_pipeline_run_link_filter_model: ModelVersionPipelineRunFilterModel, + ) -> Page[ModelVersionPipelineRunResponseModel]: + """Get all model version to pipeline run links by filter. + + Args: + model_version_pipeline_run_link_filter_model: All filter parameters including pagination + params. + + Returns: + A page of all model version to pipeline run links. + """ + + @abstractmethod + def delete_model_version_pipeline_run_link( + self, + model_name_or_id: Union[str, UUID], + model_version_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: Union[str, UUID], + ) -> None: + """Deletes a model version to pipeline run link. + + Args: + model_name_or_id: name or ID of the model containing the model version. + model_version_name_or_id: name or ID of the model version containing the link. + model_version_pipeline_run_link_name_or_id: name or ID of the model version to pipeline run link to be deleted. + + Raises: + KeyError: specified ID not found. + """ diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 64d729f2d70..04cbf55d373 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -40,7 +40,6 @@ StubLocalRepositoryContext, ) from zenml.client import Client -from zenml.config.pipeline_configurations import PipelineConfiguration from zenml.enums import SecretScope, StackComponentType, StoreType from zenml.exceptions import ( DoesNotExistException, @@ -54,13 +53,14 @@ ArtifactResponseModel, ComponentFilterModel, ComponentUpdateModel, + ModelVersionArtifactFilterModel, + ModelVersionArtifactRequestModel, ModelVersionFilterModel, - ModelVersionLinkFilterModel, - ModelVersionLinkRequestModel, + ModelVersionPipelineRunFilterModel, + ModelVersionPipelineRunRequestModel, ModelVersionRequestModel, ModelVersionUpdateModel, PipelineRunFilterModel, - PipelineRunRequestModel, PipelineRunResponseModel, RoleFilterModel, RoleRequestModel, @@ -2712,152 +2712,155 @@ def test_model_version_update_public_interface(self): ) -class TestModelVersionLink: - def test_model_version_link_create_pass(self): - with ModelVersionContext(True, create_artifact=True) as ( +class TestModelVersionArtifactLinks: + def test_link_create_pass(self): + with ModelVersionContext(True, create_artifacts=1) as ( model_version, - artifact, + artifacts, ): zs = Client().zen_store - zs.create_model_version_link( - ModelVersionLinkRequestModel( + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( user=model_version.user.id, workspace=model_version.workspace.id, model=model_version.model.id, model_version=model_version.id, name="link", - artifact=artifact.id, + artifact=artifacts[0].id, ) ) - def test_model_version_link_create_duplicated(self): - with ModelVersionContext(True, create_artifact=True) as ( + def test_link_create_duplicated(self): + with ModelVersionContext(True, create_artifacts=1) as ( model_version, - artifact, + artifacts, ): zs = Client().zen_store - zs.create_model_version_link( - ModelVersionLinkRequestModel( + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( user=model_version.user.id, workspace=model_version.workspace.id, model=model_version.model.id, model_version=model_version.id, name="link", - artifact=artifact.id, + artifact=artifacts[0].id, ) ) - + # name collision with pytest.raises(EntityExistsError): - zs.create_model_version_link( - ModelVersionLinkRequestModel( + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( user=model_version.user.id, workspace=model_version.workspace.id, model=model_version.model.id, model_version=model_version.id, name="link", - artifact=artifact.id, + artifact=uuid4(), + ) + ) + # id collision + with pytest.raises(EntityExistsError): + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link2", + artifact=artifacts[0].id, ) ) - def test_model_version_link_delete_found(self): - with ModelVersionContext(True, create_artifact=True) as ( + def test_link_delete_found(self): + with ModelVersionContext(True, create_artifacts=1) as ( model_version, - artifact, + artifacts, ): zs = Client().zen_store - zs.create_model_version_link( - ModelVersionLinkRequestModel( + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( user=model_version.user.id, workspace=model_version.workspace.id, model=model_version.model.id, model_version=model_version.id, name="link", - artifact=artifact.id, + artifact=artifacts[0].id, ) ) - zs.delete_model_version_link( - model_version.model.id, model_version.id, "link" + zs.delete_model_version_artifact_link( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + model_version_artifact_link_name_or_id="link", ) - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) assert len(mvls) == 0 - def test_model_version_link_delete_not_found(self): + def test_link_delete_not_found(self): with ModelVersionContext(True) as model_version: zs = Client().zen_store with pytest.raises(KeyError): - zs.delete_model_version_link( - model_version.model.id, model_version.id, "link" + zs.delete_model_version_artifact_link( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + model_version_artifact_link_name_or_id="link", ) - def test_model_version_link_list_empty(self): + def test_link_list_empty(self): with ModelVersionContext(True) as model_version: zs = Client().zen_store - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) assert len(mvls) == 0 - def test_model_version_link_list_populated(self): - with ModelVersionContext(True, create_artifact=True) as ( + def test_link_list_populated(self): + with ModelVersionContext(True, create_artifacts=3) as ( model_version, - artifact, + artifacts, ): zs = Client().zen_store - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) assert len(mvls) == 0 - for n, mo, dep, pr in [ - ("link1", False, False, False), - ("link2", True, False, False), - ("link3", False, True, False), - ("link4", False, False, True), + for n, mo, dep, artifact in [ + ("link1", False, False, artifacts[0]), + ("link2", True, False, artifacts[1]), + ("link3", False, True, artifacts[2]), ]: - if pr: - pipeline_run = zs.create_run( - PipelineRunRequestModel( - id=uuid.uuid4(), - name=sample_name("sample_pipeline_run"), - status="running", - config=PipelineConfiguration(name="aria_pipeline"), - user=model_version.user.id, - workspace=model_version.workspace.id, - ) - ) - zs.create_model_version_link( - ModelVersionLinkRequestModel( + zs.create_model_version_artifact_link( + ModelVersionArtifactRequestModel( user=model_version.user.id, workspace=model_version.workspace.id, model=model_version.model.id, model_version=model_version.id, name=n, - artifact=artifact.id if not pr else None, - pipeline_run=pipeline_run.id if pr else None, + artifact=artifact.id, is_model_object=mo, is_deployment=dep, ) ) - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, ) ) - assert len(mvls) == 4 + assert len(mvls) == 3 - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_artifacts=True, @@ -2865,8 +2868,8 @@ def test_model_version_link_list_populated(self): ) assert len(mvls) == 1 and mvls[0].name == "link1" - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_model_objects=True, @@ -2874,8 +2877,8 @@ def test_model_version_link_list_populated(self): ) assert len(mvls) == 1 and mvls[0].name == "link2" - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( + mvls = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( model_id=model_version.model.id, model_version_id=model_version.id, only_deployments=True, @@ -2883,15 +2886,6 @@ def test_model_version_link_list_populated(self): ) assert len(mvls) == 1 and mvls[0].name == "link3" - mvls = zs.list_model_version_links( - ModelVersionLinkFilterModel( - model_id=model_version.model.id, - model_version_id=model_version.id, - only_pipeline_runs=True, - ) - ) - assert len(mvls) == 1 and mvls[0].name == "link4" - mv = zs.get_model_version( model_name_or_id=model_version.model.id, model_version_name_or_id=model_version.id, @@ -2900,7 +2894,6 @@ def test_model_version_link_list_populated(self): assert len(mv.model_object_ids) == 1 assert len(mv.artifact_object_ids) == 1 assert len(mv.deployment_ids) == 1 - assert len(mv.pipeline_run_ids) == 1 assert isinstance( mv.model_objects["link2"], @@ -2914,20 +2907,174 @@ def test_model_version_link_list_populated(self): mv.deployments["link3"], ArtifactResponseModel, ) - assert isinstance( - mv.pipeline_runs["link4"], - PipelineRunResponseModel, - ) - assert mv.pipeline_runs["link4"].id == pipeline_run.id - assert mv.model_objects["link2"].id == artifact.id + assert mv.model_objects["link2"].id == artifacts[1].id assert mv.get_model_object("link2") == mv.model_objects["link2"] assert ( mv.get_artifact_object("link1") == mv.artifact_objects["link1"] ) assert mv.get_deployment("link3") == mv.deployments["link3"] - assert mv.get_pipeline_run("link4") == mv.pipeline_runs["link4"] - if pr: - zs.delete_run(pipeline_run.id) + +class TestModelVersionPipelineRunLinks: + def test_link_create_pass(self): + with ModelVersionContext(True, create_prs=1) as ( + model_version, + prs, + ): + zs = Client().zen_store + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + + def test_link_create_duplicated(self): + with ModelVersionContext(True, create_prs=1) as ( + model_version, + prs, + ): + zs = Client().zen_store + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + # name collision + with pytest.raises(EntityExistsError): + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=uuid4(), + ) + ) + # id collision + with pytest.raises(EntityExistsError): + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + + def test_link_delete_found(self): + with ModelVersionContext(True, create_prs=1) as ( + model_version, + prs, + ): + zs = Client().zen_store + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name="link", + pipeline_run=prs[0].id, + ) + ) + zs.delete_model_version_pipeline_run_link( + model_version.model.id, model_version.id, "link" + ) + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + + def test_link_delete_not_found(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + with pytest.raises(KeyError): + zs.delete_model_version_pipeline_run_link( + model_version.model.id, model_version.id, "link" + ) + + def test_link_list_empty(self): + with ModelVersionContext(True) as model_version: + zs = Client().zen_store + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + + def test_link_list_populated(self): + with ModelVersionContext(True, create_prs=2) as ( + model_version, + prs, + ): + zs = Client().zen_store + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 0 + for n, pr in zip(["link4", None], prs): + zs.create_model_version_pipeline_run_link( + ModelVersionPipelineRunRequestModel( + user=model_version.user.id, + workspace=model_version.workspace.id, + model=model_version.model.id, + model_version=model_version.id, + name=n, + pipeline_run=pr.id, + ) + ) + mvls = zs.list_model_version_pipeline_run_links( + ModelVersionPipelineRunFilterModel( + model_id=model_version.model.id, + model_version_id=model_version.id, + ) + ) + assert len(mvls) == 2 + + mv = zs.get_model_version( + model_name_or_id=model_version.model.id, + model_version_name_or_id=model_version.id, + ) + + assert len(mv.pipeline_run_ids) == 2 + + assert isinstance( + mv.pipeline_runs["link4"], + PipelineRunResponseModel, + ) + assert isinstance( + mv.pipeline_runs[prs[1].name], + PipelineRunResponseModel, + ) + + assert mv.pipeline_runs["link4"].id == prs[0].id + assert mv.pipeline_runs[prs[1].name].id == prs[1].id + + assert mv.get_pipeline_run("link4") == mv.pipeline_runs["link4"] + assert ( + mv.get_pipeline_run(prs[1].name) + == mv.pipeline_runs[prs[1].name] + ) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index f9fe5549ce3..00f04863f62 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -512,7 +512,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback): class ModelVersionContext: def __init__( - self, create_version: bool = False, create_artifact: bool = False + self, + create_version: bool = False, + create_artifacts: int = 0, + create_prs: int = 0, ): client = Client() self.workspace = client.active_workspace.id @@ -524,8 +527,10 @@ def __init__( self.del_model = False self.create_version = create_version - self.create_artifact = create_artifact - self.artifact = None + self.create_artifacts = create_artifacts + self.artifacts = [] + self.create_prs = create_prs + self.prs = [] def __enter__(self): zs = Client().zen_store @@ -563,26 +568,45 @@ def __enter__(self): ) ) - if self.create_artifact: - artifact = zs.create_artifact( - ArtifactRequestModel( - name=sample_name("sample_artifact"), - data_type="module.class", - materializer="module.class", - type=ArtifactType.DATA, - uri="", - user=user.id, - workspace=ws.id, + for _ in range(self.create_artifacts): + self.artifacts.append( + zs.create_artifact( + ArtifactRequestModel( + name=sample_name("sample_artifact"), + data_type="module.class", + materializer="module.class", + type=ArtifactType.DATA, + uri="", + user=user.id, + workspace=ws.id, + ) + ) + ) + for _ in range(self.create_prs): + self.prs.append( + zs.create_run( + PipelineRunRequestModel( + id=uuid.uuid4(), + name=sample_name("sample_pipeline_run"), + status="running", + config=PipelineConfiguration(name="aria_pipeline"), + user=user.id, + workspace=ws.id, + ) ) ) - self.artifact = artifact - if self.create_version: - return mv, self.artifact + if self.create_version: + if self.create_artifacts: + return mv, self.artifacts + if self.create_prs: + return mv, self.prs else: - return model, self.artifact - else: - if self.create_version: return mv + else: + if self.create_artifacts: + return model, self.artifacts + if self.create_prs: + return model, self.prs else: return model @@ -592,8 +616,10 @@ def __exit__(self, exc_type, exc_value, exc_traceback): zs.delete_model_version(self.model, self.model_version) if self.del_model: zs.delete_model(self.model) - if self.create_artifact: - zs.delete_artifact(self.artifact.id) + for artifact in self.artifacts: + zs.delete_artifact(artifact.id) + for run in self.prs: + zs.delete_run(run.id) if self.del_user: zs.delete_user(self.user) if self.del_ws: From c887b0bb132a1a40d229e9645af3c02914ba765c Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Fri, 15 Sep 2023 22:05:40 +0200 Subject: [PATCH 39/40] lint --- src/zenml/models/model_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index b01ee88bdcc..543de9ca721 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -20,7 +20,6 @@ from zenml.model import ModelStages from zenml.models.artifact_models import ArtifactResponseModel - from zenml.models.base_models import ( WorkspaceScopedRequestModel, WorkspaceScopedResponseModel, From 9b0df95e92a28d699d8dccb8676f2c8288b5ba02 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:44:05 +0200 Subject: [PATCH 40/40] pr comments --- src/zenml/zen_server/routers/models_endpoints.py | 2 +- src/zenml/zen_server/routers/workspaces_endpoints.py | 1 - src/zenml/zen_stores/schemas/model_schemas.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index e1790b27daf..f5c08bd49e9 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -340,7 +340,7 @@ def list_model_version_pipeline_run_links( Args: model_version_pipeline_run_link_filter_model: Filter model used for pagination, sorting, - filtering + and filtering Returns: The model version to pipeline run links according to query filters. diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index b6ce0edcd6c..cc300ccb368 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -1305,7 +1305,6 @@ def list_workspace_model_versions( model_version_filter_model: Filter model used for pagination, sorting, filtering - Returns: The model versions according to query filters. """ diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 2f88c457cf6..56ff415a30c 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -351,7 +351,7 @@ class ModelVersionArtifactSchema(NamedSchema, table=True): def from_request( cls, model_version_artifact_request: ModelVersionArtifactRequestModel ) -> "ModelVersionArtifactSchema": - """Convert an `ModelVersionArtifactRequestModel` to an `ModelVersionArtifactSchema`. + """Convert an `ModelVersionArtifactRequestModel` to a `ModelVersionArtifactSchema`. Args: model_version_artifact_request: The request link to convert.