Skip to content

Commit

Permalink
[fix] fix issue with BloomBlock due to transformers upgrade (#2640)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Dec 17, 2024
1 parent 96a0efd commit bffb5a0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 24 deletions.
28 changes: 6 additions & 22 deletions engines/python/setup/djl_python/seq_scheduler/lm_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Tuple, Union

import torch
from transformers import DynamicCache


class LMBlock(ABC):
Expand Down Expand Up @@ -107,34 +108,17 @@ def forward(self, input_ids: torch.tensor, position_ids: torch.tensor,

# Pre-process
if past_key_values is not None:
_, num_head, seq_len, kv_dim = past_key_values[0][0].shape
new_kv_list = []
for k, v in past_key_values:
k_new = torch.permute(
k.view(batch_size * num_head, seq_len, kv_dim), (0, 2, 1))
v_new = v.view(batch_size * num_head, seq_len, kv_dim)
new_kv_list.append((k_new, v_new))
past_key_values = tuple(new_kv_list)
cache = DynamicCache.from_legacy_cache(past_key_values)
else:
cache = DynamicCache()

# Forward
output = self.model.forward(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
past_key_values=cache,
**self.config)
past_key_values = output.past_key_values

# Post-process
_, kv_dim, seq_len = past_key_values[0][0].shape
new_kv_list = []
for k, v in past_key_values:
k_new = torch.permute(k, (0, 2, 1)).view(batch_size, -1, seq_len,
kv_dim)
v_new = v.view(batch_size, -1, seq_len, kv_dim)
new_kv_list.append((k_new, v_new))
past_key_values = tuple(new_kv_list)
output.past_key_values = past_key_values

output.past_key_values = output.past_key_values.to_legacy_cache()
return output


Expand Down
4 changes: 2 additions & 2 deletions engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def run(self):
requirements = ['psutil', 'packaging', 'wheel']

test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers==4.43.4', 'torch',
'einops', 'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph"
]

Expand Down

0 comments on commit bffb5a0

Please sign in to comment.