Skip to content

Commit

Permalink
feat: use FlashInfer RoPE (llama)
Browse files Browse the repository at this point in the history
  • Loading branch information
james-p-xu committed Nov 16, 2024
1 parent dd0d2a3 commit 6bb3979
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 1 deletion.
265 changes: 265 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from einops import rearrange
from vllm.model_executor.custom_op import CustomOp


class MRotaryEmbedding:
Expand Down Expand Up @@ -110,3 +112,266 @@ def get_next_input_positions(
)
for _ in range(3)
]


class RotaryEmbedding(CustomOp):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype

def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from flashinfer import apply_rope_pos_ids_inplace

if offsets is not None:
positions = positions + offsets
seq_len, num_q_heads, num_k_heads = (
positions.shape[0],
query.shape[1] // self.head_size,
key.shape[1] // self.head_size,
)

# (seq_len, num_heads * head_dim) -> flashinfer input shape (nnz=seq_len, num_heads, head_dim)
flashinfer_query, flashinfer_key = rearrange(
query.type(torch.float16),
"s (n_h h_d) -> s n_h h_d",
n_h=num_q_heads,
h_d=self.head_size,
), rearrange(
key.type(torch.float16),
"s (n_h h_d) -> s n_h h_d",
n_h=num_k_heads,
h_d=self.head_size,
)
apply_rope_pos_ids_inplace(
flashinfer_query,
flashinfer_key,
pos_ids=positions,
rotary_dim=self.rotary_dim,
rope_theta=self.base,
interleave=(not self.is_neox_style),
)

# flashinfer output shape (nnz=seq_len, num_heads, head_dim) -> (seq_len, num_heads * head_dim)
return rearrange(
flashinfer_query.type(self.dtype), "s n_h h_d -> s (n_h h_d)"
), rearrange(flashinfer_key.type(self.dtype), "s n_h h_d -> s (n_h h_d)")

def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


class Llama3RotaryEmbedding(RotaryEmbedding):

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)


def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]

if rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
scaling_type = rope_scaling["rope_type"]

if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
scaling_factor,
low_freq_factor,
high_freq_factor,
original_max_position,
)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
)
# else:
# rotary_emb = RotaryEmbedding(
# head_size,
# rotary_dim,
# max_position,
# base,
# is_neox_style,
# dtype,
# )
# elif scaling_type == "linear":
# scaling_factor = rope_scaling["factor"]
# rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style,
# scaling_factor, dtype)
# elif scaling_type == "dynamic":
# scaling_factor = rope_scaling["factor"]
# rotary_emb = DynamicNTKScalingRotaryEmbedding(
# head_size, rotary_dim, max_position, base, is_neox_style,
# scaling_factor, dtype)
# elif scaling_type == "yarn":
# scaling_factor = rope_scaling["factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
# "beta_slow")
# }
# rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
# original_max_position,
# base, is_neox_style,
# scaling_factor, dtype,
# **extra_kwargs)
# elif scaling_type == "deepseek_yarn":
# scaling_factor = rope_scaling["factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# # assert max_position == original_max_position * scaling_factor
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
# "beta_slow", "mscale", "mscale_all_dim")
# }
# rotary_emb = DeepseekScalingRotaryEmbedding(
# head_size, rotary_dim, original_max_position, base,
# is_neox_style, scaling_factor, dtype, **extra_kwargs)
# elif scaling_type == "longrope":
# short_factor = rope_scaling["short_factor"]
# long_factor = rope_scaling["long_factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("short_mscale", "long_mscale")
# }
# rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
# head_size, rotary_dim, max_position, original_max_position,
# base, is_neox_style, dtype, short_factor, long_factor,
# **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
2 changes: 1 addition & 1 deletion python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.activation import SiluAndMul
Expand All @@ -37,6 +36,7 @@
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
Expand Down

0 comments on commit 6bb3979

Please sign in to comment.