Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model WatchTower] Delete running versions on failed pipelines with delete_new_version_on_failure option #1825

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
104 commits
Select commit Hold shift + click to select a range
320d27c
big bang commit
avishniakov Sep 8, 2023
5e72e3d
typo
avishniakov Sep 8, 2023
18b5344
Apply suggestions from code review
avishniakov Sep 8, 2023
b32a1d3
add Alembic
avishniakov Sep 11, 2023
661729d
lint
avishniakov Sep 11, 2023
6594e16
mypy
avishniakov Sep 11, 2023
b65910a
Merge branch 'develop' into feature/OSS-2417-model-class
avishniakov Sep 11, 2023
353bff8
darglint
avishniakov Sep 11, 2023
7374bd4
wip
avishniakov Sep 11, 2023
a856526
wip
avishniakov Sep 11, 2023
9908cbe
wip
avishniakov Sep 11, 2023
b75b17e
wip
avishniakov Sep 11, 2023
e15ad92
add endpoints
avishniakov Sep 11, 2023
04b6033
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2419-add…
avishniakov Sep 11, 2023
8590eff
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2418-mod…
avishniakov Sep 11, 2023
a49d2a9
add ModelStages
avishniakov Sep 11, 2023
077c6ee
wip
avishniakov Sep 11, 2023
59f4732
work with client
avishniakov Sep 11, 2023
9222d21
handle tags
avishniakov Sep 11, 2023
67c1286
fix integrations
avishniakov Sep 11, 2023
4a00124
move list around
avishniakov Sep 11, 2023
f18ca6f
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2418-mod…
avishniakov Sep 11, 2023
3df40d5
update db schema
avishniakov Sep 11, 2023
6bb853f
wip
avishniakov Sep 11, 2023
f8764a7
lint
avishniakov Sep 12, 2023
7b3c60e
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2418-mod…
avishniakov Sep 12, 2023
1205d17
sync with model branch
avishniakov Sep 12, 2023
e4c5ee0
wip
avishniakov Sep 12, 2023
13159a4
refactor
avishniakov Sep 12, 2023
7907b4e
add stage transition
avishniakov Sep 12, 2023
1a083b0
add update interface
avishniakov Sep 12, 2023
571026f
add model version links
avishniakov Sep 12, 2023
ad66d2e
lint
avishniakov Sep 12, 2023
876730d
fix crud tests
avishniakov Sep 13, 2023
ecb3379
Merge branch 'develop' into feature/OSS-2417-model-class
avishniakov Sep 13, 2023
07c26eb
fix alembic branching
avishniakov Sep 13, 2023
efdab85
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2418-mod…
avishniakov Sep 13, 2023
4e38be7
patch azure
avishniakov Sep 13, 2023
8b130bd
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2418-mod…
avishniakov Sep 13, 2023
79e2c09
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
fa9r Sep 13, 2023
d0a3c49
Merge branch 'feature/OSS-2417-model-class' into feature/OSS-2418-mod…
avishniakov Sep 13, 2023
9aa006d
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 14, 2023
8d14736
lint
avishniakov Sep 14, 2023
574ef03
use zenml StrEnum
avishniakov Sep 14, 2023
02bd3ce
fix param name
avishniakov Sep 14, 2023
4c1cc42
Merge branch 'feature/OSS-2418-modelversion' into feature/OSS-2419-ad…
avishniakov Sep 14, 2023
aa7175f
ModelConfig implementation
avishniakov Sep 14, 2023
ec0bc88
start testing
avishniakov Sep 14, 2023
efed59b
fix tests in docker
avishniakov Sep 14, 2023
ba6611b
Merge branch 'feature/OSS-2418-modelversion' into feature/OSS-2419-ad…
avishniakov Sep 15, 2023
144e905
more tests
avishniakov Sep 15, 2023
c980036
Merge branch 'develop' into feature/OSS-2418-modelversion
avishniakov Sep 15, 2023
8160116
fix tests for mysql
avishniakov Sep 15, 2023
180295b
rename artifact ids variables
avishniakov Sep 15, 2023
6db61b1
reorder methods
avishniakov Sep 15, 2023
2c8c083
add direct getters
avishniakov Sep 15, 2023
2434771
lint
avishniakov Sep 15, 2023
567a1bd
split links into 2 tables
avishniakov Sep 15, 2023
75ca942
Merge branch 'develop' into feature/OSS-2418-modelversion
avishniakov Sep 15, 2023
d211921
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 15, 2023
c887b0b
lint
avishniakov Sep 15, 2023
23f5ecc
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 18, 2023
aac07fc
Merge branch 'feature/OSS-2418-modelversion' into feature/OSS-2419-ad…
avishniakov Sep 18, 2023
9b0df95
pr comments
avishniakov Sep 18, 2023
2b3aec5
Merge branch 'feature/OSS-2418-modelversion' into feature/OSS-2419-ad…
avishniakov Sep 18, 2023
ba9fd6a
add ModelConfig to step deco
avishniakov Sep 18, 2023
71baba2
Revert "add ModelConfig to step deco"
avishniakov Sep 18, 2023
97bcb67
add ModelConfig to step deco
avishniakov Sep 18, 2023
eb41965
lint
avishniakov Sep 18, 2023
15ff020
Merge branch 'feature/OSS-2419-add-modelconfig' into feature/OSS-2429…
avishniakov Sep 18, 2023
55441c9
wip
avishniakov Sep 18, 2023
5723c4b
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 18, 2023
0e9c613
Add model_config to @pipeline and @step
avishniakov Sep 18, 2023
1089e49
Merge branch 'feature/OSS-2419-add-modelconfig' into feature/OSS-2429…
avishniakov Sep 18, 2023
ff14a77
add artifact config and necessary registrations around it
avishniakov Sep 19, 2023
913915c
fix bug with `Output` annotation
avishniakov Sep 20, 2023
fa75a45
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 20, 2023
ab992ea
Merge branch 'feature/OSS-2419-add-modelconfig' into feature/OSS-2429…
avishniakov Sep 20, 2023
523f6e9
Merge branch 'feature/OSS-2429-add-modelconfig-to-step-deco' into fea…
avishniakov Sep 20, 2023
02df15f
add docstring
avishniakov Sep 20, 2023
b815912
Merge branch 'feature/OSS-2429-add-modelconfig-to-step-deco' of https…
avishniakov Sep 20, 2023
258c31b
Merge branch 'feature/OSS-2429-add-modelconfig-to-step-deco' into fea…
avishniakov Sep 20, 2023
774be98
delete running versions on fail
avishniakov Sep 20, 2023
6b1096c
add deletion test
avishniakov Sep 20, 2023
cd9dab3
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 26, 2023
ea0a8d7
don't return running without recovery
avishniakov Sep 26, 2023
561beea
remove base_model.py
avishniakov Sep 26, 2023
c5e9776
PR comments
avishniakov Sep 26, 2023
8f29956
remove merged migration
avishniakov Sep 26, 2023
5d00b72
PR comments
avishniakov Sep 26, 2023
4fb0691
improve model_config warnings
avishniakov Sep 26, 2023
d0ba552
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Sep 26, 2023
4d9df58
clean up merge mess
avishniakov Sep 26, 2023
ac03dfa
Merge branch 'feature/OSS-2427-full-version-management-in-context' of…
avishniakov Sep 26, 2023
1297429
bandit is too strict
avishniakov Sep 26, 2023
717bebe
typos
avishniakov Sep 27, 2023
b79a76f
clean up merge mess
avishniakov Sep 27, 2023
16515a1
remove not relevant asserts
avishniakov Sep 27, 2023
94aef1f
improve docs
avishniakov Sep 27, 2023
e0ac1ce
improve docs
avishniakov Sep 27, 2023
6843ee6
rely on deployment in `get_new_version_requests`
avishniakov Sep 27, 2023
e094237
stabilize tests in random order
avishniakov Sep 27, 2023
cacc3e2
improve docs
avishniakov Sep 27, 2023
bdbced1
stabilize tests
avishniakov Sep 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions src/zenml/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ModelConfig(ModelConfigModel):
create_new_model_version: Whether to create a new model version during execution
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
recovery: Whether to keep failed runs with new versions for later recovery from it.
delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
"""

_model: Optional["ModelResponseModel"] = PrivateAttr(default=None)
Expand All @@ -61,16 +61,16 @@ def _validate_create_new_model_version(
values["version"] = RUNNING_MODEL_VERSION
return create_new_model_version

@validator("recovery")
@validator("delete_new_version_on_failure")
def _validate_recovery(
cls, recovery: bool, values: Dict[str, Any]
cls, delete_new_version_on_failure: bool, values: Dict[str, Any]
) -> bool:
if recovery:
if not delete_new_version_on_failure:
if not values.get("create_new_model_version", False):
logger.warning(
"Using `recovery` flag without `create_new_model_version=True` has no effect."
"Using `delete_new_version_on_failure=False` without `create_new_model_version=True` has no effect."
)
return recovery
return delete_new_version_on_failure

@validator("version")
def _validate_version(
Expand Down Expand Up @@ -157,15 +157,12 @@ def _create_model_version(
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
self._model_version = zenml_client.zen_store.get_model_version(
mv = zenml_client.zen_store.get_model_version(
model_name_or_id=self.name,
model_version_name_or_id=self.version,
)
self._model_version = mv
except KeyError:
if self.recovery:
logger.warning(
f"Recovery mode: No `{self.version}` model version found."
)
self._model_version = zenml_client.zen_store.create_model_version(
model_version=mv_request
)
Expand Down Expand Up @@ -201,10 +198,18 @@ def _get_model_version(
def get_or_create_model_version(self) -> "ModelVersionResponseModel":
"""This method should get or create a model and a model version from Model WatchTower.

