Skip to content

Commit

Permalink
[MODEL] HunYuan MoE Support
Browse files Browse the repository at this point in the history
  • Loading branch information
quinnrong94 committed Oct 31, 2024
1 parent 09c7792 commit 5302fbf
Show file tree
Hide file tree
Showing 5 changed files with 771 additions and 4 deletions.
3 changes: 2 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

num_tokens, hidden_size = query.shape
num_tokens = query.shape[0]
hidden_size = query.numel() // num_tokens
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
Expand Down
6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
"HunYuanForCausalLM",
]


Expand Down Expand Up @@ -1652,6 +1653,11 @@ def _get_and_verify_max_len(
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]

# see DynamicNTKAlphaRotaryEmbedding
if rope_scaling["type"] == "dynamic" and "alpha" in rope_scaling:
scaling_factor = 1

derived_max_model_len *= scaling_factor

# If the user specified a max length, make sure it is smaller than the
Expand Down
52 changes: 49 additions & 3 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -384,6 +391,40 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)

def _compute_cos_sin_cache(self) -> torch.Tensor:
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha ** (self.rotary_dim / (self.rotary_dim - 2))

inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
Expand Down Expand Up @@ -823,9 +864,14 @@ def get_rope(
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
if "alpha" in rope_scaling:
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling["alpha"], dtype)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"HunYuanForCausalLM": ("hunyuan", "HunYuanForCausalLM"),
}

_EMBEDDING_MODELS = {
Expand Down
Loading

0 comments on commit 5302fbf

Please sign in to comment.