From 0a085a68fed4ce56cd23048aa2828080a24392ce Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 13 Mar 2024 14:32:49 +0000 Subject: [PATCH] feat: use static KV cache when available Some models, like Gemma and Llama, support static KV cache in transformers. For these, it is possible to use this feature, leading to much higher performance. --- .../text_generation_server/generator.py | 50 +++++++++---- .../tests/test_generator_gemma.py | 70 +++++++++++++++++++ 2 files changed, 108 insertions(+), 12 deletions(-) create mode 100644 text-generation-inference/tests/test_generator_gemma.py diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index 5dc30e88..2b381a1f 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -8,7 +8,7 @@ import torch import torch_xla.core.xla_model as xm from loguru import logger -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache from transformers.generation import GenerationConfig from .modeling import TpuModelForCausalLM @@ -140,6 +140,10 @@ def generation_config(self) -> GenerationConfig: def generated_tokens(self) -> int: return self._generated_tokens + @property + def cur_position(self) -> int: + return self._next_text_token_start + def assign(self, request: Request, generation_config: GenerationConfig): """Assign a request to a slot. @@ -294,6 +298,14 @@ def __init__( self.special_tokens = self.tokenizer.all_special_ids self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)] self.past_key_values = None + # _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup + # a static cache, otherwise it is not. + self.use_static_cache = True + if getattr(self.model, "_setup_cache", False) is False: + logger.warning( + f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected" + ) + self.use_static_cache = False @property def info(self) -> InfoResponse: @@ -388,10 +400,17 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0) position_ids = position_ids[:, -input_ids.shape[-1] :] - # Pause previously active slots during generation. - # The KV cache of paused slots will be prefilled during generation but new tokens - # will be ignored, as they have already been generated and sent back in the last decode. - generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids) + extra_args = {} + if self.use_static_cache: + self.model._setup_cache(StaticCache, len(self.slots), self.model.config.sequence_length) + extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device) + else: + # Reset/clear KV cache + self.past_key_values = None + generation, next_batch = self._generate_token( + batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args + ) + # Reactivate previously active slots for the next decode, and append # back their next token. for slot, next_token in zip(active_slots, next_tokens): @@ -437,24 +456,31 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa ) # input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached) input_ids[i, 0] = slot.next_token - position_ids[i, 0] = slot.generated_tokens + position_ids[i, 0] = slot.cur_position if input_ids is None: raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)") - - return self._generate_token(next_batch_id, input_ids, position_ids=position_ids) + extra_args = {} + if self.use_static_cache: + extra_args["cache_position"] = position_ids.max().unsqueeze(0) + else: + extra_args["past_key_values"] = self.past_key_values + return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args) def _generate_token( - self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params) -> Tuple[List[Generation], CachedBatch]: + self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params + ) -> Tuple[List[Generation], CachedBatch]: + # Add barrier to allow next graph step to always be the same + xm.mark_step() # Forward outputs = self.model( input_ids, - past_key_values=self.past_key_values, return_dict=True, use_cache=True, **forward_extra_params, ) - # Save KV cache - self.past_key_values = outputs.past_key_values + if not self.use_static_cache: + # Save KV cache + self.past_key_values = outputs.past_key_values # Barrier for XLA model xm.mark_step(wait=False) diff --git a/text-generation-inference/tests/test_generator_gemma.py b/text-generation-inference/tests/test_generator_gemma.py new file mode 100644 index 00000000..9b562b22 --- /dev/null +++ b/text-generation-inference/tests/test_generator_gemma.py @@ -0,0 +1,70 @@ +import pytest +import os +from text_generation_server.generator import TpuGenerator +from text_generation_server.model import fetch_model +from text_generation_server.pb.generate_pb2 import ( + Batch, + NextTokenChooserParameters, + Request, + StoppingCriteriaParameters, +) + + +MODEL_ID = "google/gemma-2b" +BATCH_SIZE = 4 +SEQUENCE_LENGTH = 1024 + + +@pytest.fixture(scope="module") +def model_path(): + # Add variables to environment so they can be used in TpuModelForCausalLM + os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) + os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) + path = fetch_model(MODEL_ID) + return path + + +def create_request( + id: int, + inputs: str, + max_new_tokens=20, + do_sample: bool = False, + top_k: int = 50, + top_p: float = 0.9, + temperature: float = 1.0, + seed: int = 0, + repetition_penalty: float = 1.0, +): + parameters = NextTokenChooserParameters( + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=do_sample, + seed=seed, + repetition_penalty=repetition_penalty, + ) + stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) + return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) + + +def test_decode_single(model_path): + input_text = "It was a bright cold day in April, and the clocks were striking thirteen." + max_new_tokens = 20 + generated_text = "\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never" + + generator = TpuGenerator.from_pretrained(model_path) + request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) + batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) + generations, next_batch = generator.prefill(batch) + # We already generated one token: call decode max_new_tokens - 1 times + for i in range(max_new_tokens - 1): + assert next_batch.size == 1 + assert next_batch.max_tokens == 1024 + assert len(generations) == 1 + assert len(generations[0].tokens.ids) == 1 + assert next_batch is None + assert len(generations) == 1 + output = generations[0].generated_text + assert output.generated_tokens == max_new_tokens + assert output.finish_reason == 0 + assert output.text == generated_text