Skip to content

Commit

Permalink
[V1] Simplify M-RoPE (#12352)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: imkero <kerorek@outlook.com>
  • Loading branch information
ywang96 and imkero authored Jan 23, 2025
1 parent d07efb3 commit 99d01a5
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,28 +144,24 @@ def __init__(

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.model_config.uses_mrope:
# NOTE: `mrope_positions` is implemented as a permuted tensor to
# satisfy the following properties to allow `torch.compile` to work
# properly:
# - shape: (3, <variable>)
# - stride: (1, 3)
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1921022256
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923

# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros((self.max_num_tokens, 3),
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
dtype=torch.int64,
device=self.device)
self.mrope_positions_cpu = torch.zeros((self.max_num_tokens, 3),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)

self.mrope_positions = self.mrope_positions.permute((1, 0))
self.mrope_positions_cpu = self.mrope_positions_cpu.permute((1, 0))
self.mrope_positions_cpu = torch.zeros(
(3, self.max_num_tokens + 1),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)

self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
Expand Down

0 comments on commit 99d01a5

Please sign in to comment.