diff --git a/tests/torchtune/utils/test_checkpointer.py b/tests/torchtune/utils/test_checkpointer.py index d70166f551..4616b07486 100644 --- a/tests/torchtune/utils/test_checkpointer.py +++ b/tests/torchtune/utils/test_checkpointer.py @@ -292,3 +292,31 @@ def test_save_load_checkpoint_multiple_file( assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys()) assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys()) + + +class TestCheckpointerUtils: + @pytest.fixture + def model_checkpoint(self, tmp_path): + """ + Fixture which creates a checkpoint file for testing checkpointer utils. + """ + checkpoint_file = tmp_path / "model_checkpoint_01.pt" + + state_dict = { + "token_embeddings.weight": torch.ones(1, 10), + "output.weight": torch.ones(1, 10), + } + + torch.save(state_dict, checkpoint_file) + + return checkpoint_file + + @pytest.mark.parametrize("weights_only", [True, False]) + def test_safe_torch_load(self, model_checkpoint, weights_only): + state_dict = safe_torch_load(Path(model_checkpoint), weights_only) + + assert "token_embeddings.weight" in state_dict + assert "output.weight" in state_dict + + assert state_dict["token_embeddings.weight"].shape[1] == 10 + assert state_dict["output.weight"].shape[0] == 1 diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 6c273a11be..4b6486d84c 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -148,7 +148,7 @@ def __init__( ) self._recipe_checkpoint = get_path(self._checkpoint_dir, recipe_checkpoint) - def load_checkpoint(self) -> Dict[str, Any]: + def load_checkpoint(self, weights_only: bool = True) -> Dict[str, Any]: """ Load TorchTune checkpoint from file. Currently only loading from a single file is supported. @@ -162,9 +162,18 @@ def load_checkpoint(self) -> Dict[str, Any]: "optimizer": ..., ... } + + Args: + weights_only (bool): flag passed down to torch.load. We expose this, because quantized models + cannot be loaded with weights_only=True + + Returns: + Dict[str, Any]: state_dict from the input checkpoint """ state_dict: Dict[str:Any] = {} - state_dict[utils.MODEL_KEY] = safe_torch_load(self._checkpoint_path) + state_dict[utils.MODEL_KEY] = safe_torch_load( + self._checkpoint_path, weights_only=weights_only + ) if self._adapter_checkpoint: adapter_state_dict = safe_torch_load(self._adapter_checkpoint) diff --git a/torchtune/utils/_checkpointing/_checkpointer_utils.py b/torchtune/utils/_checkpointing/_checkpointer_utils.py index 3f52e8c968..d1ae4d2374 100644 --- a/torchtune/utils/_checkpointing/_checkpointer_utils.py +++ b/torchtune/utils/_checkpointing/_checkpointer_utils.py @@ -47,7 +47,7 @@ def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: return file_path -def safe_torch_load(checkpoint_path: Path) -> Dict[str, Any]: +def safe_torch_load(checkpoint_path: Path, weights_only: bool = True) -> Dict[str, Any]: """ Utility to load a checkpoint file in a safe manner. """ @@ -55,7 +55,10 @@ def safe_torch_load(checkpoint_path: Path) -> Dict[str, Any]: # convert the path into a string since pathlib Path and mmap don't work # well together state_dict = torch.load( - str(checkpoint_path), map_location="cpu", mmap=True, weights_only=True + str(checkpoint_path), + map_location="cpu", + mmap=True, + weights_only=weights_only, ) except Exception as e: raise ValueError(f"Unable to load checkpoint from {checkpoint_path}. ") from e