From 6eb093f3abd36d72aea6d1efd04e9e35abdaf8ff Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 1 Oct 2024 17:51:41 +0800 Subject: [PATCH] [Bugfix] Fix Fuyu tensor parallel inference (#8986) Signed-off-by: Sumit Dubey --- tests/distributed/test_pipeline_parallel.py | 4 +++- vllm/model_executor/models/fuyu.py | 3 ++- vllm/model_executor/models/persimmon.py | 20 ++++++++++---------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 9fd1368cc2b59..2e8e83c3d271b 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -37,7 +37,9 @@ (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), - (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp") + (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"), + # TP only models + (2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"), ], ) @fork_new_process_for_each_test diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 9f4dca78d435d..87b88da0dc05c 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -237,8 +237,9 @@ def __init__(self, self.image_feature_size, config.hidden_size, quant_config=quant_config, + gather_output=True, ) - self.language_model = PersimmonForCausalLM(config, + self.language_model = PersimmonForCausalLM(config.text_config, cache_config=cache_config, quant_config=quant_config) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index ced846cbe3358..fda0602110a0b 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -25,11 +25,11 @@ import torch from torch import nn from transformers import PersimmonConfig -from transformers.activations import ReLUSquaredActivation from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) @@ -57,7 +57,7 @@ def __init__(self, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size, quant_config=quant_config) - self.act = ReLUSquaredActivation() + self.act = get_act_fn(config.hidden_act, quant_config) def forward(self, hidden_states) -> torch.Tensor: hidden_states, _ = self.dense_h_to_4h(hidden_states) @@ -96,7 +96,7 @@ def __init__(self, quant_config=quant_config, ) self.dense = RowParallelLinear( - self.num_heads * self.head_dim, + self.total_num_heads * self.head_dim, self.hidden_size, bias=True, quant_config=quant_config, @@ -213,10 +213,10 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.vocab_size = config.text_config.vocab_size + self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.text_config.vocab_size, config.hidden_size) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) self.layers = nn.ModuleList([ PersimmonDecoderLayer(config, cache_config=cache_config, @@ -252,19 +252,19 @@ def forward( class PersimmonForCausalLM(nn.Module): def __init__(self, - config, + config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.vocab_size = config.text_config.vocab_size + self.vocab_size = config.vocab_size self.model = PersimmonModel(config, cache_config=cache_config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.text_config.vocab_size, + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False) - self.logits_processor = LogitsProcessor(config.text_config.vocab_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() def forward(