Skip to content

Commit

Permalink
Align to Transformers 4.41.1 (#45)
Browse files Browse the repository at this point in the history
* feat: align to transformers 4.41.1

When aligning, new static cache handling has been taken into account,
and modeling classes for Llama and Gemma have been aligned too.

* refactor(generator): rename variable use_static_cache -> _supports_static_cache

This name is aligned to the one use internally in transformers.

* refactor(Gemma): rename TpuGemma -> Gemma

This goes in the opposite direction of what was done after importing the
classes from transformers, but it turns out it makes it easier to update
the classes to latest changes from the original repo.
  • Loading branch information
tengomucho authored Jun 3, 2024
1 parent 7cf4c6d commit 292bd41
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 287 deletions.
23 changes: 14 additions & 9 deletions examples/text-generation/generation_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ def sample_greedy(logits):
return next_token_id


def decode_one_tokens(model, cur_token, input_pos, cache_position, step):
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
return_dict=False,
use_cache=True,
past_key_values=past_key_values,
)[0]
new_token = sample_greedy(logits)
return new_token
Expand Down Expand Up @@ -69,10 +70,14 @@ def main():
max_cache_length = 1024
max_new_tokens = 20

start = time.time()
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)
end = time.time()
print(f"Model cache setup took {end - start} seconds.")
# setup static cache
past_key_values = StaticCache(
config=model.config,
max_batch_size=batch_size,
max_cache_len=max_cache_length,
device=model.device,
dtype=model.dtype,
)
start = time.time()
cache_position = torch.arange(sequence_length, device=device)
generated_ids = torch.zeros(
Expand All @@ -91,6 +96,7 @@ def main():
return_dict=False,
use_cache=True,
position_ids=pos_ids,
past_key_values=past_key_values,
)[0]
next_token = sample_greedy(logits)
xm.mark_step()
Expand All @@ -101,14 +107,13 @@ def main():
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1

model = conditional_compile(model)
cache_position = torch.tensor([sequence_length + 1], device=device)
cache_position = torch.tensor([sequence_length], device=device)
decode_times = []
for i in range(max_new_tokens):
step_start = time.time()
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, i)
generated_ids[:, cache_position] = next_token

next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values)
cache_position += 1
generated_ids[:, cache_position] = next_token
pos_ids += 1
xm.mark_step()
step_end = time.time()
Expand Down
6 changes: 3 additions & 3 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
model_type = model.config.model_type
matched_model = False
if model_type == "gemma":
from .modeling_gemma import TpuGemmaForCausalLM
from .modeling_gemma import GemmaForCausalLM

if isinstance(model, TpuGemmaForCausalLM):
cls_to_wrap = "TpuGemmaDecoderLayer"
if isinstance(model, GemmaForCausalLM):
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
from .modeling_llama import LlamaForCausalLM
Expand Down
4 changes: 2 additions & 2 deletions optimum/tpu/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
def config_name_to_class(pretrained_model_name_or_path: str):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
if config.model_type == "gemma":
from .modeling_gemma import TpuGemmaForCausalLM
from .modeling_gemma import GemmaForCausalLM

return TpuGemmaForCausalLM
return GemmaForCausalLM
if config.model_type == "llama":
from .modeling_llama import LlamaForCausalLM

Expand Down
Loading

0 comments on commit 292bd41

Please sign in to comment.