New model is created implicitly, if missing, otherwise fetched.

New version will be created if `create_new_model_version`, otherwise
will try to fetch based on `model_version`.
A new model is created implicitly if missing, otherwise existing model is fetched. Model
name is controlled by the `name` parameter.

Model Version returned by this method is resolved based on model configuration:
- If there is an existing model version leftover from the previous failed run with
`delete_new_version_on_failure` is set to False and `create_new_model_version` is True,
leftover model version will be reused.
- Otherwise if `create_new_model_version` is True, a new model version is created.
- If `create_new_model_version` is False a model version will be fetched based on the version:
- If `version` is not set, the latest model version will be fetched.
- If `version` is set to a string, the model version with the matching version will be fetched.
- If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched.

Returns:
The model version based on configuration.
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/models/model_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ModelConfigModel(ModelBaseModel):
create_new_model_version: Whether to create a new model version during execution
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
recovery: Whether to keep new model versions from failed runs for later recovery.
delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it.
"""

version: Optional[Union[ModelStages, str]] = Field(
Expand All @@ -78,4 +78,4 @@ class ModelConfigModel(ModelBaseModel):
)
create_new_model_version: bool = False
save_models_to_registry: bool = True
recovery: bool = False
delete_new_version_on_failure: bool = True
84 changes: 54 additions & 30 deletions src/zenml/models/model_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class ModelVersionResponseModel(
model: "ModelResponseModel" = Field(
title="The model containing version",
)
model_object_ids: Dict[str, UUID] = Field(
model_object_ids: Dict[str, Dict[str, UUID]] = Field(
avishniakov marked this conversation as resolved.
Show resolved Hide resolved
title="Model Objects linked to the model version",
default={},
)
artifact_object_ids: Dict[str, UUID] = Field(
artifact_object_ids: Dict[str, Dict[str, UUID]] = Field(
title="Artifacts linked to the model version",
default={},
)
deployment_ids: Dict[str, UUID] = Field(
deployment_ids: Dict[str, Dict[str, UUID]] = Field(
title="Deployments linked to the model version",
default={},
)
Expand All @@ -86,45 +86,54 @@ class ModelVersionResponseModel(
)

@property
def model_objects(self) -> Dict[str, ArtifactResponseModel]:
"""Get all model objects linked to this version.
def model_objects(self) -> Dict[str, Dict[str, ArtifactResponseModel]]:
"""Get all model objects linked to this model version.

Returns:
Dictionary of Model Objects as ArtifactResponseModel
Dictionary of Model Objects with versions as Dict[str, ArtifactResponseModel]
"""
from zenml.client import Client

