diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 314d6215cbd9c..41c37a4813c68 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -169,6 +169,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -183,6 +184,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -195,6 +197,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -347,6 +350,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -364,6 +368,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 28a395af19e6d..185da6399a06a 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -84,6 +84,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -98,6 +99,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -110,6 +112,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -262,6 +265,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -279,6 +283,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 619408b9315cf..6a32387a6f36c 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -100,7 +100,7 @@ def _bgmv_expand( corresponding to each batch, An index of -1 means no lora should be applied. batches (int): batch size - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index c16db233891a5..73628fd20d327 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -104,7 +104,7 @@ def _bgmv_expand_slice( lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch, An index of -1 means no lora should be applied. - slice_offst (int): output_tensor's offst + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size batches (int): batch size add_inputs (bool, optional): Defaults to False. diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index c71332d8bdfb2..adb3ab5b46b87 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -106,6 +106,7 @@ def _sgmv_expand( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, add_inputs: bool = False, ) -> None: """ @@ -115,17 +116,19 @@ def _sgmv_expand( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - add_inputs (bool, optional): Defaults to False. adds the final lora + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ @@ -134,6 +137,7 @@ def _sgmv_expand( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index b4ae9a2acbb5c..efa234520ab87 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -112,6 +112,7 @@ def _sgmv_expand_slice( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False, @@ -124,20 +125,22 @@ def _sgmv_expand_slice( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences + max_seq_length (int): The max sequence lengths of the sequences in the batch - slice_offst (int): output_tensor's offst + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False. adds the final lora - results to the output.. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] @@ -145,6 +148,7 @@ def _sgmv_expand_slice( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c0791c260e915..c003f3dc0ce9e 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -110,6 +110,7 @@ def _sgmv_shrink( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, scaling: float, ) -> None: """ @@ -120,17 +121,19 @@ def _sgmv_shrink( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - scaling (float): Scaling factor. + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -138,6 +141,7 @@ def _sgmv_shrink( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 6d5c834299961..5033ce4126929 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -27,7 +27,7 @@ def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -43,7 +43,7 @@ def compute_meta( b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) max_length = seq_length_tensor.max().item() - + token_nums = seq_length_tensor.sum().item() batch_size = lora_indices_tensor.size(0) no_lora = False # -1 means no lora should be applied. Use `no_lora` to determine whether @@ -52,7 +52,7 @@ def compute_meta( if batch_size == 1 and lora_indices_tensor == -1: no_lora = True return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) + batch_size, max_length, token_nums, no_lora) # TODO see if this can be vectorized @@ -178,7 +178,7 @@ def convert_mapping( class PunicaWrapper: """ PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica kernel. """ @@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, dtype=torch.long, device=device) self.max_length: int = 0 + self.token_nums: int = 0 self.batch_size: int = -1 self.is_prefill = False self.no_lora = False @@ -276,13 +277,13 @@ def _update_base_metadata( long_lora_offsets_tensor) else: self._long_lora_indices.zero_() - self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( b_seq_start_tensor) @@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length + self.token_nums = token_nums self.no_lora = no_lora @property def prefill_metadata( - self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions - 2. seq_lengths: Tensor of sequence lengths + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. - 4. batch_size: batch size after clustering identical lora indices - 5. max_length: The maximum sequence length in the batch + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. """ return (self._seq_start_locs[:self.batch_size], self._seq_lengths[:self.batch_size], self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length) + self.batch_size, self.max_length, self.token_nums) @property def token_lora_indices(self) -> torch.Tensor: @@ -324,7 +328,7 @@ def token_lora_indices(self) -> torch.Tensor: def sampler_indices(self) -> torch.Tensor: """ This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA + LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] return self._sampler_indices[:sampler_indices_len] @@ -332,7 +336,7 @@ def sampler_indices(self) -> torch.Tensor: @property def sampler_indices_padded(self) -> torch.Tensor: """ - This property provides access to padded sampler indices + This property provides access to padded sampler indices. """ indices_padded_len = self.indices_len[2] return self._sampler_indices_padded[:indices_padded_len] @@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor: def embeddings_indices(self) -> torch.Tensor: """ This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA + specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] @@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor: """ This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora + lora, specifically for LinearScalingRotaryEmbeddingWithLora. """ long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] @@ -524,7 +528,7 @@ def add_lora(self, scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. - y_slice_size (Optional[int], optional): Size of the y column slice.. + y_slice_size (Optional[int], optional): Size of the y column slice. buffer (Optional[torch.Tensor], optional): Defaults to None. """ y_org = y