diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index ba9d2d4389b21..6ca3a645c7771 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -5,6 +5,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.platforms import current_platform from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -20,6 +21,9 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + CUDA_DEVICES = ['hpu'] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -65,6 +69,8 @@ def test_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) + if current_platform.is_hpu(): + htorch.core.mark_step() out_query, out_key = rope.forward(positions, query, key) # Compare the results. torch.testing.assert_close(out_query, @@ -120,6 +126,8 @@ def test_batched_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) + if current_platform.is_hpu(): + htorch.core.mark_step() out_query, out_key = rope.forward(positions, query, key, @@ -193,6 +201,8 @@ def test_batched_rotary_embedding_multi_lora( # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key, query_offsets) + if current_platform.is_hpu(): + htorch.core.mark_step() out_query, out_key = rope.forward(positions, query, key, query_offsets.flatten()) # Compare the results. diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 85cd700c978ea..10626d53338e3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,7 +28,6 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -195,6 +194,61 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, apply_rotary_pos_emb) + positions = positions.flatten() + if offsets is not None: + positions = positions + offsets + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view( + num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, + 2, + dim=-1, + output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, + rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" @@ -918,17 +972,8 @@ def get_rope( return _ROPE_DICT[key] if rope_scaling is None: - if current_platform.is_hpu(): - from vllm_hpu_extension.rotary_embed import HpuRotaryEmbedding - rotary_emb = HpuRotaryEmbedding(head_size, - rotary_dim, - max_position, - base, - is_neox_style, - RoPEFallback=RotaryEmbedding) - else: - rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, - base, is_neox_style, dtype) + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) else: scaling_type = rope_scaling[ "type"] if "type" in rope_scaling else rope_scaling["rope_type"] @@ -941,25 +986,12 @@ def get_rope( high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] - if current_platform.is_hpu(): - from vllm_hpu_extension.rotary_embed import ( - HpuLlama3RotaryEmbedding) - rotary_emb = HpuLlama3RotaryEmbedding( - head_size, - rotary_dim, - max_position, - base, - is_neox_style, - scaling_factor, - low_freq_factor, - high_freq_factor, - original_max_position, - RoPEFallback=Llama3RotaryEmbedding) - else: - rotary_emb = Llama3RotaryEmbedding( - head_size, rotary_dim, max_position, base, is_neox_style, - dtype, scaling_factor, low_freq_factor, high_freq_factor, - original_max_position) + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) elif scaling_type == "linear": rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, max_position, base,