return {
name: Client().get_artifact(a)
for name, a in self.model_object_ids.items()
name: {
version: Client().get_artifact(a)
for version, a in self.model_object_ids[name].items()
}
for name in self.model_object_ids
}

@property
def artifact_objects(self) -> Dict[str, ArtifactResponseModel]:
"""Get all artifacts linked to this version.
def artifacts(self) -> Dict[str, Dict[str, ArtifactResponseModel]]:
"""Get all artifacts linked to this model version.

Returns:
Dictionary of Artifact Objects as ArtifactResponseModel
Dictionary of Artifacts with versions as Dict[str, ArtifactResponseModel]
"""
from zenml.client import Client

return {
name: Client().get_artifact(a)
for name, a in self.artifact_object_ids.items()
name: {
version: Client().get_artifact(a)
for version, a in self.artifact_object_ids[name].items()
}
for name in self.artifact_object_ids
}

@property
def deployments(self) -> Dict[str, ArtifactResponseModel]:
"""Get all deployments linked to this version.
def deployments(self) -> Dict[str, Dict[str, ArtifactResponseModel]]:
"""Get all deployments linked to this model version.

Returns:
Dictionary of Deployments as ArtifactResponseModel
Dictionary of Deployments with versions as Dict[str, ArtifactResponseModel]
"""
from zenml.client import Client

