Skip to content

Commit

Permalink
[NeMo-UX] Fix a serialization bug that prevents users from moving che…
Browse files Browse the repository at this point in the history
…ckpoints (NVIDIA#9939)

* perfor serialization using relative paths to allow users to move checkpoints after they're saved

Signed-off-by: ashors1 <ashors@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: ashors1 <ashors1@users.noreply.github.com>

* remove unused import

Signed-off-by: ashors1 <ashors@nvidia.com>

* fix artifact load

Signed-off-by: ashors1 <ashors@nvidia.com>

* fix path artifact

Signed-off-by: ashors1 <ashors@nvidia.com>

* remove unused import

Signed-off-by: ashors1 <ashors@nvidia.com>

---------

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors1@users.noreply.github.com>
Co-authored-by: ashors1 <ashors1@users.noreply.github.com>
  • Loading branch information
2 people authored and WoodieDudy committed Aug 26, 2024
1 parent 18bae50 commit 44fc6d4
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 75 deletions.
50 changes: 1 addition & 49 deletions nemo/lightning/io/api.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,13 @@
import json
from pathlib import Path
from pydoc import locate
from typing import Any, Callable, Optional, Type, TypeVar

import fiddle as fdl
import pytorch_lightning as pl
from fiddle._src.experimental import serialization

from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, track_io
from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, load
from nemo.lightning.io.pl import TrainerContext

CkptType = TypeVar("CkptType")


def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType:
"""
Loads a configuration from a pickle file and constructs an object of the specified type.
Args:
path (Path): The path to the pickle file or directory containing 'io.pkl'.
output_type (Type[CkptType]): The type of the object to be constructed from the loaded data.
Returns
-------
CkptType: An instance of the specified type constructed from the loaded configuration.
Raises
------
FileNotFoundError: If the specified file does not exist.
Example:
loaded_model = load("/path/to/model", output_type=MyModel)
"""
del output_type # Just for type-hint

_path = Path(path)
if hasattr(_path, 'is_dir') and _path.is_dir():
_path = Path(_path) / "io.json"
elif hasattr(_path, 'isdir') and _path.isdir:
_path = Path(_path) / "io.json"

if not _path.is_file():
raise FileNotFoundError(f"No such file: '{_path}'")

## add IO functionality to custom objects present in the json file
with open(_path) as f:
j = json.load(f)
for obj, val in j["objects"].items():
clss = ".".join([val["type"]["module"], val["type"]["name"]])
if not serialization.find_node_traverser(locate(clss)):
track_io(locate(clss))

with open(_path, "rb") as f:
config = serialization.load_json(f.read())

return fdl.build(config)


def load_context(path: Path) -> TrainerContext:
"""
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/io/artifact/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, attr: str, required: bool = True):
self.required = required

@abstractmethod
def dump(self, value: ValueT, path: Path) -> ValueT:
def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
pass

@abstractmethod
Expand Down
15 changes: 8 additions & 7 deletions nemo/lightning/io/artifact/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@


class PathArtifact(Artifact[Path]):
def dump(self, value: Path, path: Path) -> Path:
new_value = copy_file(value, path)
def dump(self, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
new_value = copy_file(value, absolute_dir, relative_dir)
return new_value

def load(self, path: Path) -> Path:
return path


class FileArtifact(Artifact[str]):
def dump(self, value: str, path: Path) -> str:
new_value = copy_file(value, path)
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
new_value = copy_file(value, absolute_dir, relative_dir)
return str(new_value)

def load(self, path: str) -> str:
return path


def copy_file(src: Union[Path, str], dst: Union[Path, str]):
output = Path(dst) / Path(src).name
def copy_file(src: Union[Path, str], path: Union[Path, str], relative_dst: Union[Path, str]):
relative_path = Path(relative_dst) / Path(src).name
output = Path(path) / relative_path
shutil.copy2(src, output)
return output
return relative_path
8 changes: 4 additions & 4 deletions nemo/lightning/io/artifact/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


class PickleArtifact(Artifact[Any]):
def dump(self, value: Any, path: Path) -> Path:
file = self.file_path(path)
with open(file, "wb") as f:
def dump(self, absolute_dir: Path, relative_dir: Path) -> Path:
relative_file = self.file_path(relative_dir)
with open(Path(absolute_dir) / relative_file, "wb") as f:
dump(value, f)

return file
return relative_file

def load(self, path: Path) -> Any:
with open(self.file_path(path), "rb") as f:
Expand Down
103 changes: 89 additions & 14 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import functools
import inspect
import json
import shutil
import threading
import types
import uuid
from copy import deepcopy
from dataclasses import is_dataclass
from pathlib import Path
from pydoc import locate
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

import fiddle as fdl
import fiddle._src.experimental.dataclasses as fdl_dc
from cloudpickle import dump, load
from cloudpickle import dump
from cloudpickle import load as pickle_load
from fiddle._src.experimental import serialization
from typing_extensions import Self

Expand All @@ -21,6 +24,7 @@
from nemo.lightning.io.fdl_torch import enable as _enable_ext

ConnT = TypeVar('ConnT', bound=ModelConnector)
CkptType = TypeVar("CkptType")
_enable_ext()


Expand Down Expand Up @@ -136,21 +140,24 @@ def io_dump(self, output: Path):
will be stored.
"""
output_path = Path(output)
artifacts_dir = output_path / "artifacts"
local_artifacts_dir = "artifacts"
artifacts_dir = output_path / local_artifacts_dir
artifacts_dir.mkdir(parents=True, exist_ok=True)

# Store artifacts directory in thread-local storage
_thread_local.artifacts_dir = artifacts_dir
_thread_local.local_artifacts_dir = local_artifacts_dir
_thread_local.output_path = output_path

config_path = output_path / "io.json"
with open(config_path, "w") as f:
io = deepcopy(self.__io__)
_artifact_transform(io, artifacts_dir)
_artifact_transform_save(io, output_path, local_artifacts_dir)
json = serialization.dump_json(io)
f.write(json)

# Clear thread-local storage after io_dump is complete
del _thread_local.artifacts_dir
del _thread_local.local_artifacts_dir
del _thread_local.output_path

# Check if artifacts directory is empty and delete if so
if not any(artifacts_dir.iterdir()):
Expand Down Expand Up @@ -481,23 +488,28 @@ def _io_flatten_object(instance):
try:
serialization.dump_json(instance.__io__)
except (serialization.UnserializableValueError, AttributeError) as e:
if not hasattr(_thread_local, "artifacts_dir"):
if not hasattr(_thread_local, "local_artifacts_dir") or not hasattr(_thread_local, "output_path"):
raise e

artifact_dir = _thread_local.artifacts_dir
artifact_path = artifact_dir / f"{uuid.uuid4()}"
local_artifact_path = Path(_thread_local.local_artifacts_dir) / f"{uuid.uuid4()}"
output_path = _thread_local.output_path
artifact_path = output_path / local_artifact_path
with open(artifact_path, "wb") as f:
dump(getattr(instance, "__io__", instance), f)
return (str(artifact_path),), None
return (str(local_artifact_path),), None

return instance.__io__.__flatten__()


def _io_unflatten_object(values, metadata):

assert hasattr(_thread_local, "output_dir")
output_dir = _thread_local.output_dir

if len(values) == 1:
pickle_path = values[0]
with open(pickle_path, "rb") as f:
return load(f)
with open(Path(output_dir) / pickle_path, "rb") as f:
return pickle_load(f)

return fdl.Config.__unflatten__(values, metadata)

Expand All @@ -511,19 +523,82 @@ def _io_path_elements_fn(x):
return x.__io__.__path_elements__()


def _artifact_transform(cfg: fdl.Config, output_path: Path):
def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "artifacts"):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
current_val = getattr(cfg, artifact.attr)
if current_val is None:
if artifact.required:
raise ValueError(f"Artifact '{artifact.attr}' is required but not provided")
continue
new_val = artifact.dump(current_val, output_path)
## dump artifact and return the relative path
new_val = artifact.dump(current_val, output_path, relative_dir)
setattr(cfg, artifact.attr, new_val)

