Skip to content

Commit

Permalink
Add weights_only flag to torchtune checkpointer (#642)
Browse files Browse the repository at this point in the history
  • Loading branch information
kartikayk authored Apr 2, 2024
1 parent e7e310a commit 5f865b0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
28 changes: 28 additions & 0 deletions tests/torchtune/utils/test_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,31 @@ def test_save_load_checkpoint_multiple_file(

assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys())
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys())


class TestCheckpointerUtils:
@pytest.fixture
def model_checkpoint(self, tmp_path):
"""
Fixture which creates a checkpoint file for testing checkpointer utils.
"""
checkpoint_file = tmp_path / "model_checkpoint_01.pt"

state_dict = {
"token_embeddings.weight": torch.ones(1, 10),
"output.weight": torch.ones(1, 10),
}

torch.save(state_dict, checkpoint_file)

return checkpoint_file

@pytest.mark.parametrize("weights_only", [True, False])
def test_safe_torch_load(self, model_checkpoint, weights_only):
state_dict = safe_torch_load(Path(model_checkpoint), weights_only)

assert "token_embeddings.weight" in state_dict
assert "output.weight" in state_dict

assert state_dict["token_embeddings.weight"].shape[1] == 10
assert state_dict["output.weight"].shape[0] == 1
13 changes: 11 additions & 2 deletions torchtune/utils/_checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
)
self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint)

def load_checkpoint(self) -> Dict[str, Any]:
def load_checkpoint(self, weights_only: bool = True) -> Dict[str, Any]:
"""
Load TorchTune checkpoint from file. Currently only loading from a single file is supported.
Expand All @@ -162,9 +162,18 @@ def load_checkpoint(self) -> Dict[str, Any]:
"optimizer": ...,
...
}
Args:
weights_only (bool): flag passed down to torch.load. We expose this, because quantized models
cannot be loaded with weights_only=True
Returns:
Dict[str, Any]: state_dict from the input checkpoint
"""
state_dict: Dict[str:Any] = {}
state_dict[utils.MODEL_KEY] = safe_torch_load(self._checkpoint_path)
state_dict[utils.MODEL_KEY] = safe_torch_load(
self._checkpoint_path, weights_only=weights_only
)

if self._adapter_checkpoint:
adapter_state_dict = safe_torch_load(self._adapter_checkpoint)
Expand Down
7 changes: 5 additions & 2 deletions torchtune/utils/_checkpointing/_checkpointer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,18 @@ def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path:
return file_path


def safe_torch_load(checkpoint_path: Path) -> Dict[str, Any]:
def safe_torch_load(checkpoint_path: Path, weights_only: bool = True) -> Dict[str, Any]:
"""
Utility to load a checkpoint file in a safe manner.
"""
try:
# convert the path into a string since pathlib Path and mmap don't work
# well together
state_dict = torch.load(
str(checkpoint_path), map_location="cpu", mmap=True, weights_only=True
str(checkpoint_path),
map_location="cpu",
mmap=True,
weights_only=weights_only,
)
except Exception as e:
raise ValueError(f"Unable to load checkpoint from {checkpoint_path}. ") from e
Expand Down

0 comments on commit 5f865b0

Please sign in to comment.