diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index ce07497899..ad6ba41e74 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -28,6 +28,7 @@ from torchtune.modules.tokenizers import ModelTokenizer from torchtune.modules.transforms import Transform from torchtune.recipe_interfaces import EvalRecipeInterface +from torchtune.training import FullModelTorchTuneCheckpointer try: import lm_eval @@ -475,13 +476,6 @@ def setup(self, cfg: DictConfig) -> None: # Load checkpoint checkpointer = config.instantiate(cfg.checkpointer) - if quantization_mode is None: - ckpt_dict = checkpointer.load_checkpoint() - else: - # weights_only needs to be False when loading a quantized model - # currently loading a quantized model is only supported with the - # FullModelTorchTuneCheckpointer - ckpt_dict = checkpointer.load_checkpoint(weights_only=False) # Initialize model with training.set_default_dtype(self.dtype), self.device: @@ -489,14 +483,32 @@ def setup(self, cfg: DictConfig) -> None: # Quantize model if requested if quantization_mode is not None: + if not isinstance(checkpointer, FullModelTorchTuneCheckpointer): + raise ValueError( + "Quantization is only supported for models quantized and saved with the " + "FullModelTorchTuneCheckpointer - please ensure you have quantized your " + "model and are using the quantized weights!" + ) + if "qat" in quantization_mode: + raise ValueError( + "You have specified a quantizer with 'QAT' - " + "QAT quantizers should only be used during quantization aware training " + "and when quantizing models. Please use the corresponding post-training " + "quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer." + ) model = quantizer.quantize(model) model = model.to(device=self.device, dtype=self.dtype) - for k, v in model_state_dict.items(): - model_state_dict[k] = v.to(self._device) - model.load_state_dict(model_state_dict, assign=True) + ckpt_dict = checkpointer.load_checkpoint(weights_only=False)[ + training.MODEL_KEY + ] + for k, v in ckpt_dict.items(): + ckpt_dict[k] = v.to(self.device) + model.load_state_dict(ckpt_dict, assign=True) + else: + ckpt_dict = checkpointer.load_checkpoint()[training.MODEL_KEY] + model.load_state_dict(ckpt_dict) # Load model weights into initialized model - model.load_state_dict(ckpt_dict[training.MODEL_KEY]) self.logger.info(f"Model is initialized with precision {self.dtype}.") # Put model in eval mode. diff --git a/tests/recipes/test_eleuther_eval.py b/tests/recipes/test_eleuther_eval.py index 32eaee4b1b..f09daf2309 100644 --- a/tests/recipes/test_eleuther_eval.py +++ b/tests/recipes/test_eleuther_eval.py @@ -14,7 +14,7 @@ import pytest from tests.common import TUNE_PATH -from tests.recipes.utils import llama2_test_config +from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config from tests.test_utils import CKPT_MODEL_PATHS @@ -126,6 +126,80 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir): in printed_err ) + @pytest.mark.integration_test + def test_eval_recipe_errors_with_quantization_hf_checkpointer( + self, capsys, monkeypatch, tmpdir + ): + ckpt = "llama2_hf" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + + # Config file needed for model conversion. + write_hf_ckpt_config(ckpt_dir) + + cmd = f""" + tune run eleuther_eval \ + --config eleuther_evaluation \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelHFCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + tokenizer.prompt_template=null \ + limit=1 \ + dtype=fp32 \ + device=cpu \ + quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ + quantizer.groupsize=256 \ + """.split() + + model_config = llama2_test_config() + cmd = cmd + model_config + + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises( + ValueError, + match="Quantization is only supported for models quantized and saved with the " + "FullModelTorchTuneCheckpointer", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + + @pytest.mark.integration_test + def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir): + ckpt = "llama2_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + + cmd = f""" + tune run eleuther_eval \ + --config eleuther_evaluation \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA2 \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + tokenizer.prompt_template=null \ + limit=1 \ + dtype=fp32 \ + device=cpu \ + quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \ + quantizer.groupsize=32\ + """.split() + + model_config = llama2_test_config() + cmd = cmd + model_config + + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises( + ValueError, + match="QAT quantizers should only be used during quantization aware training", + ): + runpy.run_path(TUNE_PATH, run_name="__main__") + @pytest.mark.integration_test def test_eval_recipe_errors_with_generate_until_and_mc_tasks( self, caplog, capsys, monkeypatch, tmpdir