Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Control Plane] parallel running versions support #1859

Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
50bc06a
linkage of pipeline runs
avishniakov Sep 28, 2023
0ebf539
add docstring
avishniakov Sep 28, 2023
14f1172
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 2, 2023
7e5e1fc
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 2, 2023
9947808
tricky bug
avishniakov Oct 2, 2023
c82878c
full linkage on consumption
avishniakov Oct 2, 2023
86bdf05
lint
avishniakov Oct 2, 2023
ba01176
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 2, 2023
f8dc84b
use client
avishniakov Oct 2, 2023
3d0318c
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 2, 2023
63e3a1b
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 2, 2023
5b9804c
Auto-update of E2E template
actions-user Oct 2, 2023
37dc0cd
update outdated test
avishniakov Oct 2, 2023
d04980f
Merge branch 'feature/OSS-2463-add-runs-linkage-to-model-versions' of…
avishniakov Oct 2, 2023
99abf71
create new version if any step touching it is executed and it was req…
avishniakov Oct 4, 2023
b7f91b2
refactor `_link_pipeline_run_to_model`
avishniakov Oct 4, 2023
e36bd2d
add few more tests
avishniakov Oct 4, 2023
c8ca951
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 4, 2023
511885c
parallel execution of model versions
avishniakov Oct 4, 2023
74addcf
add version number
avishniakov Oct 5, 2023
2f164e8
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 5, 2023
4b6bb7a
improve readability
avishniakov Oct 5, 2023
7d18a05
protect from misuse
avishniakov Oct 5, 2023
88c2952
extend `ArtifactConfig.model_version`
avishniakov Oct 5, 2023
af1d741
align model config docstrings
avishniakov Oct 5, 2023
12b38be
stabilize parallelized test
avishniakov Oct 5, 2023
1998d85
rework test as subprocess calls
avishniakov Oct 6, 2023
2e23ffd
skip subprocess test on windows
avishniakov Oct 6, 2023
20dc2ce
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 6, 2023
3d059f7
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 6, 2023
3295ac2
after merge mess
avishniakov Oct 6, 2023
8b6c50c
update tests flow based on develop
avishniakov Oct 7, 2023
8e549eb
proper handle __latest__ mv in REST
avishniakov Oct 7, 2023
95dc017
fix get model version endpoint
avishniakov Oct 8, 2023
fca8789
Merge branch 'feature/OSS-2300-model-watch-tower-v0.1' into feature/O…
avishniakov Oct 10, 2023
b2c715f
simplify user-facing interface
avishniakov Oct 10, 2023
532192d
fix test annotation
avishniakov Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/zenml/artifacts/external_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ class ExternalArtifact(ExternalArtifactConfiguration):

Example:
```
from zenml import step, pipeline, ExternalArtifact
from zenml import step, pipeline
from zenml.artifacts.external_artifact import ExternalArtifact
import numpy as np

@step
Expand All @@ -99,7 +100,8 @@ def _validate_all(cls, values: Dict[str, Any]) -> Dict[str, Any]:
pipeline_name = values.get("pipeline_name", None)
artifact_name = values.get("artifact_name", None)
model_name = values.get("model_name", None)
model_version = values.get("model_version", None)
model_version_name = values.get("model_version_name", None)
model_version_number = values.get("model_version_number", None)
Comment on lines -102 to +104
Copy link
Contributor

@fa9r fa9r Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand why this is needed. Feels like model_version_name is the model version and model_version_number is the version of the model version. If that's correct then I think this needs to be redesigned since multi-layered versioning is too confusing in my opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This supports what @stefannica asked for.
Model Version is a container entity, which is represented by a display name (model_version_name) and continuous number (model_version_number). A display name is what the user can control, while a number is an auto-incremental integer assigned to it during creation. This change gives the user the flexibility to fetch the model version either by display name or by number.
These two are split into separate args, cause I thought it might be too confusing to understand why a "1" would be fetched by name, while 1 would be fetched by number while keeping it mixed in all internal methods.

model_artifact_name = values.get("model_artifact_name", None)

if (value is not None) + (id is not None) + (
Expand All @@ -116,7 +118,10 @@ def _validate_all(cls, values: Dict[str, Any]) -> Dict[str, Any]:
value,
id,
pipeline_name or artifact_name,
model_name or model_version or model_artifact_name,
model_name
or model_version_name
or model_artifact_name
or model_version_number,
]
):
raise ValueError(
Expand Down Expand Up @@ -192,7 +197,8 @@ def config(self) -> ExternalArtifactConfiguration:
pipeline_name=self.pipeline_name,
artifact_name=self.artifact_name,
model_name=self.model_name,
model_version=self.model_version,
model_version_name=self.model_version_name,
model_version_number=self.model_version_number,
model_artifact_name=self.model_artifact_name,
model_artifact_version=self.model_artifact_version,
model_artifact_pipeline_name=self.model_artifact_pipeline_name,
Expand Down
40 changes: 32 additions & 8 deletions src/zenml/artifacts/external_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""External artifact definition."""
from typing import TYPE_CHECKING, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, root_validator

from zenml.config.source import Source
from zenml.enums import ModelStages
Expand All @@ -42,12 +42,33 @@ class ExternalArtifactConfiguration(BaseModel):
pipeline_name: Optional[str] = None
artifact_name: Optional[str] = None
model_name: Optional[str] = None
model_version: Optional[str] = None
model_artifact_name: Optional[Union[str, ModelStages]] = None
model_version_name: Optional[Union[str, ModelStages]] = None
model_version_number: Optional[int] = None
model_artifact_name: Optional[str] = None
model_artifact_version: Optional[str] = None
model_artifact_pipeline_name: Optional[str] = None
model_artifact_step_name: Optional[str] = None

@root_validator
def _validate_all(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validates all fields.

Args:
values: The values dict used to instantiate the model.

Returns:
The validated values dict.
"""
if values.get("model_version_number", None) and values.get(
"model_version_name", None
):
logger.warning(
"`model_version_number` has higher priority then `model_version_name`."
"Setting `model_version_name` to `None`."
)
values["model_version_name"] = None
return values
avishniakov marked this conversation as resolved.
Show resolved Hide resolved

def _get_artifact_from_pipeline_run(self) -> "ArtifactResponseModel":
"""Get artifact from pipeline run.

Expand Down Expand Up @@ -103,10 +124,13 @@ def _get_artifact_from_model(
"@pipeline definitions."
)
self.model_name = model_config.name
self.model_version = model_config.version
self.model_version_name = model_config.version_name
self.model_version_number = model_config.version_number

_model_config = ModelConfig(
name=self.model_name, version=self.model_version
name=self.model_name,
version_name=self.model_version_name,
version_number=self.model_version_number,
)
model_version = _model_config._get_model_version()

Expand All @@ -127,13 +151,13 @@ def _get_artifact_from_model(
if response is None:
raise RuntimeError(
f"Artifact with name `{self.model_artifact_name}` was not found "
f"in model `{self.model_name}` version `{self.model_version}`. "
f"in model `{self.model_name}` version `{self.model_version_name}`. "
"Please check your inputs and try again."
)

return response

def get_artifact(
def get_artifact_id(
self, model_config: Optional["ModelConfig"] = None
) -> UUID:
"""Get the artifact.
Expand Down
8 changes: 5 additions & 3 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5079,21 +5079,23 @@ def delete_model_version(
def get_model_version(
self,
model_name_or_id: Union[str, UUID],
model_version_name_or_id: Union[str, UUID, ModelStages] = "__latest__",
model_version_name_or_number_or_id: Optional[
Union[str, int, UUID, ModelStages]
] = None,
) -> ModelVersionResponseModel:
"""Get an existing model version from Model WatchTower.

Args:
model_name_or_id: name or id of the model containing the model version.
model_version_name_or_id: name, id or stage of the model version to be retrieved.
model_version_name_or_number_or_id: name, id, stage or number of the model version to be retrieved.
If skipped latest version will be retrieved.

Returns:
The model version of interest.
"""
return self.zen_store.get_model_version(
model_name_or_id=model_name_or_id,
model_version_name_or_id=model_version_name_or_id,
model_version_name_or_number_or_id=model_version_name_or_number_or_id,
)

def list_model_versions(
Expand Down
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,4 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# Model WatchTower constants
RUNNING_MODEL_VERSION = "running"
LATEST_MODEL_VERSION_PLACEHOLDER = "__latest__"
36 changes: 23 additions & 13 deletions src/zenml/model/artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ class ArtifactConfig(BaseModel):
"""Used to link a generic Artifact to the model version.

model_name: The name of the model to link artifact to.
model_version_name: The name of the model version to link artifact to.
It can be exact version ("23"), stage (ModelStages.PRODUCTION) or None
for the latest version.
model_version: The identifier of the model version to link artifact to.
It can be exact version ("23"), exact version number (42), stage
(ModelStages.PRODUCTION) or None for the latest version.
model_stage: The stage of the model version to link artifact to.
artifact_name: The override name of a link instead of an artifact name.
overwrite: Whether to overwrite an existing link or create new versions.
"""

model_name: Optional[str]
model_version_name: Optional[Union[ModelStages, str]]
model_version: Optional[Union[ModelStages, str, int]]
artifact_name: Optional[str]
overwrite: bool = False

Expand All @@ -56,6 +56,11 @@ class ArtifactConfig(BaseModel):
IS_MODEL_ARTIFACT: ClassVar[bool] = False
IS_DEPLOYMENT_ARTIFACT: ClassVar[bool] = False

class Config:
"""Config class for ArtifactConfig."""

smart_union = True

@property
def _model_config(self) -> "ModelConfig":
"""Property that returns the model configuration.
Expand All @@ -80,7 +85,12 @@ def _model_config(self) -> "ModelConfig":

on_the_fly_config = ModelConfig(
name=self.model_name,
version=self.model_version_name,
version_name=self.model_version
if not isinstance(self.model_version, int)
else None,
version_number=self.model_version
if isinstance(self.model_version, int)
else None,
create_new_model_version=False,
)
return on_the_fly_config
Expand All @@ -95,7 +105,7 @@ def _model_config(self) -> "ModelConfig":
return model_config

@property
def model(self) -> "ModelResponseModel":
def _model(self) -> "ModelResponseModel":
"""Get the `ModelResponseModel`.

Returns:
Expand All @@ -104,7 +114,7 @@ def model(self) -> "ModelResponseModel":
return self._model_config.get_or_create_model()

@property
def model_version(self) -> "ModelVersionResponseModel":
def _model_version(self) -> "ModelVersionResponseModel":
"""Get the `ModelVersionResponseModel`.

Returns:
Expand Down Expand Up @@ -143,8 +153,8 @@ def _link_to_model_version(
workspace=client.active_workspace.id,
name=artifact_name,
artifact=artifact_uuid,
model=self.model.id,
model_version=self.model_version.id,
model=self._model.id,
model_version=self._model_version.id,
is_model_object=is_model_object,
is_deployment=is_deployment,
overwrite=self.overwrite,
Expand All @@ -158,8 +168,8 @@ def _link_to_model_version(
user_id=client.active_user.id,
workspace_id=client.active_workspace.id,
name=artifact_name,
model_id=self.model.id,
model_version_id=self.model_version.id,
model_id=self._model.id,
model_version_id=self._model_version.id,
only_artifacts=not (is_model_object or is_deployment),
only_deployments=is_deployment,
only_model_objects=is_model_object,
Expand All @@ -172,8 +182,8 @@ def _link_to_model_version(
f"Existing artifact link(s) `{artifact_name}` found and will be deleted."
)
client.zen_store.delete_model_version_artifact_link(
model_name_or_id=self.model.id,
model_version_name_or_id=self.model_version.id,
model_name_or_id=self._model.id,
model_version_name_or_id=self._model_version.id,
model_version_artifact_link_name_or_id=artifact_name,
)
else:
Expand Down
72 changes: 26 additions & 46 deletions src/zenml/model/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# permissions and limitations under the License.
"""ModelConfig user facing interface to pass into pipeline or step."""

from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Optional

from pydantic import PrivateAttr, validator
from pydantic import PrivateAttr

from zenml.enums import ModelStages
from zenml.exceptions import EntityExistsError
from zenml.logger import get_logger
from zenml.models.model_base_model import ModelConfigModel
Expand All @@ -34,7 +33,9 @@
class ModelConfig(ModelConfigModel):
"""ModelConfig class to pass into pipeline or step to set it into a model context.

version: points model context to a specific version or stage.
version_name: points model context to a specific version or stage.
version_number: points model context to a specific version number.
version_description: The description of the model version.
create_new_model_version: Whether to create a new model version during execution
save_models_to_registry: Whether to save all ModelArtifacts to Model Registry,
if available in active stack.
Expand All @@ -46,42 +47,6 @@ class ModelConfig(ModelConfigModel):
default=None
)

@validator("create_new_model_version")
def _validate_create_new_model_version(
cls, create_new_model_version: bool, values: Dict[str, Any]
) -> bool:
from zenml.constants import RUNNING_MODEL_VERSION

if create_new_model_version:
version = values.get("version", RUNNING_MODEL_VERSION)
if version != RUNNING_MODEL_VERSION and version is not None:
raise ValueError(
"`version` cannot be used with `create_new_model_version`."
)
values["version"] = RUNNING_MODEL_VERSION
return create_new_model_version

@validator("delete_new_version_on_failure")
def _validate_recovery(
cls, delete_new_version_on_failure: bool, values: Dict[str, Any]
) -> bool:
if not delete_new_version_on_failure:
if not values.get("create_new_model_version", False):
logger.warning(
"Using `delete_new_version_on_failure=False` without `create_new_model_version=True` has no effect."
)
return delete_new_version_on_failure

@validator("version")
def _validate_version(
cls, version: Union[str, ModelStages]
) -> Union[str, ModelStages]:
if version in [stage.value for stage in ModelStages]:
logger.info(
f"`version` `{version}` matches one of the possible `ModelStages`, model will be fetched using stage."
)
return version

def get_or_create_model(self) -> "ModelResponseModel":
"""This method should get or create a model from Model WatchTower.

Expand Down Expand Up @@ -140,22 +105,21 @@ def _create_model_version(
return self._model_version

from zenml.client import Client
from zenml.constants import RUNNING_MODEL_VERSION
from zenml.models.model_models import ModelVersionRequestModel

zenml_client = Client()
self.version = RUNNING_MODEL_VERSION
model_version_request = ModelVersionRequestModel(
user=zenml_client.active_user.id,
workspace=zenml_client.active_workspace.id,
version=self.version,
name=self.version_name,
description=self.version_description,
model=model.id,
)
mv_request = ModelVersionRequestModel.parse_obj(model_version_request)
try:
mv = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_id=self.version,
model_version_name_or_number_or_id=self.version_name,
)
self._model_version = mv
except KeyError:
Expand All @@ -178,7 +142,7 @@ def _get_model_version(self) -> "ModelVersionResponseModel":
from zenml.client import Client

zenml_client = Client()
if self.version is None:
if self.version_name is None and self.version_number is None:
# raise if not found
self._model_version = zenml_client.get_model_version(
model_name_or_id=self.name
Expand All @@ -188,7 +152,8 @@ def _get_model_version(self) -> "ModelVersionResponseModel":
# raise if not found
self._model_version = zenml_client.get_model_version(
model_name_or_id=self.name,
model_version_name_or_id=self.version,
model_version_name_or_number_or_id=self.version_number
or self.version_name,
)
return self._model_version

Expand Down Expand Up @@ -217,3 +182,18 @@ def get_or_create_model_version(self) -> "ModelVersionResponseModel":
else:
mv = self._get_model_version()
return mv

def _merge_with_config(self, model_config: ModelConfigModel) -> None:
self.license = self.license or model_config.license
self.description = self.description or model_config.description
self.audience = self.audience or model_config.audience
self.use_cases = self.use_cases or model_config.use_cases
self.limitations = self.limitations or model_config.limitations
self.trade_offs = self.trade_offs or model_config.trade_offs
self.ethic = self.ethic or model_config.ethic
if model_config.tags is not None:
self.tags = (self.tags or []) + model_config.tags

self.delete_new_version_on_failure &= (
model_config.delete_new_version_on_failure
)
Loading
Loading