for attr in dir(cfg):
try:
if isinstance(getattr(cfg, attr), fdl.Config):
_artifact_transform(getattr(cfg, attr), output_path=output_path)
_artifact_transform_save(getattr(cfg, attr), output_path=output_path, relative_dir=relative_dir)
except ValueError:
pass


def _artifact_transform_load(cfg: fdl.Config, path: Path):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
current_val = getattr(cfg, artifact.attr)
## replace local path with absolute one
new_val = str(Path(path) / current_val)
setattr(cfg, artifact.attr, new_val)

for attr in dir(cfg):
try:
if isinstance(getattr(cfg, attr), fdl.Config):
_artifact_transform_load(getattr(cfg, attr), path=path)
except ValueError:
pass


def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType:
"""
Loads a configuration from a pickle file and constructs an object of the specified type.
Args:
path (Path): The path to the pickle file or directory containing 'io.pkl'.
output_type (Type[CkptType]): The type of the object to be constructed from the loaded data.
Returns
-------
CkptType: An instance of the specified type constructed from the loaded configuration.
Raises
------
FileNotFoundError: If the specified file does not exist.
Example:
loaded_model = load("/path/to/model", output_type=MyModel)
"""
del output_type # Just for type-hint

_path = Path(path)
_thread_local.output_dir = _path

if hasattr(_path, 'is_dir') and _path.is_dir():
_path = Path(_path) / "io.json"
elif hasattr(_path, 'isdir') and _path.isdir:
_path = Path(_path) / "io.json"

if not _path.is_file():
raise FileNotFoundError(f"No such file: '{_path}'")

## add IO functionality to custom objects present in the json file
with open(_path) as f:
j = json.load(f)
for obj, val in j["objects"].items():
clss = ".".join([val["type"]["module"], val["type"]["name"]])
if not serialization.find_node_traverser(locate(clss)):
track_io(locate(clss))

with open(_path, "rb") as f:
config = serialization.load_json(f.read())
_artifact_transform_load(config, path)

return fdl.build(config)

0 comments on commit 44fc6d4

Please sign in to comment.