Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Tencent Hunyuan model support #811

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aphrodite/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

num_tokens, hidden_size = query.shape
num_tokens = query.shape[0]
hidden_size = query.numel() // num_tokens
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
Expand Down
5 changes: 5 additions & 0 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"NemotronForCausalLM",
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"HunYuanForCausalLM",
]

_OPTIMIZED_QUANTS = [
Expand Down Expand Up @@ -1795,6 +1796,10 @@ def _get_and_verify_max_len(
if rope_type == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]

if rope_scaling["type"] == "dynamic" and "alpha" in rope_scaling:
scaling_factor = 1

derived_max_model_len *= scaling_factor

if max_model_len is None:
Expand Down
50 changes: 47 additions & 3 deletions aphrodite/modeling/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
# Copyright 2023 The PygmalionAI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.

# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
# (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -389,6 +398,36 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_alpha: float,
dtype: torch.dtype,
) -> None:
self.scaling_alpha = scaling_alpha
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_cos_sin_cache(self) -> torch.Tensor:
max_len = self.max_position_embeddings
base = self.base * self.scaling_alpha ** (
self.rotary_dim / (self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
Expand Down Expand Up @@ -835,9 +874,14 @@ def get_rope(
is_neox_style,
scaling_factor, dtype)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
if "alpha" in rope_scaling:
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling["alpha"], dtype)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
Expand Down
1 change: 1 addition & 0 deletions aphrodite/modeling/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"HunYuanForCausalLM": ("hunyuan", "HunYuanForCausalLM"),
}

_EMBEDDING_MODELS = {
Expand Down
Loading
Loading