Skip to content

Commit

Permalink
[Kernel][LoRA] Add assertion for punica sgmv kernels (vllm-project#7585)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Sep 23, 2024
1 parent a1883b3 commit 6dad705
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 38 deletions.
5 changes: 5 additions & 0 deletions tests/lora/test_punica_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def test_punica_sgmv(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
Expand All @@ -183,6 +184,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
Expand All @@ -195,6 +197,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
Expand Down Expand Up @@ -347,6 +350,7 @@ def test_punica_expand_nslices(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
Expand All @@ -364,6 +368,7 @@ def test_punica_expand_nslices(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
Expand Down
5 changes: 5 additions & 0 deletions tests/lora/test_punica_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_punica_sgmv(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
Expand All @@ -98,6 +99,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
else:
Expand All @@ -110,6 +112,7 @@ def test_punica_sgmv(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
Expand Down Expand Up @@ -262,6 +265,7 @@ def test_punica_expand_nslices(
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
Expand All @@ -279,6 +283,7 @@ def test_punica_expand_nslices(
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/ops/bgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _bgmv_expand(
corresponding to each batch, An index of -1 means no lora should be
applied.
batches (int): batch size
add_inputs (bool, optional): Defaults to False. adds the final lora
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/ops/bgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _bgmv_expand_slice(
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch, An index of -1 means no lora should be
applied.
slice_offst (int): output_tensor's offst
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
batches (int): batch size
add_inputs (bool, optional): Defaults to False.
Expand Down
16 changes: 10 additions & 6 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _sgmv_expand(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
"""
Expand All @@ -115,17 +116,19 @@ def _sgmv_expand(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
add_inputs (bool, optional): Defaults to False. adds the final lora
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""

Expand All @@ -134,6 +137,7 @@ def _sgmv_expand(
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
Expand Down
18 changes: 11 additions & 7 deletions vllm/lora/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _sgmv_expand_slice(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
Expand All @@ -124,27 +125,30 @@ def _sgmv_expand_slice(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4, 10].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
max_seq_length (int): The max sequence lengths of the sequences
in the batch
slice_offst (int): output_tensor's offst
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
slice_offset (int): output_tensor's offset
slice_size (int): current output_tensor's size
add_inputs (bool, optional): Defaults to False. adds the final lora
results to the output..
add_inputs (bool, optional): Defaults to False, adds the final lora
results to the output.
"""

assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
assert lora_b_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_b_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
Expand Down
16 changes: 10 additions & 6 deletions vllm/lora/ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _sgmv_shrink(
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
"""
Expand All @@ -120,24 +121,27 @@ def _sgmv_shrink(
output_tensor (torch.Tensor): output tensor
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
sequence lengths of the sequences in the batch, used to index
into sequence. E.g.,if the sequence length is [4, 6], it is
into sequence. E.g., if the sequence length is [4, 6], it is
[0, 4].
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
length of the sequences in the batch
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
length of the sequences in the batch.
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
corresponding to each batch. An index of -1 means no lora should be
applied.
batches (int): batch size
max_seq_length (int): The max sequence lengths of the sequences
in the batch
scaling (float): Scaling factor.
max_seq_length (int): The max sequence lengths of the sequences in the
batch.
token_nums (int): The token numbers in the batch. Used to verify if the
token numbers in the inputs matches the one in the metadata.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights.dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
assert lora_a_weights.dtype in [
torch.float16,
torch.bfloat16,
]
assert inputs.size(0) == token_nums
assert inputs.size(1) == lora_a_weights.size(-1)
assert b_seq_start_loc.size(0) == batches
assert lora_indices_tensor.size(0) == batches
Expand Down
38 changes: 21 additions & 17 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def compute_meta(
token_lora_tensor: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
Expand All @@ -43,7 +43,7 @@ def compute_meta(
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()

token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
Expand All @@ -52,7 +52,7 @@ def compute_meta(
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora)
batch_size, max_length, token_nums, no_lora)


# TODO see if this can be vectorized
Expand Down Expand Up @@ -178,7 +178,7 @@ def convert_mapping(
class PunicaWrapper:
"""
PunicaWrapper is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica kernel.
"""

Expand Down Expand Up @@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
dtype=torch.long,
device=device)
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
Expand Down Expand Up @@ -276,13 +277,13 @@ def _update_base_metadata(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()

self.indices_len[:] = indices_len

def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:

(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
batch_size, max_length, token_nums,
no_lora) = compute_meta(token_lora_tensor)

self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
b_seq_start_tensor)
Expand All @@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
lora_indices_tensor)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora

@property
def prefill_metadata(
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
self
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions
2. seq_lengths: Tensor of sequence lengths
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: batch size after clustering identical lora indices
5. max_length: The maximum sequence length in the batch
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (self._seq_start_locs[:self.batch_size],
self._seq_lengths[:self.batch_size],
self._lora_indices_per_batch[:self.batch_size],
self.batch_size, self.max_length)
self.batch_size, self.max_length, self.token_nums)

@property
def token_lora_indices(self) -> torch.Tensor:
Expand All @@ -324,15 +328,15 @@ def token_lora_indices(self) -> torch.Tensor:
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]

@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
Expand All @@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor:
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
Expand All @@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor:
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLora
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
Expand Down Expand Up @@ -524,7 +528,7 @@ def add_lora(self,
scale (float): Scaling factor.
y_offset (Optional[int], optional): Offset to apply to the starting
column of y.
y_slice_size (Optional[int], optional): Size of the y column slice..
y_slice_size (Optional[int], optional): Size of the y column slice.
buffer (Optional[torch.Tensor], optional): Defaults to None.
"""
y_org = y
Expand Down

0 comments on commit 6dad705

Please sign in to comment.