From 57421208e5b6b7637fc54a13ade71f8cab91c258 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 5 Jan 2024 11:31:51 +0100 Subject: [PATCH 1/2] Allow filtering artifacts with/without custom names --- src/zenml/client.py | 3 +++ src/zenml/models/v2/core/artifact.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index ab19542e5a5..bfcaed57930 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -2608,6 +2608,7 @@ def list_artifacts( created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, name: Optional[str] = None, + has_custom_name: Optional[bool] = None, ) -> Page[ArtifactResponse]: """Get a list of artifacts. @@ -2620,6 +2621,7 @@ def list_artifacts( created: Use to filter by time of creation updated: Use the last updated date for filtering name: The name of the artifact to filter by. + has_custom_name: Filter artifact with/without custom names. Returns: A list of artifacts. @@ -2633,6 +2635,7 @@ def list_artifacts( created=created, updated=updated, name=name, + has_custom_name=has_custom_name, ) return self.zen_store.list_artifacts(artifact_filter_model) diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index f39ef4595bb..f73b3199d78 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -136,4 +136,5 @@ def versions(self) -> Dict[str, "ArtifactVersionResponse"]: class ArtifactFilter(BaseFilter): """Model to enable advanced filtering of artifacts.""" - name: Optional[str] + name: Optional[str] = None + has_custom_name: Optional[bool] = None From 58b750dccdb145920ce41aa2ee45b7c5db1f2034 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Fri, 5 Jan 2024 11:53:27 +0100 Subject: [PATCH 2/2] Also allow filtering for artifact versions and model artifact versions --- src/zenml/client.py | 8 +++++++- src/zenml/models/v2/core/artifact_version.py | 12 ++++++++++++ src/zenml/models/v2/core/model_version_artifact.py | 12 ++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/zenml/client.py b/src/zenml/client.py index bfcaed57930..74ce394f940 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -2621,7 +2621,7 @@ def list_artifacts( created: Use to filter by time of creation updated: Use the last updated date for filtering name: The name of the artifact to filter by. - has_custom_name: Filter artifact with/without custom names. + has_custom_name: Filter artifacts with/without custom names. Returns: A list of artifacts. @@ -2727,6 +2727,7 @@ def list_artifact_versions( workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, only_unused: Optional[bool] = False, + has_custom_name: Optional[bool] = None, ) -> Page[ArtifactVersionResponse]: """Get a list of artifact versions. @@ -2751,6 +2752,7 @@ def list_artifact_versions( user_id: The id of the user to filter by. only_unused: Only return artifact versions that are not used in any pipeline runs. + has_custom_name: Filter artifacts with/without custom names. Returns: A list of artifact versions. @@ -2775,6 +2777,7 @@ def list_artifact_versions( workspace_id=workspace_id, user_id=user_id, only_unused=only_unused, + has_custom_name=has_custom_name, ) artifact_version_filter_model.set_scope_workspace( self.active_workspace.id @@ -4849,6 +4852,7 @@ def list_model_version_artifact_links( only_data_artifacts: Optional[bool] = None, only_model_artifacts: Optional[bool] = None, only_deployment_artifacts: Optional[bool] = None, + has_custom_name: Optional[bool] = None, ) -> Page[ModelVersionArtifactResponse]: """Get model version to artifact links by filter in Model Control Plane. @@ -4868,6 +4872,7 @@ def list_model_version_artifact_links( only_data_artifacts: Use to filter by data artifacts only_model_artifacts: Use to filter by model artifacts only_deployment_artifacts: Use to filter by deployment artifacts + has_custom_name: Filter artifacts with/without custom names. Returns: A page of all model version to artifact links. @@ -4889,6 +4894,7 @@ def list_model_version_artifact_links( only_data_artifacts=only_data_artifacts, only_model_artifacts=only_model_artifacts, only_deployment_artifacts=only_deployment_artifacts, + has_custom_name=has_custom_name, ) ) diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index bd35ca36adb..5b6820b18d4 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -356,6 +356,7 @@ class ArtifactVersionFilter(WorkspaceScopedFilter): *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "name", "only_unused", + "has_custom_name", ] artifact_id: Optional[Union[UUID, str]] = Field( default=None, @@ -401,6 +402,10 @@ class ArtifactVersionFilter(WorkspaceScopedFilter): only_unused: Optional[bool] = Field( default=False, description="Filter only for unused artifacts" ) + has_custom_name: Optional[bool] = Field( + default=None, + description="Filter only artifacts with/without custom names.", + ) def get_custom_filters( self, @@ -448,4 +453,11 @@ def get_custom_filters( ) custom_filters.append(unused_filter) + if self.has_custom_name is not None: + custom_name_filter = and_( # type: ignore[type-var] + ArtifactVersionSchema.artifact_id == ArtifactSchema.id, + ArtifactSchema.has_custom_name == self.has_custom_name, + ) + custom_filters.append(custom_name_filter) + return custom_filters diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index a7d35c52dc0..4e1cf7f1811 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -142,12 +142,14 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter): "only_data_artifacts", "only_model_artifacts", "only_deployment_artifacts", + "has_custom_name", ] CLI_EXCLUDE_FIELDS = [ *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, "only_data_artifacts", "only_model_artifacts", "only_deployment_artifacts", + "has_custom_name", "model_id", "model_version_id", "user_id", @@ -178,6 +180,7 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter): only_data_artifacts: Optional[bool] = False only_model_artifacts: Optional[bool] = False only_deployment_artifacts: Optional[bool] = False + has_custom_name: Optional[bool] = None def get_custom_filters( self, @@ -233,4 +236,13 @@ def get_custom_filters( ) custom_filters.append(deployment_artifact_filter) + if self.has_custom_name is not None: + custom_name_filter = and_( # type: ignore[type-var] + ModelVersionArtifactSchema.artifact_version_id + == ArtifactVersionSchema.id, + ArtifactVersionSchema.artifact_id == ArtifactSchema.id, + ArtifactSchema.has_custom_name == self.has_custom_name, + ) + custom_filters.append(custom_name_filter) + return custom_filters