Skip to content

Commit

Permalink
Allow filtering artifacts with/without custom names (#2226)
Browse files Browse the repository at this point in the history
* Allow filtering artifacts with/without custom names

* Also allow filtering for artifact versions and model artifact versions
  • Loading branch information
schustmi authored Jan 8, 2024
1 parent 8797e0b commit a37246c
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,6 +2711,7 @@ def list_artifacts(
created: Optional[Union[datetime, str]] = None,
updated: Optional[Union[datetime, str]] = None,
name: Optional[str] = None,
has_custom_name: Optional[bool] = None,
hydrate: bool = False,
) -> Page[ArtifactResponse]:
"""Get a list of artifacts.
Expand All @@ -2724,6 +2725,7 @@ def list_artifacts(
created: Use to filter by time of creation
updated: Use the last updated date for filtering
name: The name of the artifact to filter by.
has_custom_name: Filter artifacts with/without custom names.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Expand All @@ -2739,6 +2741,7 @@ def list_artifacts(
created=created,
updated=updated,
name=name,
has_custom_name=has_custom_name,
)
return self.zen_store.list_artifacts(
artifact_filter_model,
Expand Down Expand Up @@ -2837,6 +2840,7 @@ def list_artifact_versions(
workspace_id: Optional[Union[str, UUID]] = None,
user_id: Optional[Union[str, UUID]] = None,
only_unused: Optional[bool] = False,
has_custom_name: Optional[bool] = None,
hydrate: bool = False,
) -> Page[ArtifactVersionResponse]:
"""Get a list of artifact versions.
Expand All @@ -2862,6 +2866,7 @@ def list_artifact_versions(
user_id: The id of the user to filter by.
only_unused: Only return artifact versions that are not used in
any pipeline runs.
has_custom_name: Filter artifacts with/without custom names.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Expand All @@ -2888,6 +2893,7 @@ def list_artifact_versions(
workspace_id=workspace_id,
user_id=user_id,
only_unused=only_unused,
has_custom_name=has_custom_name,
)
artifact_version_filter_model.set_scope_workspace(
self.active_workspace.id
Expand Down Expand Up @@ -5018,6 +5024,7 @@ def list_model_version_artifact_links(
only_data_artifacts: Optional[bool] = None,
only_model_artifacts: Optional[bool] = None,
only_deployment_artifacts: Optional[bool] = None,
has_custom_name: Optional[bool] = None,
hydrate: bool = False,
) -> Page[ModelVersionArtifactResponse]:
"""Get model version to artifact links by filter in Model Control Plane.
Expand All @@ -5038,6 +5045,7 @@ def list_model_version_artifact_links(
only_data_artifacts: Use to filter by data artifacts
only_model_artifacts: Use to filter by model artifacts
only_deployment_artifacts: Use to filter by deployment artifacts
has_custom_name: Filter artifacts with/without custom names.
hydrate: Flag deciding whether to hydrate the output model(s)
by including metadata fields in the response.
Expand All @@ -5061,6 +5069,7 @@ def list_model_version_artifact_links(
only_data_artifacts=only_data_artifacts,
only_model_artifacts=only_model_artifacts,
only_deployment_artifacts=only_deployment_artifacts,
has_custom_name=has_custom_name,
),
hydrate=hydrate,
)
Expand Down
3 changes: 2 additions & 1 deletion src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,5 @@ def versions(self) -> Dict[str, "ArtifactVersionResponse"]:
class ArtifactFilter(BaseFilter):
"""Model to enable advanced filtering of artifacts."""

name: Optional[str]
name: Optional[str] = None
has_custom_name: Optional[bool] = None
12 changes: 12 additions & 0 deletions src/zenml/models/v2/core/artifact_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ class ArtifactVersionFilter(WorkspaceScopedFilter):
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
"name",
"only_unused",
"has_custom_name",
]
artifact_id: Optional[Union[UUID, str]] = Field(
default=None,
Expand Down Expand Up @@ -401,6 +402,10 @@ class ArtifactVersionFilter(WorkspaceScopedFilter):
only_unused: Optional[bool] = Field(
default=False, description="Filter only for unused artifacts"
)
has_custom_name: Optional[bool] = Field(
default=None,
description="Filter only artifacts with/without custom names.",
)

def get_custom_filters(
self,
Expand Down Expand Up @@ -448,4 +453,11 @@ def get_custom_filters(
)
custom_filters.append(unused_filter)

if self.has_custom_name is not None:
custom_name_filter = and_( # type: ignore[type-var]
ArtifactVersionSchema.artifact_id == ArtifactSchema.id,
ArtifactSchema.has_custom_name == self.has_custom_name,
)
custom_filters.append(custom_name_filter)

return custom_filters
12 changes: 12 additions & 0 deletions src/zenml/models/v2/core/model_version_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,14 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
"only_data_artifacts",
"only_model_artifacts",
"only_deployment_artifacts",
"has_custom_name",
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
"only_data_artifacts",
"only_model_artifacts",
"only_deployment_artifacts",
"has_custom_name",
"model_id",
"model_version_id",
"user_id",
Expand Down Expand Up @@ -178,6 +180,7 @@ class ModelVersionArtifactFilter(WorkspaceScopedFilter):
only_data_artifacts: Optional[bool] = False
only_model_artifacts: Optional[bool] = False
only_deployment_artifacts: Optional[bool] = False
has_custom_name: Optional[bool] = None

def get_custom_filters(
self,
Expand Down Expand Up @@ -233,4 +236,13 @@ def get_custom_filters(
)
custom_filters.append(deployment_artifact_filter)

if self.has_custom_name is not None:
custom_name_filter = and_( # type: ignore[type-var]
ModelVersionArtifactSchema.artifact_version_id
== ArtifactVersionSchema.id,
ArtifactVersionSchema.artifact_id == ArtifactSchema.id,
ArtifactSchema.has_custom_name == self.has_custom_name,
)
custom_filters.append(custom_name_filter)

return custom_filters

0 comments on commit a37246c

Please sign in to comment.