From ab5b7ce3820b1cb4a999895991e57cd570103f11 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 5 Jul 2024 20:16:22 +0800 Subject: [PATCH] support gemma2 in pytorch engine(#1924) * add gemma2 support * support gemma2 * eos stop --- README.md | 2 +- docs/en/supported_models/supported_models.md | 1 + .../zh_cn/supported_models/supported_models.md | 1 + lmdeploy/model.py | 2 ++ lmdeploy/pytorch/configurations/gemma.py | 2 +- lmdeploy/pytorch/engine/engine.py | 3 ++- lmdeploy/pytorch/models/gemma.py | 4 ++++ lmdeploy/pytorch/models/module_map.py | 18 +++++++++++++++++- lmdeploy/pytorch/supported_models.py | 2 ++ 9 files changed, 31 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7e4055f2d..77477e464 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ ______________________________________________________________________ - \[2024/06\] PyTorch engine support DeepSeek-V2 and several VLMs, such as CogVLM2, Mini-InternVL, LlaVA-Next - \[2024/05\] Balance vision model when deploying VLMs with multiple GPUs -- \[2024/05\] Support 4-bits weight-only quantization and inference on VMLs, such as InternVL v1.5, LLaVa, InternLMXComposer2 +- \[2024/05\] Support 4-bits weight-only quantization and inference on VLMs, such as InternVL v1.5, LLaVa, InternLMXComposer2 - \[2024/04\] Support Llama3 and more VLMs, such as InternVL v1.1, v1.2, MiniGemini, InternLMXComposer2. - \[2024/04\] TurboMind adds online int8/int4 KV cache quantization and inference for all supported devices. Refer [here](docs/en/quantization/kv_quant.md) for detailed guide - \[2024/04\] TurboMind latest upgrade boosts GQA, rocketing the [internlm2-20b](https://huggingface.co/internlm/internlm2-20b) model inference to 16+ RPS, about 1.8x faster than vLLM. diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index b4c1fefa3..fe1a8d91b 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -65,3 +65,4 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | CogVLM2-Chat | 19B | Yes | No | No | | LLaVA(1.5,1.6) | 7B-34B | Yes | No | No | | InternVL-Chat(v1.5) | 2B-26B | Yes | No | No | +| Gemma2 | 9B-27B | Yes | No | No | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 789b43b00..4bb0e73c1 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -65,3 +65,4 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | CogVLM2-Chat | 19B | Yes | No | No | | LLaVA(1.5,1.6) | 7B-34B | Yes | No | No | | InternVL-Chat(v1.5) | 2B-26B | Yes | No | No | +| Gemma2 | 9B-27B | Yes | No | No | diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 63c2834e7..f757eb64f 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -1151,11 +1151,13 @@ def __init__(self, eoh='\n', assistant='model\n', eoa='\n', + stop_words=[''], **kwargs): super().__init__(user=user, eoh=eoh, assistant=assistant, eoa=eoa, + stop_words=stop_words, **kwargs) @classmethod diff --git a/lmdeploy/pytorch/configurations/gemma.py b/lmdeploy/pytorch/configurations/gemma.py index 6f97a2162..338eaee6d 100644 --- a/lmdeploy/pytorch/configurations/gemma.py +++ b/lmdeploy/pytorch/configurations/gemma.py @@ -9,7 +9,7 @@ class GemmaModelConfigBuilder(AutoModelConfigBuilder): @classmethod def condition(cls, hf_config): """config.""" - return hf_config.model_type == 'gemma' + return hf_config.model_type in ['gemma', 'gemma2'] @classmethod def build(cls, hf_config, model_path: str = None): diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 914546639..dd50bbbe4 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -512,7 +512,8 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor, if msg.status != MessageStatus.RUNNING: continue update_token = token - if stop or token in eos_token_id: + stop = stop or token in eos_token_id + if stop: update_token = _EMPTY_TOKEN else: msg.num_new_tokens += 1 diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index f0c93c7cf..9bf78b87a 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -124,6 +124,8 @@ def __rotary_emb_fn(query_states, key_states, value_states): block_offsets=block_offsets, ) + window_size = getattr(self, 'sliding_window', None) + sm_scale = getattr(self, 'scaling', None) attn_output = query_states paged_attention_fwd( query_states, @@ -135,6 +137,8 @@ def __rotary_emb_fn(query_states, key_states, value_states): q_seqlens=q_seq_length, kv_seqlens=kv_seq_length, max_seqlen=max_q_seq_length, + window_size=window_size, + sm_scale=sm_scale, ) attn_output = attn_output.reshape(*hidden_states.shape[:-1], hidden_size) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 8ef359765..bb68711bc 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -125,12 +125,28 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', 'transformers.models.gemma.modeling_gemma.GemmaModel': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel', - 'transformers.models.gemma.modeling_gemma.modeling_mistral.GemmaMLP': + 'transformers.models.gemma.modeling_gemma.GemmaMLP': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', 'transformers.models.gemma.modeling_gemma.GemmaRMSNorm': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm', }) +# gemma2 +MODULE_MAP.update({ + 'transformers.models.gemma2.modeling_gemma2.Gemma2Attention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', + 'transformers.models.gemma2.modeling_gemma2.Gemma2FlashAttention2': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', + 'transformers.models.gemma2.modeling_gemma2.Gemma2SdpaAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaAttention', + 'transformers.models.gemma2.modeling_gemma2.Gemma2Model': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaModel', + 'transformers.models.gemma2.modeling_gemma2.Gemma2MLP': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.llama.LlamaMLP', + 'transformers.models.gemma2.modeling_gemma2.Gemma2RMSNorm': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.gemma.PatchedGemmaRMSNorm', +}) + # deepseek MODULE_MAP.update({ 'modeling_deepseek.DeepseekAttention': diff --git a/lmdeploy/pytorch/supported_models.py b/lmdeploy/pytorch/supported_models.py index b1882b046..ad9eaa84f 100644 --- a/lmdeploy/pytorch/supported_models.py +++ b/lmdeploy/pytorch/supported_models.py @@ -60,6 +60,8 @@ DeepseekV2ForCausalLM=True, # internvl InternVLChatModel=True, + # gemma2 + Gemma2ForCausalLM=True, )