diff --git a/src/zenml/model/model_config.py b/src/zenml/model/model_config.py index 883c031598b..7385eaa0610 100644 --- a/src/zenml/model/model_config.py +++ b/src/zenml/model/model_config.py @@ -38,7 +38,7 @@ class ModelConfig(ModelConfigModel): create_new_model_version: Whether to create a new model version during execution save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, if available in active stack. - recovery: Whether to keep failed runs with new versions for later recovery from it. + delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. """ _model: Optional["ModelResponseModel"] = PrivateAttr(default=None) @@ -61,16 +61,16 @@ def _validate_create_new_model_version( values["version"] = RUNNING_MODEL_VERSION return create_new_model_version - @validator("recovery") + @validator("delete_new_version_on_failure") def _validate_recovery( - cls, recovery: bool, values: Dict[str, Any] + cls, delete_new_version_on_failure: bool, values: Dict[str, Any] ) -> bool: - if recovery: + if not delete_new_version_on_failure: if not values.get("create_new_model_version", False): logger.warning( - "Using `recovery` flag without `create_new_model_version=True` has no effect." + "Using `delete_new_version_on_failure=False` without `create_new_model_version=True` has no effect." ) - return recovery + return delete_new_version_on_failure @validator("version") def _validate_version( @@ -157,15 +157,12 @@ def _create_model_version( ) mv_request = ModelVersionRequestModel.parse_obj(model_version_request) try: - self._model_version = zenml_client.zen_store.get_model_version( + mv = zenml_client.zen_store.get_model_version( model_name_or_id=self.name, model_version_name_or_id=self.version, ) + self._model_version = mv except KeyError: - if self.recovery: - logger.warning( - f"Recovery mode: No `{self.version}` model version found." - ) self._model_version = zenml_client.zen_store.create_model_version( model_version=mv_request ) @@ -201,10 +198,18 @@ def _get_model_version( def get_or_create_model_version(self) -> "ModelVersionResponseModel": """This method should get or create a model and a model version from Model WatchTower. - New model is created implicitly, if missing, otherwise fetched. - - New version will be created if `create_new_model_version`, otherwise - will try to fetch based on `model_version`. + A new model is created implicitly if missing, otherwise existing model is fetched. Model + name is controlled by the `name` parameter. + + Model Version returned by this method is resolved based on model configuration: + - If there is an existing model version leftover from the previous failed run with + `delete_new_version_on_failure` is set to False and `create_new_model_version` is True, + leftover model version will be reused. + - Otherwise if `create_new_model_version` is True, a new model version is created. + - If `create_new_model_version` is False a model version will be fetched based on the version: + - If `version` is not set, the latest model version will be fetched. + - If `version` is set to a string, the model version with the matching version will be fetched. + - If `version` is set to a `ModelStage`, the model version with the matching stage will be fetched. Returns: The model version based on configuration. diff --git a/src/zenml/models/model_base_model.py b/src/zenml/models/model_base_model.py index 59841a93b56..5eb85b5ae5d 100644 --- a/src/zenml/models/model_base_model.py +++ b/src/zenml/models/model_base_model.py @@ -68,7 +68,7 @@ class ModelConfigModel(ModelBaseModel): create_new_model_version: Whether to create a new model version during execution save_models_to_registry: Whether to save all ModelArtifacts to Model Registry, if available in active stack. - recovery: Whether to keep new model versions from failed runs for later recovery. + delete_new_version_on_failure: Whether to delete failed runs with new versions for later recovery from it. """ version: Optional[Union[ModelStages, str]] = Field( @@ -78,4 +78,4 @@ class ModelConfigModel(ModelBaseModel): ) create_new_model_version: bool = False save_models_to_registry: bool = True - recovery: bool = False + delete_new_version_on_failure: bool = True diff --git a/src/zenml/models/model_models.py b/src/zenml/models/model_models.py index b387d2a75b3..1cf1a233a7a 100644 --- a/src/zenml/models/model_models.py +++ b/src/zenml/models/model_models.py @@ -68,15 +68,15 @@ class ModelVersionResponseModel( model: "ModelResponseModel" = Field( title="The model containing version", ) - model_object_ids: Dict[str, UUID] = Field( + model_object_ids: Dict[str, Dict[str, UUID]] = Field( title="Model Objects linked to the model version", default={}, ) - artifact_object_ids: Dict[str, UUID] = Field( + artifact_object_ids: Dict[str, Dict[str, UUID]] = Field( title="Artifacts linked to the model version", default={}, ) - deployment_ids: Dict[str, UUID] = Field( + deployment_ids: Dict[str, Dict[str, UUID]] = Field( title="Deployments linked to the model version", default={}, ) @@ -86,45 +86,54 @@ class ModelVersionResponseModel( ) @property - def model_objects(self) -> Dict[str, ArtifactResponseModel]: - """Get all model objects linked to this version. + def model_objects(self) -> Dict[str, Dict[str, ArtifactResponseModel]]: + """Get all model objects linked to this model version. Returns: - Dictionary of Model Objects as ArtifactResponseModel + Dictionary of Model Objects with versions as Dict[str, ArtifactResponseModel] """ from zenml.client import Client return { - name: Client().get_artifact(a) - for name, a in self.model_object_ids.items() + name: { + version: Client().get_artifact(a) + for version, a in self.model_object_ids[name].items() + } + for name in self.model_object_ids } @property - def artifact_objects(self) -> Dict[str, ArtifactResponseModel]: - """Get all artifacts linked to this version. + def artifacts(self) -> Dict[str, Dict[str, ArtifactResponseModel]]: + """Get all artifacts linked to this model version. Returns: - Dictionary of Artifact Objects as ArtifactResponseModel + Dictionary of Artifacts with versions as Dict[str, ArtifactResponseModel] """ from zenml.client import Client return { - name: Client().get_artifact(a) - for name, a in self.artifact_object_ids.items() + name: { + version: Client().get_artifact(a) + for version, a in self.artifact_object_ids[name].items() + } + for name in self.artifact_object_ids } @property - def deployments(self) -> Dict[str, ArtifactResponseModel]: - """Get all deployments linked to this version. + def deployments(self) -> Dict[str, Dict[str, ArtifactResponseModel]]: + """Get all deployments linked to this model version. Returns: - Dictionary of Deployments as ArtifactResponseModel + Dictionary of Deployments with versions as Dict[str, ArtifactResponseModel] """ from zenml.client import Client return { - name: Client().get_artifact(a) - for name, a in self.deployment_ids.items() + name: { + version: Client().get_artifact(a) + for version, a in self.deployment_ids[name].items() + } + for name in self.deployment_ids } @property @@ -141,44 +150,59 @@ def pipeline_runs(self) -> Dict[str, PipelineRunResponseModel]: for name, pr in self.pipeline_run_ids.items() } - def get_model_object(self, name: str) -> ArtifactResponseModel: - """Get model object linked to this version. + def get_model_object( + self, name: str, version: Optional[str] = None + ) -> ArtifactResponseModel: + """Get model object linked to this model version. Args: name: The name of the model object to retrieve. + version: The version of the model object to retrieve (None for latest/non-versioned) Returns: - Model Object as ArtifactResponseModel + Specific version of Model Object """ from zenml.client import Client - return Client().get_artifact(self.model_object_ids[name]) + if version is None: + version = max(self.model_object_ids[name].keys()) + return Client().get_artifact(self.model_object_ids[name][version]) - def get_artifact_object(self, name: str) -> ArtifactResponseModel: - """Get artifact linked to this version. + def get_artifact_object( + self, name: str, version: Optional[str] = None + ) -> ArtifactResponseModel: + """Get artifact linked to this model version. Args: name: The name of the artifact to retrieve. + version: The version of the model object to retrieve (None for latest/non-versioned) Returns: - Artifact Object as ArtifactResponseModel + Specific version of Artifact """ from zenml.client import Client - return Client().get_artifact(self.artifact_object_ids[name]) + if version is None: + version = max(self.artifact_object_ids[name].keys()) + return Client().get_artifact(self.artifact_object_ids[name][version]) - def get_deployment(self, name: str) -> ArtifactResponseModel: - """Get deployment linked to this version. + def get_deployment( + self, name: str, version: Optional[str] = None + ) -> ArtifactResponseModel: + """Get deployment linked to this model version. Args: name: The name of the deployment to retrieve. + version: The version of the model object to retrieve (None for latest/non-versioned) Returns: - Deployment as ArtifactResponseModel + Specific version of Deployment """ from zenml.client import Client - return Client().get_artifact(self.deployment_ids[name]) + if version is None: + version = max(self.deployment_ids[name].keys()) + return Client().get_artifact(self.deployment_ids[name][version]) def get_pipeline_run(self, name: str) -> PipelineRunResponseModel: """Get pipeline run linked to this version. diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index 58c51d709cf..08a270e384c 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -508,6 +508,9 @@ def _run( unlisted: Whether the pipeline run should be unlisted (not assigned to any pipeline). prevent_build_reuse: Whether to prevent the reuse of a build. + + Raises: + Exception: bypass any exception from pipeline up. """ if constants.SHOULD_PREVENT_PIPELINE_EXECUTION: # An environment variable was set to stop the execution of @@ -590,6 +593,8 @@ def _run( stack = Client().active_stack + new_version_requests = self.get_new_version_requests(deployment) + local_repo_context = ( code_repository_utils.find_active_code_repository() ) @@ -647,6 +652,11 @@ def _run( constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True try: stack.deploy_pipeline(deployment=deployment_model) + except Exception as e: + self.delete_running_versions_without_recovery( + new_version_requests + ) + raise e finally: constants.SHOULD_PREVENT_PIPELINE_EXECUTION = False @@ -657,7 +667,7 @@ def _run( ) if runs.items: - self.register_running_versions(runs.items[0]) + self.register_running_versions(new_version_requests) run_url = dashboard_utils.get_run_url(runs[0]) if run_url: logger.info(f"Dashboard URL: {run_url}") @@ -746,49 +756,108 @@ def log_pipeline_deployment_metadata( except Exception as e: logger.debug(f"Logging pipeline deployment metadata failed: {e}") - def register_running_versions( - self, pipeline: PipelineRunResponseModel - ) -> None: - """Registers the running versions of the models used in the given pipeline run. + def get_new_version_requests( + self, deployment: "PipelineDeploymentBaseModel" + ) -> Dict[str, Dict[str, Any]]: + """Get the running versions of the models that are used in the pipeline run. Args: - pipeline: The pipeline run response model. + deployment: The pipeline deployment configuration. - Raises: - KeyError: No running model version found for @step `model_config`s. + Returns: + A dict of dicts containing requesters of new version and if it should be kept on failure. """ - models_to_register = set() - pipeline_model_name = None - for step_name, step in pipeline.steps.items(): + new_versions_requested: Dict[str, Dict[str, Any]] = {} + all_steps_have_own_config = True + for step in deployment.step_configurations.values(): + step_model_config = step.config.model_config_model + all_steps_have_own_config = ( + all_steps_have_own_config + and step.config.model_config_model is not None + ) if ( - step.config.model_config - and step.config.model_config.create_new_model_version + step_model_config + and step_model_config.create_new_model_version ): - models_to_register.add(step.config.model_config.name) - if ( - pipeline.config.model_config - and pipeline.config.model_config.create_new_model_version - ): - pipeline_model_name = pipeline.config.model_config.name - models_to_register.add(pipeline.config.model_config.name) + model_name = step_model_config.name + new_versions_requested[ + model_name + ] = new_versions_requested.get( + model_name, + {"requesters": [], "delete_new_version_on_failure": True}, + ) + new_versions_requested[model_name]["requesters"].append( + f"Step: {step.config.name}" + ) + new_versions_requested[model_name][ + "delete_new_version_on_failure" + ] &= step_model_config.delete_new_version_on_failure + if not all_steps_have_own_config: + pipeline_model_config = ( + deployment.pipeline_configuration.model_config_model + ) + if ( + pipeline_model_config + and pipeline_model_config.create_new_model_version + ): + new_versions_requested[ + pipeline_model_config.name + ] = new_versions_requested.get( + pipeline_model_config.name, + {"requesters": [], "delete_new_version_on_failure": True}, + ) + new_versions_requested[pipeline_model_config.name][ + "requesters" + ].append(f"Pipeline: {self.name}") + new_versions_requested[pipeline_model_config.name][ + "delete_new_version_on_failure" + ] &= pipeline_model_config.delete_new_version_on_failure + elif deployment.pipeline_configuration.model_config_model is not None: + logger.warning( + f"ModelConfig of pipeline `{self.name}` is overridden in all steps. " + ) + for model_name, data in new_versions_requested.items(): + if len(data["requesters"]) > 1: + logger.warning( + f"New version of model `{model_name}` requested in multiple decorators:\n" + f"{data['requesters']}\n We recommend that `create_new_model_version` is configured " + "only in one place of the pipeline." + ) + + return new_versions_requested + + def register_running_versions( + self, new_version_requests: Dict[str, Dict[str, Any]] + ) -> None: + """Registers the running versions of the models used in the given pipeline run. + + Args: + new_version_requests: Dict of models requesting new versions and their definition points. + """ zs = Client().zen_store - for model_name in models_to_register: - try: + for model_name, requesters in new_version_requests.items(): + if requesters["delete_new_version_on_failure"]: mv = zs.get_model_version( model_name_or_id=model_name, model_version_name_or_id=RUNNING_MODEL_VERSION, ) mv._assign_version_to_running() - except KeyError as e: - if model_name == pipeline_model_name: - logger.warning( - f"Failed to register stable model version of `{model_name}` model. " - f"No `{RUNNING_MODEL_VERSION}` version found. " - "Most probable root cause: you set ModelConfig on pipeline level and " - "override it in all steps inside that pipeline." - ) - else: - raise e + + def delete_running_versions_without_recovery( + self, new_version_requests: Dict[str, Dict[str, Any]] + ) -> None: + """Delete the running versions of the models without `restore` after fail. + + Args: + new_version_requests: Dict of models requesting new versions and their definition points. + """ + zs = Client().zen_store + for model_name, requesters in new_version_requests.items(): + if requesters["delete_new_version_on_failure"]: + zs.delete_model_version( + model_name_or_id=model_name, + model_version_name_or_id=RUNNING_MODEL_VERSION, + ) def get_runs(self, **kwargs: Any) -> List[PipelineRunResponseModel]: """(Deprecated) Get runs of this pipeline. diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 1ce30151fca..565a9c9cf71 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -246,23 +246,34 @@ def to_model(self) -> ModelVersionResponseModel: description=self.description, stage=self.stage, model_object_ids={ - al.name: al.artifact_id + al.name: { + al.version: al.artifact_id + for al in self.artifact_links + if al.is_model_object + } for al in self.artifact_links - if al.artifact_id is not None and al.is_model_object + if al.is_model_object }, deployment_ids={ - al.name: al.artifact_id + al.name: { + al.version: al.artifact_id + for al in self.artifact_links + if al.is_deployment + } for al in self.artifact_links - if al.artifact_id is not None and al.is_deployment + if al.is_deployment }, artifact_object_ids={ - al.name: al.artifact_id + al.name: { + al.version: al.artifact_id + for al in self.artifact_links + if not (al.is_deployment or al.is_model_object) + } for al in self.artifact_links - if al.artifact_id is not None - and not (al.is_deployment or al.is_model_object) + if not (al.is_deployment or al.is_model_object) }, pipeline_run_ids={ - al.name: al.pipeline_run_id for al in self.pipeline_run_links + pr.name: pr.pipeline_run_id for pr in self.pipeline_run_links }, ) diff --git a/tests/integration/functional/model/test_artifact_config.py b/tests/integration/functional/model/test_artifact_config.py index 7a2e76a737d..1be70157f34 100644 --- a/tests/integration/functional/model/test_artifact_config.py +++ b/tests/integration/functional/model/test_artifact_config.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -from contextlib import contextmanager from typing import Callable, Tuple import pytest from typing_extensions import Annotated +from tests.integration.functional.utils import model_killer from zenml import pipeline, step from zenml.client import Client from zenml.enums import ModelStages @@ -37,15 +37,6 @@ MODEL_NAME = "foo" -@contextmanager -def model_killer(model_name: str = MODEL_NAME): - try: - yield - finally: - zs = Client().zen_store - zs.delete_model(model_name) - - @step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) def single_output_step_from_context() -> Annotated[int, ArtifactConfig()]: """Untyped single output linked as Artifact from step context.""" @@ -87,9 +78,6 @@ def test_link_minimalistic(): user = Client().active_user.id ws = Client().active_workspace.id - with pytest.raises(KeyError): - zs.get_model(MODEL_NAME) - simple_pipeline() model = zs.get_model(MODEL_NAME) @@ -105,12 +93,25 @@ def test_link_minimalistic(): ) ) assert links.size == 3 + + one_is_deployment = False + one_is_model_object = False + one_is_artifact = False for link in links: assert link.link_version == 1 assert link.name == "output" - assert not (links[0].is_deployment or links[0].is_model_object) - assert not links[1].is_deployment and links[1].is_model_object - assert links[2].is_deployment and not links[2].is_model_object + one_is_deployment ^= ( + link.is_deployment and not link.is_model_object + ) + one_is_model_object ^= ( + not link.is_deployment and link.is_model_object + ) + one_is_artifact ^= ( + not link.is_deployment and not link.is_model_object + ) + assert one_is_deployment + assert one_is_model_object + assert one_is_artifact @step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) @@ -138,9 +139,6 @@ def test_link_multiple_named_outputs(): user = Client().active_user.id ws = Client().active_workspace.id - with pytest.raises(KeyError): - zs.get_model(MODEL_NAME) - multi_named_pipeline() model = zs.get_model(MODEL_NAME) @@ -159,9 +157,7 @@ def test_link_multiple_named_outputs(): assert ( al[0].link_version + al[1].link_version + al[2].link_version == 3 ) - assert al[0].name == "1" - assert al[1].name == "2" - assert al[2].name == "3" + assert {al.name for al in al} == {"1", "2", "3"} @step(model_config=ModelConfig(name=MODEL_NAME, create_new_model_version=True)) @@ -240,60 +236,61 @@ def multi_named_pipeline_from_self(): def test_link_multiple_named_outputs_with_self_context(): """Test multi output linking with context defined in Annotated.""" with model_killer(): - with model_killer("bar"): - zs = Client().zen_store - user = Client().active_user.id - ws = Client().active_workspace.id + zs = Client().zen_store + user = Client().active_user.id + ws = Client().active_workspace.id - # manual creation needed, as we work with specific versions - m1 = ModelConfig( - name=MODEL_NAME, - ).get_or_create_model() - m2 = ModelConfig( - name="bar", - ).get_or_create_model() - - mv1 = zs.create_model_version( - ModelVersionRequestModel( - user=user, - workspace=ws, - version="bar", - model=m1.id, - ) + # manual creation needed, as we work with specific versions + m1 = ModelConfig( + name=MODEL_NAME, + ).get_or_create_model() + m2 = ModelConfig( + name="bar", + ).get_or_create_model() + + mv1 = zs.create_model_version( + ModelVersionRequestModel( + user=user, + workspace=ws, + version="bar", + model=m1.id, ) - mv2 = zs.create_model_version( - ModelVersionRequestModel( - user=user, - workspace=ws, - version="foo", - model=m2.id, - ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=user, + workspace=ws, + version="foo", + model=m2.id, ) + ) - multi_named_pipeline_from_self() + multi_named_pipeline_from_self() - al1 = zs.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv1.model.id, - model_version_id=mv1.id, - ) + al1 = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv1.model.id, + model_version_id=mv1.id, ) - al2 = zs.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv2.model.id, - model_version_id=mv2.id, - ) + ) + al2 = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv2.model.id, + model_version_id=mv2.id, ) - assert al1.size == 2 - assert al2.size == 1 + ) + assert al1.size == 2 + assert al2.size == 1 - assert al1[0].name == "1" - assert al1[1].name == "2" - assert al2[0].name == "3" + assert {al.name for al in al1} == { + "1", + "2", + } + assert al2[0].name == "3" @step(model_config=ModelConfig(name="step", version="step")) @@ -354,68 +351,66 @@ def multi_named_pipeline_mixed_linkage(): def test_link_multiple_named_outputs_with_mixed_linkage(): """In this test a mixed linkage of artifacts is verified. See steps description.""" - with model_killer("pipe"): - with model_killer("step"): - with model_killer("artifact"): - zs = Client().zen_store - user = Client().active_user.id - ws = Client().active_workspace.id - - # manual creation needed, as we work with specific versions - models = [] - mvs = [] - for n in ["pipe", "step", "artifact"]: - models.append( - ModelConfig( - name=n, - ).get_or_create_model() - ) - mvs.append( - zs.create_model_version( - ModelVersionRequestModel( - user=user, - workspace=ws, - version=n, - model=models[-1].id, - ) - ) + with model_killer(): + zs = Client().zen_store + user = Client().active_user.id + ws = Client().active_workspace.id + + # manual creation needed, as we work with specific versions + models = [] + mvs = [] + for n in ["pipe", "step", "artifact"]: + models.append( + ModelConfig( + name=n, + ).get_or_create_model() + ) + mvs.append( + zs.create_model_version( + ModelVersionRequestModel( + user=user, + workspace=ws, + version=n, + model=models[-1].id, ) + ) + ) - multi_named_pipeline_mixed_linkage() - - artifact_links = [] - for mv in mvs: - artifact_links.append( - zs.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=mv.model.id, - model_version_id=mv.id, - ) - ) + multi_named_pipeline_mixed_linkage() + + artifact_links = [] + for mv in mvs: + artifact_links.append( + zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=mv.model.id, + model_version_id=mv.id, ) + ) + ) - assert artifact_links[0].size == 3 - assert artifact_links[1].size == 2 - assert artifact_links[2].size == 1 - - assert {al.name for al in artifact_links[0]} == { - "custom_name", - "4", - "output", - } - assert {al.name for al in artifact_links[1]} == { - "2", - "output", - } - assert artifact_links[2][0].name == "3" - assert {al.link_version for al in artifact_links[0]} == { - 1 - }, "some artifacts tracked as higher versions, while all should be version 1" - assert {al.link_version for al in artifact_links[1]} == { - 1 - }, "some artifacts tracked as higher versions, while all should be version 1" + assert artifact_links[0].size == 3 + assert artifact_links[1].size == 2 + assert artifact_links[2].size == 1 + + assert {al.name for al in artifact_links[0]} == { + "custom_name", + "4", + "output", + } + assert {al.name for al in artifact_links[1]} == { + "2", + "output", + } + assert artifact_links[2][0].name == "3" + assert {al.link_version for al in artifact_links[0]} == { + 1 + }, "some artifacts tracked as higher versions, while all should be version 1" + assert {al.link_version for al in artifact_links[1]} == { + 1 + }, "some artifacts tracked as higher versions, while all should be version 1" @step(model_config=ModelConfig(name=MODEL_NAME, version="good_one")) @@ -608,68 +603,67 @@ def simple_pipeline_with_manual_and_implicit_linkage(): def test_link_with_manual_linkage(pipeline: Callable): """Test manual linking by function call in 2 setting: only manual and manual+implicit""" with model_killer(): - with model_killer("bar"): - zs = Client().zen_store - user = Client().active_user.id - ws = Client().active_workspace.id - - # manual creation needed, as we work with specific versions - model = zs.create_model( - ModelRequestModel( - name=MODEL_NAME, - user=user, - workspace=ws, - ) + zs = Client().zen_store + user = Client().active_user.id + ws = Client().active_workspace.id + + # manual creation needed, as we work with specific versions + model = zs.create_model( + ModelRequestModel( + name=MODEL_NAME, + user=user, + workspace=ws, ) - model2 = zs.create_model( - ModelRequestModel( - name="bar", - user=user, - workspace=ws, - ) + ) + model2 = zs.create_model( + ModelRequestModel( + name="bar", + user=user, + workspace=ws, ) - mv = zs.create_model_version( - ModelVersionRequestModel( - user=user, - workspace=ws, - version="good_one", - model=model.id, - ) + ) + mv = zs.create_model_version( + ModelVersionRequestModel( + user=user, + workspace=ws, + version="good_one", + model=model.id, ) - mv2 = zs.create_model_version( - ModelVersionRequestModel( - user=user, - workspace=ws, - version="bar", - model=model2.id, - ) + ) + mv2 = zs.create_model_version( + ModelVersionRequestModel( + user=user, + workspace=ws, + version="bar", + model=model2.id, ) + ) - pipeline() + pipeline() - al1 = zs.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=model.id, - model_version_id=mv.id, - ) + al1 = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=model.id, + model_version_id=mv.id, ) - assert al1.size == 1 - assert al1[0].link_version == 1 - assert al1[0].name == "1" - - al2 = zs.list_model_version_artifact_links( - ModelVersionArtifactFilterModel( - user_id=user, - workspace_id=ws, - model_id=model2.id, - model_version_id=mv2.id, - ) + ) + assert al1.size == 1 + assert al1[0].link_version == 1 + assert al1[0].name == "1" + + al2 = zs.list_model_version_artifact_links( + ModelVersionArtifactFilterModel( + user_id=user, + workspace_id=ws, + model_id=model2.id, + model_version_id=mv2.id, ) - assert al2.size == 1 - assert al2[0].link_version == 1 - assert al2[0].name == "2" + ) + assert al2.size == 1 + assert al2[0].link_version == 1 + assert al2[0].name == "2" @step diff --git a/tests/integration/functional/model/test_model_config.py b/tests/integration/functional/model/test_model_config.py index 6f9340d4ce6..36bddb15e3f 100644 --- a/tests/integration/functional/model/test_model_config.py +++ b/tests/integration/functional/model/test_model_config.py @@ -175,12 +175,12 @@ def test_init_create_new_version_with_version_fails(self): def test_init_recovery_without_create_new_version_warns(self): """Test that use of `recovery` warn on `create_new_model_version` set to False.""" with mock.patch("zenml.model.model_config.logger.warning") as logger: - ModelConfig(name=MODEL_NAME, recovery=True) + ModelConfig(name=MODEL_NAME, delete_new_version_on_failure=False) logger.assert_called_once() with mock.patch("zenml.model.model_config.logger.warning") as logger: ModelConfig( name=MODEL_NAME, - recovery=True, + delete_new_version_on_failure=False, create_new_model_version=True, ) logger.assert_not_called() @@ -204,7 +204,7 @@ def test_recovery_flow(self): mc = ModelConfig( name=MODEL_NAME, create_new_model_version=True, - recovery=True, + delete_new_version_on_failure=False, ) mv1 = mc.get_or_create_model_version() del mc @@ -212,8 +212,8 @@ def test_recovery_flow(self): mc = ModelConfig( name=MODEL_NAME, create_new_model_version=True, - recovery=True, + delete_new_version_on_failure=False, ) mv2 = mc.get_or_create_model_version() - assert mv1 == mv2 + assert mv1.id == mv2.id diff --git a/tests/integration/functional/steps/test_model_config.py b/tests/integration/functional/steps/test_model_config.py index daa8c08dc43..2a34e6afc8b 100644 --- a/tests/integration/functional/steps/test_model_config.py +++ b/tests/integration/functional/steps/test_model_config.py @@ -12,23 +12,16 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -from contextlib import contextmanager +from unittest import mock import pytest +from tests.integration.functional.utils import model_killer +from typing_extensions import Annotated from zenml import get_step_context, pipeline, step from zenml.client import Client from zenml.constants import RUNNING_MODEL_VERSION -from zenml.model import ModelConfig - - -@contextmanager -def model_killer(model_name): - try: - yield - finally: - zs = Client().zen_store - zs.delete_model(model_name) +from zenml.model import ArtifactConfig, ModelConfig @step @@ -49,7 +42,7 @@ def _simple_step_pipeline(): ), )() - with model_killer("foo"): + with model_killer(): _simple_step_pipeline() @@ -64,7 +57,7 @@ def test_model_config_passed_to_step_context_via_pipeline(): def _simple_step_pipeline(): _assert_that_model_config_set() - with model_killer("foo"): + with model_killer(): _simple_step_pipeline() @@ -83,7 +76,7 @@ def _simple_step_pipeline(): ), )() - with model_killer("foo"): + with model_killer(): _simple_step_pipeline() @@ -111,10 +104,8 @@ def _simple_step_pipeline(): ), )(name="foobar") - with model_killer("foo"): - with model_killer("bar"): - with model_killer("foobar"): - _simple_step_pipeline() + with model_killer(): + _simple_step_pipeline() @step(model_config=ModelConfig(name="foo", create_new_model_version=True)) @@ -139,33 +130,28 @@ def _this_pipeline_creates_a_version(): _this_step_creates_a_version() _this_step_does_not_create_a_version() - with model_killer("foo"): - with model_killer("bar"): - zs = Client().zen_store - with pytest.raises(KeyError): - zs.get_model("foo") - with pytest.raises(KeyError): - zs.get_model("bar") + with model_killer(): + zs = Client().zen_store - _this_pipeline_creates_a_version() + _this_pipeline_creates_a_version() - foo = zs.get_model("foo") - assert foo.name == "foo" - foo_version = zs.get_model_version("foo") - assert foo_version.version == "1" + foo = zs.get_model("foo") + assert foo.name == "foo" + foo_version = zs.get_model_version("foo") + assert foo_version.version == "1" - bar = zs.get_model("bar") - assert bar.name == "bar" - bar_version = zs.get_model_version("bar") - assert bar_version.version == "1" + bar = zs.get_model("bar") + assert bar.name == "bar" + bar_version = zs.get_model_version("bar") + assert bar_version.version == "1" - _this_pipeline_creates_a_version() + _this_pipeline_creates_a_version() - foo_version = zs.get_model_version("foo") - assert foo_version.version == "2" + foo_version = zs.get_model_version("foo") + assert foo_version.version == "2" - bar_version = zs.get_model_version("bar") - assert bar_version.version == "2" + bar_version = zs.get_model_version("bar") + assert bar_version.version == "2" def test_create_new_version_only_in_step(): @@ -176,10 +162,8 @@ def _this_pipeline_does_not_create_a_version(): _this_step_creates_a_version() _this_step_does_not_create_a_version() - with model_killer("foo"): + with model_killer(): zs = Client().zen_store - with pytest.raises(KeyError): - zs.get_model("foo") _this_pipeline_does_not_create_a_version() @@ -205,10 +189,8 @@ def test_create_new_version_only_in_pipeline(): def _this_pipeline_creates_a_version(): _this_step_does_not_create_a_version() - with model_killer("bar"): + with model_killer(): zs = Client().zen_store - with pytest.raises(KeyError): - zs.get_model("bar") _this_pipeline_creates_a_version() @@ -221,3 +203,183 @@ def _this_pipeline_creates_a_version(): foo_version = zs.get_model_version("bar") assert foo_version.version == "2" + + +@step +def _this_step_produces_output() -> ( + Annotated[int, "data", ArtifactConfig(overwrite=False)] +): + return 1 + + +@step +def _this_step_tries_to_recover(run_number: int): + zs = Client().zen_store + mv = zs.get_model_version( + model_name_or_id="foo", model_version_name_or_id=RUNNING_MODEL_VERSION + ) + assert ( + len(mv.artifact_object_ids["data"]) == run_number + ), "expected AssertionError" + + raise Exception("make pipeline fail") + + +def test_recovery_of_steps(): + """Test that model config can recover states after previous fails.""" + + @pipeline( + name="bar", + enable_cache=False, + model_config=ModelConfig( + name="foo", + create_new_model_version=True, + delete_new_version_on_failure=False, + ), + ) + def _this_pipeline_will_recover(run_number: int): + _this_step_produces_output() + _this_step_tries_to_recover( + run_number, after=["_this_step_produces_output"] + ) + + with model_killer(): + zs = Client().zen_store + + with pytest.raises(Exception, match="make pipeline fail"): + _this_pipeline_will_recover(1) + with pytest.raises(Exception, match="make pipeline fail"): + _this_pipeline_will_recover(2) + with pytest.raises(Exception, match="make pipeline fail"): + _this_pipeline_will_recover(3) + + model = zs.get_model("foo") + mv = zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=RUNNING_MODEL_VERSION, + ) + assert mv.version == RUNNING_MODEL_VERSION + assert len(mv.artifact_object_ids) == 1 + assert len(mv.artifact_object_ids["data"]) == 3 + + +def test_clean_up_after_failure(): + """Test that hanging `running` versions are cleaned-up after failure.""" + + @pipeline( + name="bar", + enable_cache=False, + model_config=ModelConfig( + name="foo", + create_new_model_version=True, + delete_new_version_on_failure=True, + ), + ) + def _this_pipeline_will_not_recover(run_number: int): + _this_step_produces_output() + _this_step_tries_to_recover( + run_number, after=["_this_step_produces_output"] + ) + + with model_killer(): + zs = Client().zen_store + + with pytest.raises(Exception, match="make pipeline fail"): + _this_pipeline_will_not_recover(1) + with pytest.raises(AssertionError, match="expected AssertionError"): + _this_pipeline_will_not_recover(2) + + model = zs.get_model("foo") + with pytest.raises(KeyError): + zs.get_model_version( + model_name_or_id=model.id, + model_version_name_or_id=RUNNING_MODEL_VERSION, + ) + + +@step(model_config=ModelConfig(name="foo", create_new_model_version=True)) +def _new_version_step(): + return 1 + + +@step +def _no_model_config_step(): + return 1 + + +@pipeline( + enable_cache=False, + model_config=ModelConfig(name="foo", create_new_model_version=True), +) +def _new_version_pipeline_overridden_warns(): + _new_version_step() + + +@pipeline( + enable_cache=False, + model_config=ModelConfig(name="foo", create_new_model_version=True), +) +def _new_version_pipeline_not_warns(): + _no_model_config_step() + + +@pipeline(enable_cache=False) +def _no_new_version_pipeline_not_warns(): + _new_version_step() + + +@pipeline(enable_cache=False) +def _no_new_version_pipeline_warns_on_steps(): + _new_version_step() + _new_version_step() + + +@pipeline( + enable_cache=False, + model_config=ModelConfig(name="foo", create_new_model_version=True), +) +def _new_version_pipeline_warns_on_steps(): + _new_version_step() + _no_model_config_step() + + +@pytest.mark.parametrize( + "pipeline, expected_warning", + [ + ( + _new_version_pipeline_overridden_warns, + "is overridden in all steps", + ), + (_new_version_pipeline_not_warns, ""), + (_no_new_version_pipeline_not_warns, ""), + ( + _no_new_version_pipeline_warns_on_steps, + "`create_new_model_version` is configured only in one", + ), + ( + _new_version_pipeline_warns_on_steps, + "`create_new_model_version` is configured only in one", + ), + ], + ids=[ + "Pipeline with one step, which overrides model_config - warns that pipeline conf is useless.", + "Configuration in pipeline only - not warns.", + "Configuration in step only - not warns.", + "Two steps ask to create new versions - warning to keep it in one place.", + "Pipeline and one of the steps ask to create new versions - warning to keep it in one place.", + ], +) +def test_multiple_definitions_create_new_version_warns( + pipeline, expected_warning +): + """Test that setting conflicting model configurations are raise warnings to user.""" + with model_killer(): + with mock.patch( + "zenml.new.pipelines.pipeline.logger.warning" + ) as logger: + pipeline() + if expected_warning: + logger.assert_called_once() + assert expected_warning in logger.call_args[0][0] + else: + logger.assert_not_called() diff --git a/tests/integration/functional/utils.py b/tests/integration/functional/utils.py index abbdb6fa244..4bd6e56db4d 100644 --- a/tests/integration/functional/utils.py +++ b/tests/integration/functional/utils.py @@ -1,6 +1,24 @@ +from contextlib import contextmanager + +from zenml.client import Client +from zenml.models import ModelFilterModel from zenml.utils.string_utils import random_str def sample_name(prefix: str = "aria") -> str: """Function to get random username.""" return f"{prefix}-{random_str(4)}".lower() + + +@contextmanager +def model_killer(): + try: + yield + finally: + zs = Client().zen_store + models = zs.list_models(ModelFilterModel(size=999)) + for model in models: + try: + zs.delete_model(model.name) + except KeyError: + pass diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index d4b965247ee..25cf68d7c59 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -3079,7 +3079,7 @@ def test_link_list_empty(self): assert len(mvls) == 0 def test_link_list_populated(self): - with ModelVersionContext(True, create_artifacts=3) as ( + with ModelVersionContext(True, create_artifacts=4) as ( model_version, artifacts, ): @@ -3095,6 +3095,7 @@ def test_link_list_populated(self): ("link1", False, False, artifacts[0]), ("link2", True, False, artifacts[1]), ("link3", False, True, artifacts[2]), + ("link1", False, False, artifacts[3]), ]: zs.create_model_version_artifact_link( ModelVersionArtifactRequestModel( @@ -3116,7 +3117,7 @@ def test_link_list_populated(self): model_version_id=model_version.id, ) ) - assert len(mvls) == 3 + assert len(mvls) == len(artifacts) mvls = zs.list_model_version_artifact_links( ModelVersionArtifactFilterModel( @@ -3125,7 +3126,11 @@ def test_link_list_populated(self): only_artifacts=True, ) ) - assert len(mvls) == 1 and mvls[0].name == "link1" + assert ( + len(mvls) == 2 + and mvls[0].name == "link1" + and mvls[1].name == "link1" + ) mvls = zs.list_model_version_artifact_links( ModelVersionArtifactFilterModel( @@ -3155,25 +3160,40 @@ def test_link_list_populated(self): assert len(mv.deployment_ids) == 1 assert isinstance( - mv.model_objects["link2"], + mv.get_model_object("link2", "1"), ArtifactResponseModel, ) assert isinstance( - mv.artifact_objects["link1"], + mv.get_artifact_object("link1", "1"), ArtifactResponseModel, ) assert isinstance( - mv.deployments["link3"], + mv.get_deployment("link3", "1"), ArtifactResponseModel, ) - assert mv.model_objects["link2"].id == artifacts[1].id + assert mv.model_objects["link2"]["1"].id == artifacts[1].id - assert mv.get_model_object("link2") == mv.model_objects["link2"] assert ( - mv.get_artifact_object("link1") == mv.artifact_objects["link1"] + mv.get_model_object("link2", "1") + == mv.model_objects["link2"]["1"] + ) + assert ( + mv.get_deployment("link3", "1") == mv.deployments["link3"]["1"] + ) + + # check how versioned artifacts retrieved + assert ( + mv.get_artifact_object("link1", "1") + == mv.artifacts["link1"]["1"] + ) + assert ( + mv.get_artifact_object("link1", "2") + == mv.artifacts["link1"]["2"] + ) + assert ( + mv.get_artifact_object("link1") == mv.artifacts["link1"]["2"] ) - assert mv.get_deployment("link3") == mv.deployments["link3"] class TestModelVersionPipelineRunLinks: