Skip to content

Commit

Permalink
WIP: Rework notebook code upload
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Aug 8, 2024
1 parent 49cc04b commit 65d4d8d
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 46 deletions.
8 changes: 1 addition & 7 deletions src/zenml/config/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,14 @@ class NotebookSource(Source):
"""Source representing an object defined in a notebook.
Attributes:
code_path: Path where the notebook cell code for this source is
uploaded.
replacement_module: Name of the module from which this source should
be loaded in case the code is not running in a notebook.
"""

code_path: Optional[str] = None
replacement_module: Optional[str] = None
artifact_store_id: Optional[UUID] = None
type: SourceType = SourceType.NOTEBOOK

# Private attribute that is used to store the code but should not be
# serialized
_cell_code: Optional[str] = None

@field_validator("type")
@classmethod
def _validate_type(cls, value: SourceType) -> SourceType:
Expand Down
4 changes: 4 additions & 0 deletions src/zenml/materializers/base_materializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __new__(
associated_type, cls
)

from zenml.utils import notebook_utils

notebook_utils.try_to_save_notebook_cell_code(cls)

return cls


Expand Down
69 changes: 48 additions & 21 deletions src/zenml/new/pipelines/run_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
"""Utility functions for running pipelines."""

import hashlib
import time
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from uuid import UUID

from pydantic import BaseModel

from zenml import constants
from zenml.client import Client
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
from zenml.config.source import SourceType
from zenml.config.source import Source, SourceType
from zenml.config.step_configurations import StepConfigurationUpdate
from zenml.enums import ExecutionStatus, ModelStages
from zenml.logger import get_logger
Expand All @@ -25,7 +26,7 @@
from zenml.new.pipelines.model_utils import NewModelRequest
from zenml.orchestrators.utils import get_run_name
from zenml.stack import Flavor, Stack
from zenml.utils import cloud_utils, code_utils, notebook_utils
from zenml.utils import cloud_utils, code_utils, notebook_utils, source_utils
from zenml.zen_stores.base_zen_store import BaseZenStore

if TYPE_CHECKING:
Expand Down Expand Up @@ -383,9 +384,8 @@ def upload_notebook_cell_code_if_necessary(
RuntimeError: If the code for one of the steps that will run out of
process cannot be extracted into a python file.
"""
code_archive = code_utils.CodeArchive(root=None)
should_upload = False
sources_that_require_upload = []
resolved_notebook_sources = source_utils.get_resolved_notebook_sources()

for step in deployment.step_configurations.values():
source = step.spec.source
Expand All @@ -396,7 +396,9 @@ def upload_notebook_cell_code_if_necessary(
or step.config.step_operator
):
should_upload = True
cell_code = getattr(step.spec.source, "_cell_code", None)
cell_code = resolved_notebook_sources.get(
source.import_path, None
)

# Code does not run in-process, which means we need to
# extract the step code into a python file
Expand All @@ -410,20 +412,45 @@ def upload_notebook_cell_code_if_necessary(
"of a notebook."
)

notebook_utils.warn_about_notebook_cell_magic_commands(
cell_code=cell_code
)
if should_upload:
logger.info("Uploading notebook code...")

code_hash = hashlib.sha1(cell_code.encode()).hexdigest() # nosec
module_name = f"extracted_notebook_code_{code_hash}"
file_name = f"{module_name}.py"
code_archive.add_file(source=cell_code, destination=file_name)
for _, cell_code in resolved_notebook_sources.items():
notebook_utils.warn_about_notebook_cell_magic_commands(
cell_code=cell_code
)
module_name = notebook_utils.compute_cell_replacement_module_name(
cell_code=cell_code
)
file_name = f"{module_name}.py"

setattr(step.spec.source, "replacement_module", module_name)
sources_that_require_upload.append(source)
code_utils.upload_notebook_code(
artifact_store=stack.artifact_store,
code=cell_code,
file_name=file_name,
)

if should_upload:
logger.info("Archiving notebook code...")
code_path = code_utils.upload_code_if_necessary(code_archive)
for source in sources_that_require_upload:
setattr(source, "code_path", code_path)
all_deployment_sources = get_all_sources_from_value(deployment)

for source in all_deployment_sources:
if source.type == SourceType.NOTEBOOK:
setattr(source, "artifact_store_id", stack.artifact_store.id)

logger.info("Upload finished.")


def get_all_sources_from_value(value: Any) -> List[Source]:
sources = []
if isinstance(value, Source):
sources.append(value)
elif isinstance(value, BaseModel):
for v in value.__dict__.values():
sources.extend(get_all_sources_from_value(v))
elif isinstance(value, Dict):
for v in value.values():
sources.extend(get_all_sources_from_value(v))
elif isinstance(value, (List, Set, tuple)):
for v in value:
sources.extend(get_all_sources_from_value(v))

return sources
30 changes: 30 additions & 0 deletions src/zenml/utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
if TYPE_CHECKING:
from git.repo.base import Repo

from zenml.artifact_stores import BaseArtifactStore


logger = get_logger(__name__)

Expand Down Expand Up @@ -242,3 +244,31 @@ def download_and_extract_code(code_path: str, extract_dir: str) -> None:

shutil.unpack_archive(filename=download_path, extract_dir=extract_dir)
os.remove(download_path)


def _get_notebook_upload_dir(artifact_store: "BaseArtifactStore") -> str:
return os.path.join(artifact_store.path, "notebook_code")


def upload_notebook_code(
artifact_store: "BaseArtifactStore", code: str, file_name: str
) -> None:
upload_dir = _get_notebook_upload_dir(artifact_store=artifact_store)
fileio.makedirs(upload_dir)
upload_path = os.path.join(upload_dir, file_name)

if not fileio.exists(upload_path):
with fileio.open(upload_path, "wb") as f:
f.write(code.encode())


def download_notebook_code(
artifact_store: "BaseArtifactStore", file_name: str, download_path: str
) -> None:
code_dir = _get_notebook_upload_dir(artifact_store=artifact_store)
code_path = os.path.join(code_dir, file_name)

if not fileio.exists(code_path):
raise RuntimeError(f"Code code at path {code_path} not found.")

fileio.copy(code_path, download_path)
6 changes: 6 additions & 0 deletions src/zenml/utils/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# permissions and limitations under the License.
"""Notebook utilities."""

import hashlib
from typing import Any, Callable, Optional, TypeVar, Union

from zenml.environment import Environment
Expand Down Expand Up @@ -120,3 +121,8 @@ def warn_about_notebook_cell_magic_commands(cell_code: str) -> None:
"of these lines contain Jupyter notebook magic commands, "
"remove them and try again."
)


def compute_cell_replacement_module_name(cell_code: str) -> str:
code_hash = hashlib.sha1(cell_code.encode()).hexdigest() # nosec
return f"extracted_notebook_code_{code_hash}"
69 changes: 51 additions & 18 deletions src/zenml/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,7 @@
from distutils.sysconfig import get_python_lib
from pathlib import Path, PurePath
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import (
Any,
Callable,
Iterator,
Optional,
Type,
Union,
cast,
)
from typing import Any, Callable, Dict, Iterator, Optional, Type, Union, cast

from zenml.config.source import (
CodeRepositorySource,
Expand Down Expand Up @@ -69,6 +61,7 @@
)

_SHARED_TEMPDIR: Optional[str] = None
_resolved_notebook_sources: Dict[str, str] = {}


def load(source: Union[Source, str]) -> Any:
Expand Down Expand Up @@ -241,9 +234,18 @@ def resolve(
attribute=attribute_name,
type=source_type,
)
# Private attributes are ignored by pydantic if passed in the __init__
# method, so we set this afterwards
source._cell_code = notebook_utils.load_notebook_cell_code(obj)

if cell_code := notebook_utils.load_notebook_cell_code(obj):
global _resolved_notebook_sources

replacement_module = (
notebook_utils.compute_cell_replacement_module_name(
cell_code=cell_code
)
)
source.replacement_module = replacement_module
_resolved_notebook_sources[source.import_path] = cell_code

return source

return Source(
Expand Down Expand Up @@ -586,29 +588,54 @@ def _try_to_load_notebook_source(source: NotebookSource) -> Any:
Returns:
The loaded object.
"""
if not source.code_path or not source.replacement_module:
if not source.replacement_module:
raise RuntimeError(
f"Failed to load {source.import_path}. This object was defined in "
"a notebook and you're trying to load it outside of a notebook. "
"This is currently only enabled for ZenML steps."
"This is currently only enabled for ZenML steps and materializers. "
"To enable this for your custom classes or functions, use the "
"`zenml.utils.notebook_utils.enable_notebook_code_extraction` "
"decorator."
)

extract_dir = _get_shared_temp_dir()
file_path = os.path.join(extract_dir, f"{source.replacement_module}.py")
file_name = f"{source.replacement_module}.py"
file_path = os.path.join(extract_dir, file_name)

if not os.path.exists(file_path):
from zenml.client import Client
from zenml.utils import code_utils

artifact_store = Client().active_stack.artifact_store

if (
source.artifact_store_id
and source.artifact_store_id != artifact_store.id
):
raise RuntimeError(
"Notebook cell code not stored in active artifact store."
)

logger.info(
"Downloading notebook cell content from `%s` to load `%s`.",
source.code_path,
source.import_path,
)

code_utils.download_and_extract_code(
code_path=source.code_path, extract_dir=extract_dir
)
try:
code_utils.download_notebook_code(
artifact_store=Client().artifact_store,
file_name=file_name,
download_path=file_path,
)
except FileNotFoundError:
if not source.artifact_store_id:
raise FileNotFoundError(
"Unable to find notebook code file. This might be becuase "
"the file is stored in a different artifact store."
)

raise
try:
module = _load_module(
module_name=source.replacement_module, import_root=extract_dir
Expand Down Expand Up @@ -734,3 +761,9 @@ def validate_source_class(
return True
else:
return False


def get_resolved_notebook_sources() -> Dict[str, str]:
global _resolved_notebook_sources

return _resolved_notebook_sources.copy()

0 comments on commit 65d4d8d

Please sign in to comment.