Skip to content

Commit

Permalink
[Model] Support Mistral-Nemo (#6548)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Jul 18, 2024
1 parent ecdb462 commit 15c6a07
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class LlamaAttention(nn.Module):

def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
Expand All @@ -115,7 +116,9 @@ def __init__(
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
Expand Down Expand Up @@ -189,6 +192,7 @@ def __init__(
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
Expand Down

0 comments on commit 15c6a07

Please sign in to comment.