From 5dd045e8332c1e614d62dbf16ecc70cf05b58baf Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Thu, 19 Sep 2024 18:35:57 +0100 Subject: [PATCH] Fix kv-cacheing and bsz > 1 in eval recipe (#1622) --- recipes/configs/eleuther_evaluation.yaml | 1 - recipes/eleuther_eval.py | 60 ++++++++++++++---------- tests/recipes/test_eleuther_eval.py | 15 ++++-- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/recipes/configs/eleuther_evaluation.yaml b/recipes/configs/eleuther_evaluation.yaml index 10a8c96088..e62fa0219c 100644 --- a/recipes/configs/eleuther_evaluation.yaml +++ b/recipes/configs/eleuther_evaluation.yaml @@ -33,7 +33,6 @@ tasks: ["truthfulqa_mc2"] limit: null max_seq_length: 4096 batch_size: 8 -# It is recommended to set enable_kv_cache=False for long-context models like Llama3.1 enable_kv_cache: True # Quantization specific args diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index 2fe8db1a29..597f88ecad 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -46,7 +46,6 @@ class _EvalWrapper(HFLM): max_seq_length (int): The maximum sequence length to use. batch_size (int): The batch size per GPU to use. dtype (torch.dtype): dtype for the model caches during generation. - enable_kv_cache (bool): Whether to enable KV cache for generation. """ def __init__( @@ -58,7 +57,6 @@ def __init__( max_seq_length: int = 4096, batch_size: int = 8, dtype: torch.dtype = torch.float32, - enable_kv_cache: bool = True, ): super().__init__(pretrained="gpt2", device=str(device)) self._model = model @@ -66,7 +64,6 @@ def __init__( self._max_seq_length = max_seq_length self._batch_size = batch_size self._dtype = dtype - self._enable_kv_cache = enable_kv_cache @property def model(self): @@ -92,10 +89,6 @@ def batch_size(self): def device(self): return self._device - @property - def enable_kv_cache(self): - return self._enable_kv_cache - def tok_encode(self, text: str, **kwargs) -> List[int]: # Note on add_bos flag: setting to False as this gives better results, for example # +1% on truthfulqa_mc2 with a LoRA finetune. lit-gpt also sets this to False, @@ -131,19 +124,15 @@ def _model_generate( ) -> torch.Tensor: curr_batch_size = context.size(0) - if curr_batch_size > 1: - raise ValueError( - f"Got a batch size of '{curr_batch_size}'. Batch size > 1 is not supported for " - "generation. See https://github.com/pytorch/torchtune/issues/1250 for more info." - ) - - # Setup caches for a given batch size - # Technically this is not necessary, but it's a good way to ensure that - # the caches won't error on a different batch size. In addition, caches - # are not needed for a regular model call, so we just setup here - if self.enable_kv_cache: - with context.device: - self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype) + # if we've recieved fewer than self._batch_size samples in the current + # batch we need to pad the batch out. here we're padding the end of the + # current batch to the correct length. this is because when we use static + # KV-caches, the model will expect a fixed batch size for all samples. + maybe_padded_context = torch.nn.functional.pad( + context, + (0, 0, 0, self._batch_size - curr_batch_size), + value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early + ) temperature = generation_kwargs.get("temperature", 0.0) do_sample = generation_kwargs.get("do_sample", False) @@ -156,14 +145,14 @@ def _model_generate( toks, _ = generation.generate( self._model, - context, + maybe_padded_context, max_generated_tokens=self.max_gen_toks, temperature=temperature, top_k=None, # do_sample is not supported currently stop_tokens=self._tokenizer.stop_tokens, ) self._model.reset_caches() - return torch.tensor(toks, dtype=torch.int32) + return toks[:curr_batch_size] class EleutherEvalRecipe(EvalRecipeInterface): @@ -175,7 +164,6 @@ class EleutherEvalRecipe(EvalRecipeInterface): Features: - Single GPU evaluation. Multi-GPU evaluation is currently not supported. - Loading model in fp32 or bf16. Fp16 is currently not supported. - - Any task from the EleutherAI eval harness that is *not* free generation We recommend launching evaluation using the tune CLI: @@ -198,6 +186,9 @@ def setup(self) -> None: self._quantization_mode = training.get_quantizer_mode(self._quantizer) self._enable_kv_cache = self._cfg.get("enable_kv_cache", True) + self._batch_size = self._cfg.batch_size + self._max_seq_length = self._cfg.get("max_seq_length", 4096) + training.set_seed(seed=self._cfg.seed) checkpointer = config.instantiate(self._cfg.checkpointer) @@ -253,10 +244,9 @@ def evaluate(self) -> None: self._model, self._tokenizer, device=self._device, - max_seq_length=self._cfg.max_seq_length, - batch_size=self._cfg.batch_size, + max_seq_length=self._max_seq_length, + batch_size=self._batch_size, dtype=self._dtype, - enable_kv_cache=self._enable_kv_cache, ) # Task initialization API changed between v0.4.1 and 0.4.2 @@ -268,6 +258,24 @@ def evaluate(self) -> None: task_manager = TaskManager(include_path=self._cfg.get("include_path", None)) task_dict = get_task_dict(self._tasks, task_manager) + task_types = set([task.OUTPUT_TYPE for _, task in task_dict.items()]) + if len(task_types) > 1 and "generate_until" in task_types: + raise RuntimeError( + "Evaluating on multiple task types where any one task involves " + "generation is currently not supported. See the issue below for more info: " + "https://github.com/pytorch/torchtune/issues/1621" + ) + + # Setup caches for a given batch size + if self._enable_kv_cache and "generate_until" in task_types: + with self._device: + self._model.setup_caches( + batch_size=self._batch_size, + dtype=self._dtype, + decoder_max_seq_len=self._max_seq_length + + model_eval_wrapper.max_gen_toks, + ) + logger.info(f"Running evaluation on {self._tasks} tasks.") output = evaluate( model_eval_wrapper, diff --git a/tests/recipes/test_eleuther_eval.py b/tests/recipes/test_eleuther_eval.py index 1575ca04cc..e89e38c246 100644 --- a/tests/recipes/test_eleuther_eval.py +++ b/tests/recipes/test_eleuther_eval.py @@ -21,7 +21,11 @@ class TestEleutherEval: @pytest.mark.parametrize( "eval_name, expected_acc, bsz", - [("truthfulqa_gen", 0.1, 1), ("truthfulqa_mc2", 0.3, 8)], + [ + ("truthfulqa_gen", 0.1, 8), + ("truthfulqa_gen", 0.1, 1), + ("truthfulqa_mc2", 0.4, 8), + ], ) @pytest.mark.integration_test def test_torchtune_checkpoint_eval_results( @@ -31,7 +35,8 @@ def test_torchtune_checkpoint_eval_results( ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) ckpt_dir = ckpt_path.parent - # TODO @joecummings bsz > 1 isn't supported for generation tasks, update test once integrated + # explicitly setting limit to an odd number here to ensure generation tasks + # work with KV-cacheing + bsz > 1 - we'll recieve batches of size 8, 8, 5 cmd = f""" tune run eleuther_eval \ --config eleuther_evaluation \ @@ -43,7 +48,7 @@ def test_torchtune_checkpoint_eval_results( checkpointer.model_type=LLAMA2 \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ tokenizer.prompt_template=null \ - limit=10 \ + limit=21 \ dtype=fp32 \ device=cpu \ tasks=[{eval_name}]\ @@ -62,12 +67,12 @@ def test_torchtune_checkpoint_eval_results( # v0.4.2 format # | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr| # |--------------|------:|------|-----:|------|-----:|---|-----:| - # |truthfulqa_mc2| 2|none | 0|acc |0.3469|± |0.1444| + # |truthfulqa_mc2| 2|none | 0|acc |0.4497|± |0.1067| # v0.4.3 format # | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr| # |--------------|------:|------|-----:|------|---|-----:|---|-----:| - # |truthfulqa_mc2| 2|none | 0|acc |↑ |0.3469|± |0.1444| + # |truthfulqa_mc2| 2|none | 0|acc |↑ |0.4497|± |0.1067| # The below RegEx command will pick up both formats search_results = re.search(