Skip to content

Commit

Permalink
Bugfix for artifacts coming from a different artifact store (#2928)
Browse files Browse the repository at this point in the history
* first draft of the artifact store solution

* fixes in error message

* review changes

* renaming the context manage

* review comments

* fixing the utils

* fixing the test fixture

* removed unused import

* added a small test checking the register calls
  • Loading branch information
bcdurak authored Aug 27, 2024
1 parent 35813b1 commit 8b8a6af
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 15 deletions.
19 changes: 13 additions & 6 deletions src/zenml/artifacts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,12 +710,19 @@ def _get_artifact_store_from_response_or_from_active_stack(
"BaseArtifactStore",
StackComponent.from_model(artifact_store_model),
)
except (KeyError, ImportError):
logger.warning(
"Unable to restore artifact store while trying to load artifact "
"`%s`. If this artifact is stored in a remote artifact store, "
"this might lead to issues when trying to load the artifact.",
artifact.id,
except KeyError:
raise RuntimeError(
"Unable to fetch the artifact store with id: "
f"'{artifact.artifact_store_id}'. Check whether the artifact "
"store still exists and you have the right permissions to "
"access it."
)
except ImportError:
raise RuntimeError(
"Unable to load the implementation of the artifact store with"
f"id: '{artifact.artifact_store_id}'. Please make sure that "
"the environment that you are loading this artifact from "
"has the right dependencies."
)
return Client().active_stack.artifact_store

Expand Down
16 changes: 13 additions & 3 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,24 @@ def _load_input_artifact(
# we use the datatype of the stored artifact
data_type = source_utils.load(artifact.data_type)

from zenml.orchestrators.utils import (
register_artifact_store_filesystem,
)

materializer_class: Type[BaseMaterializer] = (
source_utils.load_and_validate_class(
artifact.materializer, expected_class=BaseMaterializer
)
)
materializer: BaseMaterializer = materializer_class(artifact.uri)
materializer.validate_type_compatibility(data_type)
return materializer.load(data_type=data_type)

with register_artifact_store_filesystem(
artifact.artifact_store_id
) as target_artifact_store:
materializer: BaseMaterializer = materializer_class(
uri=artifact.uri, artifact_store=target_artifact_store
)
materializer.validate_type_compatibility(data_type)
return materializer.load(data_type=data_type)

def _validate_outputs(
self,
Expand Down
107 changes: 105 additions & 2 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# permissions and limitations under the License.
"""Utility functions for the orchestrator."""

import os
import random
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, cast
from uuid import UUID

from zenml.client import Client
Expand All @@ -25,17 +26,21 @@
from zenml.constants import (
ENV_ZENML_ACTIVE_STACK_ID,
ENV_ZENML_ACTIVE_WORKSPACE_ID,
ENV_ZENML_SERVER,
ENV_ZENML_STORE_PREFIX,
PIPELINE_API_TOKEN_EXPIRES_MINUTES,
)
from zenml.enums import StoreType
from zenml.enums import StackComponentType, StoreType
from zenml.exceptions import StepContextError
from zenml.logger import get_logger
from zenml.model.utils import link_artifact_config_to_model
from zenml.models.v2.core.step_run import StepRunRequest
from zenml.new.steps.step_context import get_step_context
from zenml.stack import StackComponent
from zenml.utils.string_utils import format_name_template

if TYPE_CHECKING:
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
from zenml.artifacts.external_artifact_config import (
ExternalArtifactConfiguration,
)
Expand Down Expand Up @@ -302,3 +307,101 @@ def _get_model_versions_from_artifacts(
else:
break
return models


class register_artifact_store_filesystem:
"""Context manager for the artifact_store/filesystem_registry dependency.
Even though it is rare, sometimes we bump into cases where we are trying to
load artifacts that belong to an artifact store which is different from
the active artifact store.
In cases like this, we will try to instantiate the target artifact store
by creating the corresponding artifact store Python object, which ends up
registering the right filesystem in the filesystem registry.
The problem is, the keys in the filesystem registry are schemes (such as
"s3://" or "gcs://"). If we have two artifact stores with the same set of
supported schemes, we might end up overwriting the filesystem that belongs
to the active artifact store (and its authentication). That's why we have
to re-instantiate the active artifact store again, so the correct filesystem
will be restored.
"""

def __init__(self, target_artifact_store_id: Optional[UUID]) -> None:
"""Initialization of the context manager.
Args:
target_artifact_store_id: the ID of the artifact store to load.
"""
self.target_artifact_store_id = target_artifact_store_id

def __enter__(self) -> "BaseArtifactStore":
"""Entering the context manager.
It creates an instance of the target artifact store to register the
correct filesystem in the registry.
Returns:
The target artifact store object.
Raises:
RuntimeError: If the target artifact store can not be fetched or
initiated due to missing dependencies.
"""
try:
if self.target_artifact_store_id is not None:
if (
Client().active_stack.artifact_store.id
!= self.target_artifact_store_id
):
get_logger(__name__).debug(
f"Trying to use the artifact store with ID:"
f"'{self.target_artifact_store_id}'"
f"which is currently not the active artifact store."
)

artifact_store_model_response = Client().get_stack_component(
component_type=StackComponentType.ARTIFACT_STORE,
name_id_or_prefix=self.target_artifact_store_id,
)
return cast(
"BaseArtifactStore",
StackComponent.from_model(artifact_store_model_response),
)
else:
return Client().active_stack.artifact_store

except KeyError:
raise RuntimeError(
"Unable to fetch the artifact store with id: "
f"'{self.target_artifact_store_id}'. Check whether the "
"artifact store still exists and you have the right "
"permissions to access it."
)
except ImportError:
raise RuntimeError(
"Unable to load the implementation of the artifact store with"
f"id: '{self.target_artifact_store_id}'. Please make sure that "
"the environment that you are loading this artifact from "
"has the right dependencies."
)

def __exit__(
self,
exc_type: Optional[Any],
exc_value: Optional[Any],
traceback: Optional[Any],
) -> None:
"""Set it back to the original state.
Args:
exc_type: The class of the exception
exc_value: The instance of the exception
traceback: The traceback of the exception
"""
if ENV_ZENML_SERVER not in os.environ:
# As we exit the handler, we have to re-register the filesystem
# that belongs to the active artifact store as it may have been
# overwritten.
Client().active_stack.artifact_store._register()
5 changes: 2 additions & 3 deletions tests/unit/artifacts/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
import shutil
import tempfile
from uuid import uuid4

import numpy as np
import pytest
Expand All @@ -33,7 +32,7 @@


@pytest.fixture
def model_artifact(mocker):
def model_artifact(mocker, clean_client: "Client"):
return mocker.Mock(
spec=ArtifactVersionResponse,
id="123",
Expand All @@ -45,7 +44,7 @@ def model_artifact(mocker):
uri="gs://my-bucket/model.joblib",
data_type="path/to/model/class",
materializer="path/to/materializer/class",
artifact_store_id=uuid4(),
artifact_store_id=clean_client.active_stack.artifact_store.id,
)


Expand Down
31 changes: 30 additions & 1 deletion tests/unit/orchestrators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
from zenml.orchestrators.utils import is_setting_enabled
from unittest import mock

from zenml.enums import StackComponentType
from zenml.orchestrators.utils import (
is_setting_enabled,
register_artifact_store_filesystem,
)


def test_is_setting_enabled():
Expand Down Expand Up @@ -97,3 +103,26 @@ def test_is_setting_enabled():
)
is False
)


def test_register_artifact_store_filesystem(clean_client):
"""Tests if a new filesystem gets registered with the context manager."""
with mock.patch(
"zenml.artifact_stores.base_artifact_store.BaseArtifactStore._register"
) as register:
# Calling the active artifact store will call register once
_ = clean_client.active_stack.artifact_store
assert register.call_count == 1

new_artifact_store_model = clean_client.create_stack_component(
name="new_local_artifact_store",
flavor="local",
component_type=StackComponentType.ARTIFACT_STORE,
configuration={"path": ""},
)
with register_artifact_store_filesystem(new_artifact_store_model.id):
# Entering the context manager will register the new filesystem
assert register.call_count == 2

# Exiting the context manager will set it back by calling register again
assert register.call_count == 3

0 comments on commit 8b8a6af

Please sign in to comment.