Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gemma2 in pytorch engine #1924

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ______________________________________________________________________

- \[2024/06\] PyTorch engine support DeepSeek-V2 and several VLMs, such as CogVLM2, Mini-InternVL, LlaVA-Next
- \[2024/05\] Balance vision model when deploying VLMs with multiple GPUs
- \[2024/05\] Support 4-bits weight-only quantization and inference on VMLs, such as InternVL v1.5, LLaVa, InternLMXComposer2
- \[2024/05\] Support 4-bits weight-only quantization and inference on VLMs, such as InternVL v1.5, LLaVa, InternLMXComposer2
- \[2024/04\] Support Llama3 and more VLMs, such as InternVL v1.1, v1.2, MiniGemini, InternLMXComposer2.
- \[2024/04\] TurboMind adds online int8/int4 KV cache quantization and inference for all supported devices. Refer [here](docs/en/quantization/kv_quant.md) for detailed guide
- \[2024/04\] TurboMind latest upgrade boosts GQA, rocketing the [internlm2-20b](https://huggingface.co/internlm/internlm2-20b) model inference to 16+ RPS, about 1.8x faster than vLLM.
Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| CogVLM2-Chat | 19B | Yes | No | No |
| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No |
| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No |
| Gemma2 | 9B-27B | Yes | No | No |
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| CogVLM2-Chat | 19B | Yes | No | No |
| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No |
| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No |
| Gemma2 | 9B-27B | Yes | No | No |
2 changes: 2 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,11 +1127,13 @@ def __init__(self,
eoh='<end_of_turn>\n',
assistant='<start_of_turn>model\n',
eoa='<end_of_turn>\n',
stop_words=['<end_of_turn>'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Gemma use the stop_words too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I guess. Generate eos should means stop.

**kwargs):
super().__init__(user=user,
eoh=eoh,
assistant=assistant,
eoa=eoa,
stop_words=stop_words,
**kwargs)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/configurations/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class GemmaModelConfigBuilder(AutoModelConfigBuilder):
@classmethod
def condition(cls, hf_config):
"""config."""
return hf_config.model_type == 'gemma'
return hf_config.model_type in ['gemma', 'gemma2']

@classmethod
def build(cls, hf_config, model_path: str = None):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
if msg.status != MessageStatus.RUNNING:
continue
update_token = token
if stop or token in eos_token_id:
stop = stop or token in eos_token_id
if stop:
update_token = _EMPTY_TOKEN
else:
msg.num_new_tokens += 1
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def __rotary_emb_fn(query_states, key_states, value_states):
block_offsets=block_offsets,
)

window_size = getattr(self, 'sliding_window', None)
sm_scale = getattr(self, 'scaling', None)
attn_output = query_states
paged_attention_fwd(
query_states,
Expand All @@ -135,6 +137,8 @@ def __rotary_emb_fn(query_states, key_states, value_states):
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
window_size=window_size,
sm_scale=sm_scale,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
Expand Down
18 changes: 17 additions & 1 deletion lmdeploy/pytorch/models/module_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,28 @@
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma.modeling_gemma.GemmaModel':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel',
'transformers.models.gemma.modeling_gemma.modeling_mistral.GemmaMLP':
'transformers.models.gemma.modeling_gemma.GemmaMLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.gemma.modeling_gemma.GemmaRMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm',
})

# gemma2
MODULE_MAP.update({
'transformers.models.gemma2.modeling_gemma2.Gemma2Attention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma2.modeling_gemma2.Gemma2FlashAttention2':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma2.modeling_gemma2.Gemma2SdpaAttention':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention',
'transformers.models.gemma2.modeling_gemma2.Gemma2Model':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel',
'transformers.models.gemma2.modeling_gemma2.Gemma2MLP':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP',
'transformers.models.gemma2.modeling_gemma2.Gemma2RMSNorm':
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm',
})

# deepseek
MODULE_MAP.update({
'modeling_deepseek.DeepseekAttention':
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
DeepseekV2ForCausalLM=True,
# internvl
InternVLChatModel=True,
# gemma2
Gemma2ForCausalLM=True,
)


Expand Down
Loading