Skip to content

Commit

Permalink
support gemma2 in pytorch engine(#1924)
Browse files Browse the repository at this point in the history
* add gemma2 support

* support gemma2

* eos stop
  • Loading branch information
grimoire authored Jul 5, 2024
1 parent 2e9e88f commit ab5b7ce
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ______________________________________________________________________

- \[2024/06\] PyTorch engine support DeepSeek-V2 and several VLMs, such as CogVLM2, Mini-InternVL, LlaVA-Next
- \[2024/05\] Balance vision model when deploying VLMs with multiple GPUs
- \[2024/05\] Support 4-bits weight-only quantization and inference on VMLs, such as InternVL v1.5, LLaVa, InternLMXComposer2
- \[2024/05\] Support 4-bits weight-only quantization and inference on VLMs, such as InternVL v1.5, LLaVa, InternLMXComposer2
- \[2024/04\] Support Llama3 and more VLMs, such as InternVL v1.1, v1.2, MiniGemini, InternLMXComposer2.
- \[2024/04\] TurboMind adds online int8/int4 KV cache quantization and inference for all supported devices. Refer [here](docs/en/quantization/kv_quant.md) for detailed guide
- \[2024/04\] TurboMind latest upgrade boosts GQA, rocketing the [internlm2-20b](https://huggingface.co/internlm/internlm2-20b) model inference to 16+ RPS, about 1.8x faster than vLLM.
Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| CogVLM2-Chat | 19B | Yes | No | No |
| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No |
| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No |
| Gemma2 | 9B-27B | Yes | No | No |
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| CogVLM2-Chat | 19B | Yes | No | No |
| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No |
| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No |
| Gemma2 | 9B-27B | Yes | No | No |
2 changes: 2 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,11 +1151,13 @@ def __init__(self,
eoh='<end_of_turn>\n',
assistant='<start_of_turn>model\n',
eoa='<end_of_turn>\n',
stop_words=['<end_of_turn>'],
**kwargs):
super().__init__(user=user,
eoh=eoh,
assistant=assistant,
eoa=eoa,
stop_words=stop_words,
**kwargs)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/configurations/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class GemmaModelConfigBuilder(AutoModelConfigBuilder):
@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'gemma'
return hf_config.model_type in ['gemma', 'gemma2']

@classmethod
def build(cls, hf_config, model_path: str = None):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
if msg.status != MessageStatus.RUNNING:
continue
update_token = token
if stop or token in eos_token_id:
stop = stop or token in eos_token_id
if stop:
update_token = _EMPTY_TOKEN
else:
msg.num_new_tokens += 1
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __rotary_emb_fn(query_states, key_states, value_states):
block_offsets=block_offsets,
)

window_size = getattr(self, 'sliding_window', None)
sm_scale = getattr(self, 'scaling', None)
attn_output = query_states
paged_attention_fwd(
query_states,
Expand All @@ -135,6 +137,8 @@ def __rotary_emb_fn(query_states, key_states, value_states):
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
window_size=window_size,
sm_scale=sm_scale,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
Expand Down
18 changes: 17 additions & 1 deletion lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,28 @@
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma.modeling_gemma.GemmaModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel',
'transformers.models.gemma.modeling_gemma.modeling_mistral.GemmaMLP':
'transformers.models.gemma.modeling_gemma.GemmaMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.gemma.modeling_gemma.GemmaRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm',
})

# gemma2
MODULE_MAP.update({
'transformers.models.gemma2.modeling_gemma2.Gemma2Attention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma2.modeling_gemma2.Gemma2FlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma2.modeling_gemma2.Gemma2SdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma2.modeling_gemma2.Gemma2Model':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel',
'transformers.models.gemma2.modeling_gemma2.Gemma2MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.gemma2.modeling_gemma2.Gemma2RMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm',
})

# deepseek
MODULE_MAP.update({
'modeling_deepseek.DeepseekAttention':
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
DeepseekV2ForCausalLM=True,
# internvl
InternVLChatModel=True,
# gemma2
Gemma2ForCausalLM=True,
)


Expand Down

0 comments on commit ab5b7ce

Please sign in to comment.