Skip to content

Commit

Permalink
feat: use static KV cache when available
Browse files Browse the repository at this point in the history
Some models, like Gemma and Llama, support static KV cache in
transformers. For these, it is possible to use this feature, leading to
much higher performance.
  • Loading branch information
tengomucho committed Mar 13, 2024
1 parent 1aa45ab commit 0a085a6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch_xla.core.xla_model as xm
from loguru import logger
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from transformers import AutoTokenizer, PreTrainedTokenizerBase, StaticCache
from transformers.generation import GenerationConfig

from .modeling import TpuModelForCausalLM
Expand Down Expand Up @@ -140,6 +140,10 @@ def generation_config(self) -> GenerationConfig:
def generated_tokens(self) -> int:
return self._generated_tokens

@property
def cur_position(self) -> int:
return self._next_text_token_start

def assign(self, request: Request, generation_config: GenerationConfig):
"""Assign a request to a slot.
Expand Down Expand Up @@ -294,6 +298,14 @@ def __init__(
self.special_tokens = self.tokenizer.all_special_ids
self.slots = [Slot(i, tokenizer, self.model.device) for i in range(self.model.config.batch_size)]
self.past_key_values = None
# _setup_cache is specific to some models (e.g.: Gemma and Llama). In those cases it is possible to setup
# a static cache, otherwise it is not.
self.use_static_cache = True
if getattr(self.model, "_setup_cache", False) is False:
logger.warning(
f"Static cache not available for {self.model.__class__.__name__}. Performance will be affected"
)
self.use_static_cache = False

@property
def info(self) -> InfoResponse:
Expand Down Expand Up @@ -388,10 +400,17 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
position_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
position_ids = position_ids[:, -input_ids.shape[-1] :]

# Pause previously active slots during generation.
# The KV cache of paused slots will be prefilled during generation but new tokens
# will be ignored, as they have already been generated and sent back in the last decode.
generation, next_batch = self._generate_token(batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids)
extra_args = {}
if self.use_static_cache:
self.model._setup_cache(StaticCache, len(self.slots), self.model.config.sequence_length)
extra_args["cache_position"] = torch.arange(seq_length, device=self.model.device)
else:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args
)

# Reactivate previously active slots for the next decode, and append
# back their next token.
for slot, next_token in zip(active_slots, next_tokens):
Expand Down Expand Up @@ -437,24 +456,31 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
)
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
position_ids[i, 0] = slot.generated_tokens
position_ids[i, 0] = slot.cur_position
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")

return self._generate_token(next_batch_id, input_ids, position_ids=position_ids)
extra_args = {}
if self.use_static_cache:
extra_args["cache_position"] = position_ids.max().unsqueeze(0)
else:
extra_args["past_key_values"] = self.past_key_values
return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args)

def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params) -> Tuple[List[Generation], CachedBatch]:
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params
) -> Tuple[List[Generation], CachedBatch]:
# Add barrier to allow next graph step to always be the same
xm.mark_step()
# Forward
outputs = self.model(
input_ids,
past_key_values=self.past_key_values,
return_dict=True,
use_cache=True,
**forward_extra_params,
)
# Save KV cache
self.past_key_values = outputs.past_key_values
if not self.use_static_cache:
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step(wait=False)

Expand Down
70 changes: 70 additions & 0 deletions text-generation-inference/tests/test_generator_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
import os
from text_generation_server.generator import TpuGenerator
from text_generation_server.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Batch,
NextTokenChooserParameters,
Request,
StoppingCriteriaParameters,
)


MODEL_ID = "google/gemma-2b"
BATCH_SIZE = 4
SEQUENCE_LENGTH = 1024


@pytest.fixture(scope="module")
def model_path():
# Add variables to environment so they can be used in TpuModelForCausalLM
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE)
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH)
path = fetch_model(MODEL_ID)
return path


def create_request(
id: int,
inputs: str,
max_new_tokens=20,
do_sample: bool = False,
top_k: int = 50,
top_p: float = 0.9,
temperature: float = 1.0,
seed: int = 0,
repetition_penalty: float = 1.0,
):
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)


def test_decode_single(model_path):
input_text = "It was a bright cold day in April, and the clocks were striking thirteen."
max_new_tokens = 20
generated_text = "\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never"

generator = TpuGenerator.from_pretrained(model_path)
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False)
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for i in range(max_new_tokens - 1):
assert next_batch.size == 1
assert next_batch.max_tokens == 1024
assert len(generations) == 1
assert len(generations[0].tokens.ids) == 1
assert next_batch is None
assert len(generations) == 1
output = generations[0].generated_text
assert output.generated_tokens == max_new_tokens
assert output.finish_reason == 0
assert output.text == generated_text

0 comments on commit 0a085a6

Please sign in to comment.