Skip to content

Commit

Permalink
Add safe-serialization to FullModelHFCheckpointer (#1096)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffrey-fong authored Jun 21, 2024
1 parent ef6e196 commit f200da5
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions torchtune/utils/_checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, Dict, List, Optional, Protocol

import torch
from safetensors.torch import save_file
from torchtune import utils

from torchtune.models import convert_weights
Expand Down Expand Up @@ -305,6 +306,7 @@ class FullModelHFCheckpointer(_CheckpointerInterface):
recipe_checkpoint (Optional[str]): Path to the recipe state checkpoint file. Default is None
resume_from_checkpoint (bool): If True, the checkpointer will load the additional checkpoint files to
resume training from a previous run. Default is False
safe_serialization (bool): If True, the checkpointer will save the checkpoint file using `safetensors`
Raises:
ValueError: If ``resume_from_checkpoint`` is True but ``recipe_checkpoint`` is None
Expand All @@ -319,6 +321,7 @@ def __init__(
adapter_checkpoint: Optional[str] = None,
recipe_checkpoint: Optional[str] = None,
resume_from_checkpoint: bool = False,
safe_serialization: bool = False,
) -> None:
self._checkpoint_dir = Path(checkpoint_dir)
self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files)
Expand All @@ -331,6 +334,7 @@ def __init__(
self._model_type = ModelType[model_type]
self._output_dir = Path(output_dir)
self._resume_from_checkpoint = resume_from_checkpoint
self._safe_serialization = safe_serialization

# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
# parition the state dict into output checkpoint files. This is updated during checkpoint
Expand Down Expand Up @@ -508,10 +512,17 @@ def save_checkpoint(

# write the partitioned state dicts to the right checkpoint file
for cpt_idx, model_state_dict in split_state_dicts.items():
output_path = Path.joinpath(
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
).with_suffix(".pt")
torch.save(model_state_dict, output_path)
if not self._safe_serialization:
output_path = Path.joinpath(
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
).with_suffix(".pt")
torch.save(model_state_dict, output_path)
else:
output_path = Path.joinpath(
self._output_dir,
f"model-0{cpt_idx}-of-0{list(split_state_dicts.keys())[-1]}_{epoch}",
).with_suffix(".safetensors")
save_file(model_state_dict, output_path)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
Expand Down

0 comments on commit f200da5

Please sign in to comment.