-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: use static KV cache when available
Some models, like Gemma and Llama, support static KV cache in transformers. For these, it is possible to use this feature, leading to much higher performance.
- Loading branch information
1 parent
1aa45ab
commit 0a085a6
Showing
2 changed files
with
108 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import pytest | ||
import os | ||
from text_generation_server.generator import TpuGenerator | ||
from text_generation_server.model import fetch_model | ||
from text_generation_server.pb.generate_pb2 import ( | ||
Batch, | ||
NextTokenChooserParameters, | ||
Request, | ||
StoppingCriteriaParameters, | ||
) | ||
|
||
|
||
MODEL_ID = "google/gemma-2b" | ||
BATCH_SIZE = 4 | ||
SEQUENCE_LENGTH = 1024 | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def model_path(): | ||
# Add variables to environment so they can be used in TpuModelForCausalLM | ||
os.environ["HF_BATCH_SIZE"] = str(BATCH_SIZE) | ||
os.environ["HF_SEQUENCE_LENGTH"] = str(SEQUENCE_LENGTH) | ||
path = fetch_model(MODEL_ID) | ||
return path | ||
|
||
|
||
def create_request( | ||
id: int, | ||
inputs: str, | ||
max_new_tokens=20, | ||
do_sample: bool = False, | ||
top_k: int = 50, | ||
top_p: float = 0.9, | ||
temperature: float = 1.0, | ||
seed: int = 0, | ||
repetition_penalty: float = 1.0, | ||
): | ||
parameters = NextTokenChooserParameters( | ||
temperature=temperature, | ||
top_k=top_k, | ||
top_p=top_p, | ||
do_sample=do_sample, | ||
seed=seed, | ||
repetition_penalty=repetition_penalty, | ||
) | ||
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens) | ||
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters) | ||
|
||
|
||
def test_decode_single(model_path): | ||
input_text = "It was a bright cold day in April, and the clocks were striking thirteen." | ||
max_new_tokens = 20 | ||
generated_text = "\n\nThe first thing I noticed was the smell of the rain. It was a smell I had never" | ||
|
||
generator = TpuGenerator.from_pretrained(model_path) | ||
request = create_request(id=0, inputs=input_text, max_new_tokens=max_new_tokens, do_sample=False) | ||
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH) | ||
generations, next_batch = generator.prefill(batch) | ||
# We already generated one token: call decode max_new_tokens - 1 times | ||
for i in range(max_new_tokens - 1): | ||
assert next_batch.size == 1 | ||
assert next_batch.max_tokens == 1024 | ||
assert len(generations) == 1 | ||
assert len(generations[0].tokens.ids) == 1 | ||
assert next_batch is None | ||
assert len(generations) == 1 | ||
output = generations[0].generated_text | ||
assert output.generated_tokens == max_new_tokens | ||
assert output.finish_reason == 0 | ||
assert output.text == generated_text |