return {
name: Client().get_artifact(a)
for name, a in self.deployment_ids.items()
name: {
version: Client().get_artifact(a)
for version, a in self.deployment_ids[name].items()
}
for name in self.deployment_ids
}

@property
Expand All @@ -141,44 +150,59 @@ def pipeline_runs(self) -> Dict[str, PipelineRunResponseModel]:
for name, pr in self.pipeline_run_ids.items()
}

def get_model_object(self, name: str) -> ArtifactResponseModel:
"""Get model object linked to this version.
def get_model_object(
self, name: str, version: Optional[str] = None
) -> ArtifactResponseModel:
"""Get model object linked to this model version.

Args:
name: The name of the model object to retrieve.
version: The version of the model object to retrieve (None for latest/non-versioned)

Returns:
Model Object as ArtifactResponseModel
Specific version of Model Object
"""
from zenml.client import Client

return Client().get_artifact(self.model_object_ids[name])
if version is None:
version = max(self.model_object_ids[name].keys())
return Client().get_artifact(self.model_object_ids[name][version])

def get_artifact_object(self, name: str) -> ArtifactResponseModel:
"""Get artifact linked to this version.
def get_artifact_object(
self, name: str, version: Optional[str] = None
) -> ArtifactResponseModel:
"""Get artifact linked to this model version.

Args:
name: The name of the artifact to retrieve.
version: The version of the model object to retrieve (None for latest/non-versioned)

Returns:
Artifact Object as ArtifactResponseModel
Specific version of Artifact
"""
from zenml.client import Client

return Client().get_artifact(self.artifact_object_ids[name])
if version is None:
version = max(self.artifact_object_ids[name].keys())
return Client().get_artifact(self.artifact_object_ids[name][version])

def get_deployment(self, name: str) -> ArtifactResponseModel:
"""Get deployment linked to this version.
def get_deployment(
self, name: str, version: Optional[str] = None
) -> ArtifactResponseModel:
"""Get deployment linked to this model version.

Args:
name: The name of the deployment to retrieve.
version: The version of the model object to retrieve (None for latest/non-versioned)

Returns:
Deployment as ArtifactResponseModel
Specific version of Deployment
"""
from zenml.client import Client

return Client().get_artifact(self.deployment_ids[name])
if version is None:
version = max(self.deployment_ids[name].keys())
return Client().get_artifact(self.deployment_ids[name][version])

def get_pipeline_run(self, name: str) -> PipelineRunResponseModel:
"""Get pipeline run linked to this version.
Expand Down
Loading
Loading