Skip to content

Commit

Permalink
Fix kv-cacheing and bsz > 1 in eval recipe (#1622)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Sep 19, 2024
1 parent cd573f9 commit 5dd045e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 32 deletions.
1 change: 0 additions & 1 deletion recipes/configs/eleuther_evaluation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ tasks: ["truthfulqa_mc2"]
limit: null
max_seq_length: 4096
batch_size: 8
# It is recommended to set enable_kv_cache=False for long-context models like Llama3.1
enable_kv_cache: True

# Quantization specific args
Expand Down
60 changes: 34 additions & 26 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class _EvalWrapper(HFLM):
max_seq_length (int): The maximum sequence length to use.
batch_size (int): The batch size per GPU to use.
dtype (torch.dtype): dtype for the model caches during generation.
enable_kv_cache (bool): Whether to enable KV cache for generation.
"""

def __init__(
Expand All @@ -58,15 +57,13 @@ def __init__(
max_seq_length: int = 4096,
batch_size: int = 8,
dtype: torch.dtype = torch.float32,
enable_kv_cache: bool = True,
):
super().__init__(pretrained="gpt2", device=str(device))
self._model = model
self._tokenizer = tokenizer
self._max_seq_length = max_seq_length
self._batch_size = batch_size
self._dtype = dtype
self._enable_kv_cache = enable_kv_cache

@property
def model(self):
Expand All @@ -92,10 +89,6 @@ def batch_size(self):
def device(self):
return self._device

@property
def enable_kv_cache(self):
return self._enable_kv_cache

def tok_encode(self, text: str, **kwargs) -> List[int]:
# Note on add_bos flag: setting to False as this gives better results, for example
# +1% on truthfulqa_mc2 with a LoRA finetune. lit-gpt also sets this to False,
Expand Down Expand Up @@ -131,19 +124,15 @@ def _model_generate(
) -> torch.Tensor:
curr_batch_size = context.size(0)

if curr_batch_size > 1:
raise ValueError(
f"Got a batch size of '{curr_batch_size}'. Batch size > 1 is not supported for "
"generation. See https://github.com/pytorch/torchtune/issues/1250 for more info."
)

# Setup caches for a given batch size
# Technically this is not necessary, but it's a good way to ensure that
# the caches won't error on a different batch size. In addition, caches
# are not needed for a regular model call, so we just setup here
if self.enable_kv_cache:
with context.device:
self._model.setup_caches(batch_size=curr_batch_size, dtype=self._dtype)
# if we've recieved fewer than self._batch_size samples in the current
# batch we need to pad the batch out. here we're padding the end of the
# current batch to the correct length. this is because when we use static
# KV-caches, the model will expect a fixed batch size for all samples.
maybe_padded_context = torch.nn.functional.pad(
context,
(0, 0, 0, self._batch_size - curr_batch_size),
value=self._tokenizer.eos_id, # pad with one of the tokenizer's stop tokens so generation can stop early
)

temperature = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", False)
Expand All @@ -156,14 +145,14 @@ def _model_generate(

toks, _ = generation.generate(
self._model,
context,
maybe_padded_context,
max_generated_tokens=self.max_gen_toks,
temperature=temperature,
top_k=None, # do_sample is not supported currently
stop_tokens=self._tokenizer.stop_tokens,
)
self._model.reset_caches()
return torch.tensor(toks, dtype=torch.int32)
return toks[:curr_batch_size]


class EleutherEvalRecipe(EvalRecipeInterface):
Expand All @@ -175,7 +164,6 @@ class EleutherEvalRecipe(EvalRecipeInterface):
Features:
- Single GPU evaluation. Multi-GPU evaluation is currently not supported.
- Loading model in fp32 or bf16. Fp16 is currently not supported.
- Any task from the EleutherAI eval harness that is *not* free generation
We recommend launching evaluation using the tune CLI:
Expand All @@ -198,6 +186,9 @@ def setup(self) -> None:
self._quantization_mode = training.get_quantizer_mode(self._quantizer)
self._enable_kv_cache = self._cfg.get("enable_kv_cache", True)

self._batch_size = self._cfg.batch_size
self._max_seq_length = self._cfg.get("max_seq_length", 4096)

training.set_seed(seed=self._cfg.seed)

checkpointer = config.instantiate(self._cfg.checkpointer)
Expand Down Expand Up @@ -253,10 +244,9 @@ def evaluate(self) -> None:
self._model,
self._tokenizer,
device=self._device,
max_seq_length=self._cfg.max_seq_length,
batch_size=self._cfg.batch_size,
max_seq_length=self._max_seq_length,
batch_size=self._batch_size,
dtype=self._dtype,
enable_kv_cache=self._enable_kv_cache,
)

# Task initialization API changed between v0.4.1 and 0.4.2
Expand All @@ -268,6 +258,24 @@ def evaluate(self) -> None:
task_manager = TaskManager(include_path=self._cfg.get("include_path", None))
task_dict = get_task_dict(self._tasks, task_manager)

task_types = set([task.OUTPUT_TYPE for _, task in task_dict.items()])
if len(task_types) > 1 and "generate_until" in task_types:
raise RuntimeError(
"Evaluating on multiple task types where any one task involves "
"generation is currently not supported. See the issue below for more info: "
"https://github.com/pytorch/torchtune/issues/1621"
)

# Setup caches for a given batch size
if self._enable_kv_cache and "generate_until" in task_types:
with self._device:
self._model.setup_caches(
batch_size=self._batch_size,
dtype=self._dtype,
decoder_max_seq_len=self._max_seq_length
+ model_eval_wrapper.max_gen_toks,
)

logger.info(f"Running evaluation on {self._tasks} tasks.")
output = evaluate(
model_eval_wrapper,
Expand Down
15 changes: 10 additions & 5 deletions tests/recipes/test_eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
class TestEleutherEval:
@pytest.mark.parametrize(
"eval_name, expected_acc, bsz",
[("truthfulqa_gen", 0.1, 1), ("truthfulqa_mc2", 0.3, 8)],
[
("truthfulqa_gen", 0.1, 8),
("truthfulqa_gen", 0.1, 1),
("truthfulqa_mc2", 0.4, 8),
],
)
@pytest.mark.integration_test
def test_torchtune_checkpoint_eval_results(
Expand All @@ -31,7 +35,8 @@ def test_torchtune_checkpoint_eval_results(
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent

# TODO @joecummings bsz > 1 isn't supported for generation tasks, update test once integrated
# explicitly setting limit to an odd number here to ensure generation tasks
# work with KV-cacheing + bsz > 1 - we'll recieve batches of size 8, 8, 5
cmd = f"""
tune run eleuther_eval \
--config eleuther_evaluation \
Expand All @@ -43,7 +48,7 @@ def test_torchtune_checkpoint_eval_results(
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
tokenizer.prompt_template=null \
limit=10 \
limit=21 \
dtype=fp32 \
device=cpu \
tasks=[{eval_name}]\
Expand All @@ -62,12 +67,12 @@ def test_torchtune_checkpoint_eval_results(
# v0.4.2 format
# | Tasks |Version|Filter|n-shot|Metric|Value | |Stderr|
# |--------------|------:|------|-----:|------|-----:|---|-----:|
# |truthfulqa_mc2| 2|none | 0|acc |0.3469|± |0.1444|
# |truthfulqa_mc2| 2|none | 0|acc |0.4497|± |0.1067|

# v0.4.3 format
# | Tasks |Version|Filter|n-shot|Metric| |Value | |Stderr|
# |--------------|------:|------|-----:|------|---|-----:|---|-----:|
# |truthfulqa_mc2| 2|none | 0|acc |↑ |0.3469|± |0.1444|
# |truthfulqa_mc2| 2|none | 0|acc |↑ |0.4497|± |0.1067|

# The below RegEx command will pick up both formats
search_results = re.search(
Expand Down

0 comments on commit 5dd045e

Please sign in to comment.