From c636ac2ee0339dcaae9ff1b2493f93ec9d5d752d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 29 Feb 2024 11:49:13 -0800 Subject: [PATCH 01/16] WIP --- vllm/model_executor/layers/sampler.py | 86 ++++- .../layers/triton_kernel/sample.py | 331 ++++++++++++++++++ vllm/model_executor/sampling_metadata.py | 126 ++++++- vllm/sequence.py | 6 + vllm/worker/model_runner.py | 38 +- 5 files changed, 571 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/layers/triton_kernel/sample.py diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 884d84387e505..5bfaef1c801cb 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,6 +10,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.model_executor.layers.triton_kernel.sample import sample as sample_triton class Sampler(nn.Module): @@ -105,7 +106,9 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) + sample_results = _sample_with_triton_kernel(probs, logprobs, + sampling_metadata, + sampling_tensors) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) @@ -385,7 +388,7 @@ def _sample( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] + sample_indices = categorized_sample_indices[sampling_type][:, 0] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -438,6 +441,85 @@ def _sample( return sample_results +def _sample_with_triton_kernel( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + categorized_seq_group_ids = {t: [] for t in SamplingType} + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + _, sampling_params = seq_group + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + max_best_of = 1 + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type][:, 0] + sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices, + sampled_token_indices) + if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of = max(max_best_of, sampling_params.best_of) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + sampled_tokens, _, _ = sample_triton( + probs=probs, + seeds=sampling_tensors.sampling_seeds, + max_best_of=max_best_of, + sample_indices=sampling_tensors.sample_indices, + logprobs=logprobs, + # don't save logprobs because we have logic for that below + # TODO: use this instead of the CPU-based logic below + save_logprobs=False, + ) + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + seq_group_ids, seq_groups, is_prompts, sample_indices, sampled_token_indices = sample_metadata[ + sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample( + seq_groups, sampled_tokens[sampled_token_indices][:, 0]) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + sampling_metadata.seq_data, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_ids, sample_results)) + + sample_results = [ + sample_results_dict[i] + for i in range(len(sampling_metadata.seq_groups)) + ] + return sample_results + + def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/layers/triton_kernel/sample.py b/vllm/model_executor/layers/triton_kernel/sample.py new file mode 100644 index 0000000000000..14c5c9d3a2890 --- /dev/null +++ b/vllm/model_executor/layers/triton_kernel/sample.py @@ -0,0 +1,331 @@ +import torch +import triton +import triton.language as tl +import math +from typing import Tuple, Optional + +_EPS = 1e-6 + +# This is a hardcoded limit in Triton (max block size). +MAX_TRITON_N_COLS = 131072 + + +def get_num_triton_sampler_splits(n_cols: int) -> int: + """Get the number of splits to use for Triton sampling. + + Triton has a limit on the number of columns it can handle, so we need to + split the tensor and call the kernel multiple times if it's too large. + """ + return math.ceil(n_cols / MAX_TRITON_N_COLS) + + +def _multi_split_sample( + probs: torch.Tensor, + seeds: torch.Tensor, + n_splits: int, + sampled_tokens_size: Tuple[int, int], + sampled_logprobs_size: Tuple[int, int], + sample_indices: torch.Tensor, + *, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, +): + """Sample tokens where vocab size is split into multiple parts + (too large for Triton otherwise).""" + assert seeds.ndim == 2 and seeds.shape[0] == n_splits + split_probs = probs.tensor_split(n_splits, 1) + split_logprobs = logprobs.tensor_split(n_splits, 1) + sampled_tokens_tmp = [ + torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) + for _ in range(n_splits) + ] + sampled_logprobs_tmp = [ + torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + # We are purposefuly using sampled_tokens_size as we need to always + # save modified probs in this case. + sampled_modified_probs_tmp = [ + torch.empty(sampled_tokens_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + for i in range(n_splits): + # TODO(yard1): See if we can remove the contiguous() calls. + # Will need kernel support. + _sample( + split_probs[i], + split_logprobs[i], + sample_indices, + sampled_tokens_tmp[i], + sampled_logprobs_tmp[i], + sampled_modified_probs_tmp[i], + seeds[i], + modify_greedy_probs=modify_greedy_probs, + # Don't save logprobs in kernel, we need to gather them + # below + save_logprobs=False, + save_modified_probs=True, + ) + if i > 0: + # Add offset to sampled tokens + sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) + sampled_tokens = torch.stack(sampled_tokens_tmp) + sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) + # Reduce the results from the splits. + sampled_modified_probs, indices = torch.max(sampled_modified_probs, + dim=0, + keepdim=True) + sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) + if save_logprobs: + sampled_logprobs = torch.stack(sampled_logprobs_tmp) + sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) + else: + sampled_logprobs = None + sampled_modified_probs = sampled_modified_probs.squeeze(0) + if modify_greedy_probs: + # We need to modify the greedy probs for the sampled tokens. + # We can't do this in the kernel as we need to know the + # sampled tokens. + probs.fill_(0.0) + probs.scatter_(1, sampled_tokens, 1.0) + return (sampled_tokens, sampled_logprobs, sampled_modified_probs) + + +def sample( + probs: torch.Tensor, + seeds: torch.Tensor, + *, + max_best_of: int = 1, + sample_indices: Optional[torch.Tensor] = None, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, + _save_modified_probs: bool = False, # pylint: disable=invalid-name +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Sample tokens from probs. with per-sequence seeds. + + Can sample from a subset of sequences through sample_indices. + + Args: + probs: Probabilities to sample from. + shape = [batch_size, vocab_size] + seeds: Per-sequence seed values. + shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] + max_best_of: Number of samples to generate per sequence. + Sequence seed will be incremented by 1 each time. + sample_indices: Indices of sequences to sample from. + If not provided, will sample from all sequences. + shape = [n] + logprobs: Log-probabilities of the sampled tokens. + Only used for saving the logprobs if save_logprobs is True. + shape = [batch_size, vocab_size] + modify_greedy_probs: Whether to modify the greedy probabilities + for speculative sampling (sampled token = 1.0, + everything else = 0.0). + save_logprobs: Whether to save the log-probabilities of the + sampled tokens to a tensor. + _save_modified_probs: Whether to save the modified probabilities + (including gumbel noise) of the sampled tokens to a tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + This is exposed only for testing. + + Returns: + sampled_tokens: shape = [n, max_best_of] + sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None + sampled_modified_probs: shape = [n, max_best_of] + if save_modified_probs else None + """ + if sample_indices is None: + sample_indices = torch.arange(0, probs.shape[0], device=probs.device) + + sampled_tokens_size = (sample_indices.size(0), max_best_of) + if save_logprobs: + if logprobs is None: + raise ValueError( + "logprobs tensor must be provided if save_logprobs is True") + sampled_logprobs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_logprobs_size = (0, 0) + logprobs = probs + + if _save_modified_probs: + sampled_modified_probs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_modified_probs_size = (0, 0) + + # If the number of columns in probs is too large for Triton to handle, + # we split the tensor and sample from each split separately, and then + # do an argmax+gather to combine the results. + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if n_splits > 1: + (sampled_tokens, sampled_logprobs, + sampled_modified_probs) = _multi_split_sample( + probs, + seeds, + n_splits, + sampled_tokens_size, + sampled_logprobs_size, + sample_indices, + logprobs=logprobs, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs) + else: + sampled_tokens = torch.empty(sampled_tokens_size, + dtype=torch.long, + device=probs.device) + sampled_logprobs = torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) + sampled_modified_probs = torch.empty(sampled_modified_probs_size, + dtype=probs.dtype, + device=probs.device) + + _sample(probs, + logprobs, + sample_indices, + sampled_tokens, + sampled_logprobs, + sampled_modified_probs, + seeds, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=_save_modified_probs) + return (sampled_tokens, sampled_logprobs if save_logprobs else None, + sampled_modified_probs if _save_modified_probs else None) + + +def _sample(probs: torch.Tensor, + logprobs: torch.Tensor, + sample_indices: torch.Tensor, + output_samples: torch.Tensor, + output_logprobs: torch.Tensor, + output_modified_probs: torch.Tensor, + seeds: torch.Tensor, + *, + modify_greedy_probs: bool = False, + save_logprobs: bool = True, + save_modified_probs: bool = False) -> None: + # Operates in place. + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 + # The block size is the smallest power of two greater than the number of + # columns in probs + block_size = triton.next_power_of_2(n_cols) + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if block_size >= 8192: + num_warps = 32 + elif block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + # Enqueue kernel. The 1D launch grid is simple: we have one kernel + # instance per row of the probs matrix + _sample_triton_kernel[(n_samples, n_best)]( + sample_indices, + output_samples, + output_logprobs, + output_modified_probs, + probs, + logprobs, + seeds, + output_samples.stride(0), + probs.stride(0), + n_samples, + n_cols, + n_best, + num_warps=num_warps, + block_size=block_size, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=save_modified_probs, + ) + + +@triton.jit +def _sample_triton_kernel( + sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, + output_logprobs_ptr: torch.Tensor, + output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, seed_ptr: torch.Tensor, + output_row_stride: int, probs_row_stride: int, n_samples: int, + n_cols: int, n_best: int, block_size: tl.constexpr, + modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, + save_modified_probs: tl.constexpr): + # The rows are independent, so we parallelize across those + sample_idx = tl.program_id(0) + best_idx = tl.program_id(1) + + # Load the row index from DRAM + row_idx = tl.load(sample_indices_ptr + sample_idx) + + # The stride represents how much we need to increase the + # pointer to advance 1 row + row_start_ptr = probs_ptr + row_idx * probs_row_stride + + # The block size is the next power of two greater than n_cols, + # so we can fit each row in a single block + col_offsets = tl.arange(0, block_size) + probs_ptrs = row_start_ptr + col_offsets + + # Load the row into SRAM, using a mask since block_size may be > than n_cols + row = tl.load(probs_ptrs, mask=col_offsets < n_cols, other=float("-inf")) + seed = tl.load(seed_ptr + sample_idx) + uses_random_sampling = seed != 0 + + if uses_random_sampling: + random_uniform = tl.rand(seed + best_idx, col_offsets) + + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by nan later + lb = tl.full(random_uniform.shape, _EPS, random_uniform.dtype) + random_uniform = tl.maximum(random_uniform, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + random_exponential = -tl.log(random_uniform) + + row /= random_exponential + + sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) + # clamp sampled token to n_cols - 1 + # this should not be necessary, but we do it + # just in case + if sampled_token >= n_cols: + sampled_token = n_cols - 1 + # Write back output to DRAM + output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + + best_idx) + tl.store(output_row_start_ptr, sampled_token) + + if modify_greedy_probs: # noqa + if not uses_random_sampling: + # Set the probability of the sampled token to 1, all other + # tokens to zero. This is used in speculative decoding where + # the sampling method must be encoded within the sampled + # probability distributions. + row = tl.where(col_offsets == sampled_token, 1.0, 0.0) + tl.store(probs_ptrs, row, mask=col_offsets < n_cols) + + if save_modified_probs: + output_row_start_ptr = (output_modified_probs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_value) + + if save_logprobs: + # Load the row into SRAM, using a mask since block_size + # may be > than n_cols + sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + + sampled_token) + # Write back output to DRAM + output_row_start_ptr = (output_logprobs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d0ffeecd2d74d..d74d62b12dd39 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -2,12 +2,15 @@ from typing import Dict, List, Optional, Tuple import torch +import random from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl +from vllm.model_executor.layers.triton_kernel.sample import get_num_triton_sampler_splits _SAMPLING_EPS = 1e-5 +_SEED_0_REPLACEMENT = 3403598558 class SamplingMetadata: @@ -67,14 +70,28 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor + sampling_seeds: torch.Tensor + sample_indices: torch.Tensor + extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @classmethod def from_sampling_metadata( - cls, sampling_metadata: "SamplingMetadata", vocab_size: int, - device: torch.device, - dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + *, + extra_seeds_to_generate: int = 0, + extra_entropy: Optional[Tuple[int, ...]] = None + ) -> Tuple["SamplingTensors", bool, bool, bool]: + """ + extra_seeds_to_generate: extra seeds to generate using the + user-defined seed for each sequence. + extra_entropy: extra entropy to use when generating seeds. + """ prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] top_ks: List[int] = [] @@ -84,9 +101,18 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] + sampling_seeds: List[int] = [] + sample_indices: List[int] = [] + prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) + + sample_indices_start_idx = 0 for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -95,6 +121,10 @@ def from_sampling_metadata( r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p + seed = sampling_params.seed + + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) top_k = vocab_size if top_k == -1 else top_k @@ -112,6 +142,7 @@ def from_sampling_metadata( or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get their logprobs @@ -137,10 +168,34 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) + is_prompt = i < sampling_metadata.num_prompts + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + prompt_len = sampling_metadata.prompt_lens[i] + + if sampling_params.prompt_logprobs is not None: + # NOTE: the sampling position is the last token + # in the prompt + sample_indices_start_idx += prompt_len - 1 + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + extra_entropy = extra_entropy or () + # extra_entropy = extra_entropy + (seq_id, ) + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.append(sample_indices_start_idx) + sample_indices_start_idx += 1 + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, prompt_tokens, - output_tokens, vocab_size, device, dtype) + frequency_penalties, repetition_penalties, sampling_seeds, + sample_indices, prompt_tokens, output_tokens, vocab_size, + extra_seeds_to_generate, device, dtype) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod @@ -149,9 +204,10 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], + sampling_seeds: List[int], sample_indices: List[int], prompt_tokens: List[List[int]], output_tokens: List[List[int]], vocab_size: int, - device: torch.device, + extra_seeds_to_generate: int, device: torch.device, dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. @@ -209,6 +265,12 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) + sample_indices_t = torch.tensor( + sample_indices, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) prompt_tensor = torch.tensor( prompt_padded_tokens, device="cpu", @@ -221,8 +283,28 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.long, pin_memory=pin_memory, ) + # need to transpose and make contiguous to + # copy the tensor correctly. + # [batch_size, n_seeds] -> [n_seeds, batch_size] + sampling_seeds_t = torch.tensor( + sampling_seeds, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ).T.contiguous() + # Because the memory is pinned, we can do non-blocking # transfer to device. + + # How many seeds the sample operation itself will need. + num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate + sampling_seeds_gpu = sampling_seeds_t.to(device=device, + non_blocking=True) + extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] + if not extra_seeds_gpu.numel(): + extra_seeds_gpu = None + sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] + return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -236,4 +318,36 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), output_tokens=output_tensor.to(device=device, non_blocking=True), + sampling_seeds=sampling_seeds_gpu, + sample_indices=sample_indices_t.to(device=device, + non_blocking=True), + extra_seeds=extra_seeds_gpu, ) + + @staticmethod + def _get_sequence_seeds( + seed: int, + *extra_entropy: int, + seeds_to_generate: int, + is_greedy: bool, + ): + """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" + # TODO(yard1): make sure sequences in the same group with + # best_of > 1 have different seeds (not trivial...) + if not is_greedy: + generator = random.Random(str((seed, ) + extra_entropy)) + lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max + # If the user/random sets seed = 0 but request should + # have sampling, we need to change it to something + # else. We use a constant in that case. + # This way we don't need to create and load a bool + # matrix in the sampling kernel, which reduces CPU + # overhead and latency. + seq_seeds = [ + generator.randint(lo, hi) or _SEED_0_REPLACEMENT + for _ in range(seeds_to_generate) + ] + else: + # For the kernel, seed == 0 means greedy decoding. + seq_seeds = [0] * seeds_to_generate + return seq_seeds diff --git a/vllm/sequence.py b/vllm/sequence.py index 040e9756e15c6..006da6a1d06f5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -105,6 +105,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_prompt_token_ids(self) -> List[int]: + return self.prompt_token_ids + def get_last_token_id(self) -> int: if not self.output_token_ids: return self.prompt_token_ids[-1] @@ -206,6 +209,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() + def get_prompt_token_ids(self) -> List[int]: + return self.data.get_prompt_token_ids() + def get_last_token_id(self) -> int: return self.data.get_last_token_id() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b99a409e02d1e..02a0686e831b3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -393,6 +393,7 @@ def _prepare_sample( selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): @@ -409,9 +410,12 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( @@ -433,9 +437,17 @@ def _prepare_sample( categorized_sample_indices[ sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs if sampling_params.seed is not None: generators.append(seq_group_metadata.state.generator) @@ -444,11 +456,13 @@ def _prepare_sample( dtype=torch.long, target_device=self.device, pin_memory=not self.in_wsl) + categorized_sample_indices = { - t: _async_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=not self.in_wsl) + t: _maybe_expand_dim( + _async_h2d(seq_ids, + dtype=torch.int, + target_device=self.device, + pin_memory=not self.in_wsl), 2, 2) for t, seq_ids in categorized_sample_indices.items() } @@ -883,3 +897,11 @@ def _async_h2d( ) -> torch.Tensor: t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) + + +def _maybe_expand_dim(tensor: torch.Tensor, + target_dims: int, + size: int = 1) -> torch.Tensor: + if tensor.ndim < target_dims: + tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) + return tensor From f215c3e9bfe5eb8eca5c0467aecc1b46f9a1cc3f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 13:26:23 -0800 Subject: [PATCH 02/16] Faster kernel --- .../layers/triton_kernel/rand.py | 156 ++++++++++++++++++ .../layers/triton_kernel/sample.py | 138 ++++++++++++---- 2 files changed, 259 insertions(+), 35 deletions(-) create mode 100644 vllm/model_executor/layers/triton_kernel/rand.py diff --git a/vllm/model_executor/layers/triton_kernel/rand.py b/vllm/model_executor/layers/triton_kernel/rand.py new file mode 100644 index 0000000000000..0047d1007ffec --- /dev/null +++ b/vllm/model_executor/layers/triton_kernel/rand.py @@ -0,0 +1,156 @@ +import torch +import triton +import triton.language as tl + +from typing import Optional, Union + + +def seeded_uniform( + *size, + seeds: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + pin_memory: Optional[bool] = False, +) -> torch.Tensor: + """Similar to torch.rand, but allows for seeds to be set per row. + + seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. + If it is 3d, the additional seeds needed will be derived automatically + in a deterministic fashion: + [ + row 0: [columns_with_seed_0], [columns_with_seed0^1], ... + ] + """ + n_dims = len(size) + + if n_dims > 3: + raise ValueError("seeded_uniform only supports up to 3D tensors") + + if out is None: + out = torch.empty(*size, + dtype=dtype, + device=device, + pin_memory=pin_memory) + elif out.shape != size: + raise ValueError("shape of out and size must be the same") + + if n_dims == 3: + n_rows, n_3d, n_cols = out.shape + stride_row = out.stride(0) + stride_3d = out.stride(1) + elif n_dims == 2: + n_rows, n_cols = out.shape + n_3d = 1 + stride_row = out.stride(0) + stride_3d = 1 + else: + n_cols = out.shape[0] + n_rows = 1 + n_3d = 1 + stride_row = 1 + stride_3d = 1 + + if seeds.ndim != 1: + raise ValueError("seeds must be a 1D tensor") + + if seeds.numel() != n_rows: + raise ValueError( + "seeds must have the same number of elements as out has rows") + + # The philox PRNG Triton uses generates 4 random numbers at once. + # Therefore, the most efficient use of it is to divide the + # block size by 4, and then save the generated random numbers to + # each of the 4 slices of the tensor. + full_block_size = triton.next_power_of_2(n_cols) + philox_block_size = max(full_block_size // 4, 1) + n_slices = full_block_size // philox_block_size + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if philox_block_size >= 8192: + num_warps = 32 + elif philox_block_size >= 4096: + num_warps = 16 + elif philox_block_size >= 2048: + num_warps = 8 + + _seeded_uniform_triton[(n_rows, n_3d)]( + out, + seeds, + stride_row, + stride_3d, + seeds.stride(0), + n_rows, + n_3d, + n_cols, + n_slices=n_slices, + num_warps=num_warps, + block_size=philox_block_size, + ) + return out + + +@triton.jit +def _seeded_uniform_triton( + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, + block_size: tl.constexpr, +): + """Generate a random number in [0, 1) for each element in the output + tensor. The random numbers in a row generated using the seed for that row. + + Args: + out_ptr: The output tensor. + seed_ptr: The per-row seeds to use for random number generation. + out_row_stride: The stride between rows of the output tensor. + out_3d_stride: The stride between 3D slices of the output tensor. + seed_row_stride: The stride between rows of the seed tensor. + n_rows: The number of rows in the output tensor. + n_3d: The size of second dimension of the output tensor, + if output tensor is 3D. + n_cols: The number of columns in the output tensor. + n_slices: The number of philox outputs to use. + """ + tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") + + # Get the row index. + row_idx = tl.program_id(axis=0) + three_d_idx = tl.program_id(axis=1) + + philox_offsets = tl.arange(0, block_size) + # Get the seed for the current element. + seed = tl.load(seed_ptr + row_idx * seed_row_stride) + if three_d_idx > 0: + seed ^= three_d_idx + # Generate random numbers in [0, 1). + out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) + + output_row_start_ptr = (out_ptr + row_idx * out_row_stride + + three_d_idx * out_3d_stride) + out1_offsets = philox_offsets + tl.store(output_row_start_ptr + out1_offsets, + out1, + mask=out1_offsets < n_cols) + if n_slices > 1: + out2_offsets = tl.arange(block_size, block_size * 2) + tl.store(output_row_start_ptr + out2_offsets, + out2, + mask=out2_offsets < n_cols) + if n_slices > 2: + out3_offsets = tl.arange(block_size * 2, block_size * 3) + tl.store(output_row_start_ptr + out3_offsets, + out3, + mask=out3_offsets < n_cols) + if n_slices > 3: + out4_offsets = tl.arange(block_size * 3, block_size * 4) + tl.store(output_row_start_ptr + out4_offsets, + out4, + mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/triton_kernel/sample.py b/vllm/model_executor/layers/triton_kernel/sample.py index 14c5c9d3a2890..30b74ce3bfad6 100644 --- a/vllm/model_executor/layers/triton_kernel/sample.py +++ b/vllm/model_executor/layers/triton_kernel/sample.py @@ -1,8 +1,11 @@ +import math +from typing import Tuple, Optional + import torch import triton import triton.language as tl -import math -from typing import Tuple, Optional + +from vllm.model_executor.layers.triton_kernel.rand import seeded_uniform _EPS = 1e-6 @@ -53,20 +56,28 @@ def _multi_split_sample( device=probs.device) for _ in range(n_splits) ] for i in range(n_splits): + n_samples = sample_indices.shape[0] + n_cols = split_probs[i].shape[1] + n_best = sampled_tokens_tmp[i].shape[1] + uniform_noise = seeded_uniform(n_samples, + n_best, + n_cols, + seeds=seeds[i].flatten(), + device=split_probs[i].device, + dtype=split_probs[i].dtype) # TODO(yard1): See if we can remove the contiguous() calls. # Will need kernel support. _sample( - split_probs[i], - split_logprobs[i], + split_probs[i].contiguous(), + split_logprobs[i].contiguous(), sample_indices, sampled_tokens_tmp[i], sampled_logprobs_tmp[i], sampled_modified_probs_tmp[i], seeds[i], - modify_greedy_probs=modify_greedy_probs, - # Don't save logprobs in kernel, we need to gather them - # below - save_logprobs=False, + uniform_noise, + modify_greedy_probs=False, + save_logprobs=save_logprobs, save_modified_probs=True, ) if i > 0: @@ -85,12 +96,14 @@ def _multi_split_sample( else: sampled_logprobs = None sampled_modified_probs = sampled_modified_probs.squeeze(0) + if modify_greedy_probs: # We need to modify the greedy probs for the sampled tokens. # We can't do this in the kernel as we need to know the # sampled tokens. probs.fill_(0.0) probs.scatter_(1, sampled_tokens, 1.0) + return (sampled_tokens, sampled_logprobs, sampled_modified_probs) @@ -186,17 +199,28 @@ def sample( sampled_modified_probs = torch.empty(sampled_modified_probs_size, dtype=probs.dtype, device=probs.device) + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + uniform_noise = seeded_uniform(n_samples, + max_best_of, + n_cols, + seeds=seeds.flatten(), + device=probs.device, + dtype=probs.dtype) - _sample(probs, - logprobs, - sample_indices, - sampled_tokens, - sampled_logprobs, - sampled_modified_probs, - seeds, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=_save_modified_probs) + _sample( + probs, + logprobs, + sample_indices, + sampled_tokens, + sampled_logprobs, + sampled_modified_probs, + seeds, + uniform_noise, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=_save_modified_probs, + ) return (sampled_tokens, sampled_logprobs if save_logprobs else None, sampled_modified_probs if _save_modified_probs else None) @@ -208,14 +232,44 @@ def _sample(probs: torch.Tensor, output_logprobs: torch.Tensor, output_modified_probs: torch.Tensor, seeds: torch.Tensor, + uniform_noise: torch.Tensor, *, modify_greedy_probs: bool = False, save_logprobs: bool = True, - save_modified_probs: bool = False) -> None: - # Operates in place. + save_modified_probs: bool = False) -> torch.Tensor: + """Sample tokens from probs. + + Args: + probs [batch_size, vocab_size]: probs to sample from. + logprobs [batch_size, vocab_size]: logprobs (used when + save_logprobsis True). + sample_indices [n]: Indices of the samples to use for each row of probs. + output_samples [n, n_best]: Output tensor to store samples in. + output_logprobs [n, n_best]: Output tensor to store logprobs in. + output_modified_probs [n, n_best]: Output tensor to store + probs of chosen tokens in (modified with noise). + seeds [n]: Seeds to use for sampling. If the seed is 0, we use + greedy sampling. Note this is ONLY used for determining + whether to use random sampling or not. The actual random + noise should be passed as uniform_noise. + uniform_noise [batch_size, n_best, vocab_size]: Uniform + noise to use for random sampling (will be converted + to exponential gumbel noise by the kernel). + modify_greedy_probs: If True, we modify the probs tensor in-place + to encode the sampling method used for each row. This is used + in speculative decoding. Only applies in greedy decoding. + save_logprobs: If True, we save the logprobs of the sampled tokens + in the output_logprobs tensor. + save_modified_probs: If True, we save the modified probs (with noise) + of the sampled tokens in the output_modified_probs tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + """ n_samples = sample_indices.shape[0] n_cols = probs.shape[1] n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 + # The block size is the smallest power of two greater than the number of # columns in probs block_size = triton.next_power_of_2(n_cols) @@ -228,9 +282,10 @@ def _sample(probs: torch.Tensor, num_warps = 16 elif block_size >= 2048: num_warps = 8 + # Enqueue kernel. The 1D launch grid is simple: we have one kernel # instance per row of the probs matrix - _sample_triton_kernel[(n_samples, n_best)]( + _sample_triton[(n_samples, n_best)]( sample_indices, output_samples, output_logprobs, @@ -238,8 +293,11 @@ def _sample(probs: torch.Tensor, probs, logprobs, seeds, + uniform_noise, output_samples.stride(0), probs.stride(0), + uniform_noise.stride(0), + uniform_noise.stride(1) if n_best > 1 else 1, n_samples, n_cols, n_best, @@ -249,16 +307,19 @@ def _sample(probs: torch.Tensor, save_logprobs=save_logprobs, save_modified_probs=save_modified_probs, ) + return output_samples, output_logprobs, output_modified_probs @triton.jit -def _sample_triton_kernel( +def _sample_triton( sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, output_logprobs_ptr: torch.Tensor, output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, seed_ptr: torch.Tensor, - output_row_stride: int, probs_row_stride: int, n_samples: int, - n_cols: int, n_best: int, block_size: tl.constexpr, + logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, output_row_stride: int, + probs_row_stride: int, uniform_noise_row_stride: int, + uniform_noise_best_stride: int, n_samples: int, n_cols: int, + n_best: int, block_size: tl.constexpr, modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, save_modified_probs: tl.constexpr): # The rows are independent, so we parallelize across those @@ -267,6 +328,8 @@ def _sample_triton_kernel( # Load the row index from DRAM row_idx = tl.load(sample_indices_ptr + sample_idx) + seed = tl.load(seeds_ptr + sample_idx) + uses_random_sampling = seed != 0 # The stride represents how much we need to increase the # pointer to advance 1 row @@ -275,25 +338,28 @@ def _sample_triton_kernel( # The block size is the next power of two greater than n_cols, # so we can fit each row in a single block col_offsets = tl.arange(0, block_size) - probs_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since block_size may be > than n_cols - row = tl.load(probs_ptrs, mask=col_offsets < n_cols, other=float("-inf")) - seed = tl.load(seed_ptr + sample_idx) - uses_random_sampling = seed != 0 + row = tl.load(row_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=float("-inf")) if uses_random_sampling: - random_uniform = tl.rand(seed + best_idx, col_offsets) + uniform_noise_start_ptr = uniform_noise_ptr + sample_idx * uniform_noise_row_stride + best_idx * uniform_noise_best_stride + uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=0.5) + # NEEDS TO BE MANUALLY KEPT IN SYNC WITH vllm/tests/ops/test_sampler.py # tl.rand returns values in [0, 1), so we clamp lower bound # to _EPS to avoid log(0) and thus division by nan later - lb = tl.full(random_uniform.shape, _EPS, random_uniform.dtype) - random_uniform = tl.maximum(random_uniform, lb) + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) # Use the inversion method to turn uniform samples # into exponential samples - random_exponential = -tl.log(random_uniform) + exponential_noise = -tl.log(uniform_noise) - row /= random_exponential + row /= exponential_noise sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) # clamp sampled token to n_cols - 1 @@ -313,7 +379,9 @@ def _sample_triton_kernel( # the sampling method must be encoded within the sampled # probability distributions. row = tl.where(col_offsets == sampled_token, 1.0, 0.0) - tl.store(probs_ptrs, row, mask=col_offsets < n_cols) + tl.store(row_start_ptr + col_offsets, + row, + mask=col_offsets < n_cols) if save_modified_probs: output_row_start_ptr = (output_modified_probs_ptr + From 4a7b24a7442efcae1b90fb66fed6a9807143bd2b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 14:38:26 -0800 Subject: [PATCH 03/16] WIP --- tests/samplers/test_sampler.py | 6 +++--- vllm/model_executor/layers/sampler.py | 17 +++++++++++++---- vllm/model_executor/sampling_metadata.py | 8 ++++++-- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 31e865f42ff3b..7dd67f39fcc9f 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -299,11 +299,11 @@ def test_sampler_logits_processors(seed: int, device: str): batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - # This sample logits processor gives infinite score to the i-th token, + # This sample logits processor gives maximum score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") + logits[len(token_ids)] = torch.finfo(logits.dtype).max return logits seq_group_metadata_list = [] @@ -382,7 +382,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): sample_probs = None - def mock_sample(probs, logprobs, sampling_metadata): + def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 9706e72a8316e..1e04600e3090a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -114,9 +114,8 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample_with_triton_kernel(probs, logprobs, - sampling_metadata, - sampling_tensors) + sample_results = _sample(probs, logprobs, sampling_metadata, + sampling_tensors) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) @@ -377,7 +376,7 @@ def _multinomial( return probs.div_(q).argmax(dim=1).view(-1, num_samples) -def _sample( +def _sample_old( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -529,6 +528,16 @@ def _sample_with_triton_kernel( return sample_results +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, + sampling_tensors) + + def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 086ab90df6512..5ccda95c20bba 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -335,7 +335,11 @@ def _get_sequence_seeds( # TODO(yard1): make sure sequences in the same group with # best_of > 1 have different seeds (not trivial...) if not is_greedy: - generator = random.Random(str((seed, ) + extra_entropy)) + if seed is None: + randint_fn = random.randint + else: + generator = random.Random(str((seed, ) + extra_entropy)) + randint_fn = generator.randint lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max # If the user/random sets seed = 0 but request should # have sampling, we need to change it to something @@ -344,7 +348,7 @@ def _get_sequence_seeds( # matrix in the sampling kernel, which reduces CPU # overhead and latency. seq_seeds = [ - generator.randint(lo, hi) or _SEED_0_REPLACEMENT + randint_fn(lo, hi) or _SEED_0_REPLACEMENT for _ in range(seeds_to_generate) ] else: From 111650ecaf77bbccda719d26c06172d6a38fca00 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 15:44:44 -0800 Subject: [PATCH 04/16] Add test --- tests/kernels/test_rand.py | 51 +++++ tests/kernels/test_sampler.py | 201 ++++++++++++++++++ .../layers/triton_kernel/sample.py | 2 +- 3 files changed, 253 insertions(+), 1 deletion(-) create mode 100644 tests/kernels/test_rand.py create mode 100644 tests/kernels/test_sampler.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py new file mode 100644 index 0000000000000..632e3d7a91b77 --- /dev/null +++ b/tests/kernels/test_rand.py @@ -0,0 +1,51 @@ +import torch +import pytest +import random + +from vllm.model_executor.layers.triton_kernel.rand import seeded_uniform +from vllm.model_executor.utils import set_random_seed + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_3d", [True, False]) +def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): + device = "cuda" + for seed in range(512): + set_random_seed(seed) + rows = random.randint(1, 512) + cols = random.randint(1, 64000) + if use_3d: + third_dim = random.randint(2, 10) + dims = [rows, third_dim, cols] + else: + dims = [rows, cols] + seeds = torch.randint(torch.iinfo(torch.long).min, + torch.iinfo(torch.long).max, (rows, ), + device=device) + + # Test that the same seed produces the same output + out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out2) + # del to save memory + del out2 + + out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out3) + # del to save memory + del out3 + + # Initialize out tensor with garbage to ensure that it is overwritten + out_with_tensor = seeded_uniform( + *dims, + out=torch.full( + (*dims, ), + -1, + dtype=dtype, + device=device, + ), + seeds=seeds, + dtype=dtype, + ) + torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py new file mode 100644 index 0000000000000..32d156084bde9 --- /dev/null +++ b/tests/kernels/test_sampler.py @@ -0,0 +1,201 @@ +import torch +import pytest +import triton +import triton.language as tl + +from vllm.model_executor.layers.triton_kernel.sample import (sample, + get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) +from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.sampling_metadata import SamplingTensors + +_EPS = 1e-6 + +SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size +MULTI_SPLIT_VOCAB_SIZE = (MAX_TRITON_N_COLS * 2) - 100 + + +# same code as in vllm/triton_aot/kernels/sample_triton.py +# TODO(yard1): figure out how to keep in sync... +@triton.jit +def _uniform_to_exponential(uniform_noise): + """Convert uniform samples to exponential samples.""" + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by 0 later + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + exponential_noise = -tl.log(uniform_noise) + return exponential_noise + + +@triton.jit +def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = _uniform_to_exponential(x) + tl.store(output + idx, y) + + +def test_uniform_to_exponential(): + """Test that we can convert uniform to exponential without div by 0.""" + input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], + dtype=torch.float32, + device="cuda") + output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") + _uniform_to_exponential_kernel[(1, )](input, output, 2) + assert torch.all(torch.isfinite(output)) + assert torch.all(output > 0) + assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +@pytest.mark.parametrize("save_logprobs", [True, False]) +def test_sample_decoding_only(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size, + save_logprobs): + set_random_seed(seed) + bs = 8 + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = (torch.rand( + (1, bs), device="cuda") < 0.5).expand(n_splits, bs) + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, bs), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, bs), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, bs), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + _save_modified_probs=True) + assert sampled_tokens.shape == (bs, max_best_of) + for i in range(bs): + assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) + request_uses_random_sampling = random_sampling_mask[0, i] + if modify_greedy_probs and not request_uses_random_sampling: + # If we are modifying greedy probs and the request is greedy, + # we want to make sure the probs tensor is modified in place + assert torch.allclose( + probs[i][sampled_tokens[i]], + torch.full_like(probs[i][sampled_tokens[i]], 1.0)) + assert torch.sum(probs[i]) == 1.0 + assert torch.allclose( + sampled_modified_probs[i][0], + torch.full_like(sampled_modified_probs[i][0], 1.0)) + elif request_uses_random_sampling: + # If the request is random, we want to make sure + # sampled_modified_probs tensor has noise added + # (and thus is different from probs tensor) + assert not torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + elif not request_uses_random_sampling: + # If the request is greedy and we are not modifying greedy probs, + # we want to make sure sampled_modified_probs tensor is the same as + # the probs tensor. + assert torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + + if save_logprobs: + assert sampled_logprobs.shape == (bs, max_best_of) + for i in range(bs): + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[i][ + sampled_tokens[i, best_of]]) + else: + assert sampled_logprobs is None + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +def test_sample_prompt_logprobs(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size): + set_random_seed(seed) + prompt_sizes = [16, 32, 64, 128] * 2 + samples = 8 + bs = samples + sum(prompt_sizes) + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.tensor(prompt_sizes, + dtype=torch.long, + device="cuda").cumsum_(0) + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = torch.rand( + (n_splits, samples), device="cuda") < 0.5 + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, samples), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, samples), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, samples), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, _ = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=True) + assert sampled_tokens.shape == (samples, max_best_of) + assert sampled_logprobs.shape == (samples, max_best_of) + for i, t in enumerate(sample_indices): + assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] + [sampled_tokens[i, best_of]]) + + +@pytest.mark.parametrize("seed", list(range(16))) +def test_get_sequence_seeds(seed): + """Ensure that we get a different child seed from base seed + extra entropy""" + starting_seed = seed + seq_seed = None + extra_entropy = 1 + for i in range(512): + new_seq_seed = SamplingTensors._get_sequence_seeds( + starting_seed, i, seeds_to_generate=1, is_greedy=False)[0] + new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( + starting_seed, + i, + extra_entropy, + seeds_to_generate=1, + is_greedy=False)[0] + assert new_seq_seed_extra_entropy != new_seq_seed + assert seq_seed != new_seq_seed + seq_seed = new_seq_seed diff --git a/vllm/model_executor/layers/triton_kernel/sample.py b/vllm/model_executor/layers/triton_kernel/sample.py index 30b74ce3bfad6..0b8fe30b3ea50 100644 --- a/vllm/model_executor/layers/triton_kernel/sample.py +++ b/vllm/model_executor/layers/triton_kernel/sample.py @@ -350,7 +350,7 @@ def _sample_triton( mask=col_offsets < n_cols, other=0.5) - # NEEDS TO BE MANUALLY KEPT IN SYNC WITH vllm/tests/ops/test_sampler.py + # NEEDS TO BE MANUALLY KEPT IN SYNC WITH tests/kernels/test_rand.py # tl.rand returns values in [0, 1), so we clamp lower bound # to _EPS to avoid log(0) and thus division by nan later lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) From bc5f9d4198b113383461a4c742f8a4d7fcb1ff3f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 15:58:09 -0800 Subject: [PATCH 05/16] Lint --- tests/kernels/test_sampler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 32d156084bde9..8160cc68243fa 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -3,9 +3,8 @@ import triton import triton.language as tl -from vllm.model_executor.layers.triton_kernel.sample import (sample, - get_num_triton_sampler_splits, - MAX_TRITON_N_COLS) +from vllm.model_executor.layers.triton_kernel.sample import ( + sample, get_num_triton_sampler_splits, MAX_TRITON_N_COLS) from vllm.model_executor.utils import set_random_seed from vllm.model_executor.sampling_metadata import SamplingTensors @@ -15,7 +14,7 @@ MULTI_SPLIT_VOCAB_SIZE = (MAX_TRITON_N_COLS * 2) - 100 -# same code as in vllm/triton_aot/kernels/sample_triton.py +# same code as in vllm/model_executor/layers/triton_kernel/rand.py # TODO(yard1): figure out how to keep in sync... @triton.jit def _uniform_to_exponential(uniform_noise): @@ -188,8 +187,10 @@ def test_get_sequence_seeds(seed): seq_seed = None extra_entropy = 1 for i in range(512): - new_seq_seed = SamplingTensors._get_sequence_seeds( - starting_seed, i, seeds_to_generate=1, is_greedy=False)[0] + new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, + i, + seeds_to_generate=1, + is_greedy=False)[0] new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( starting_seed, i, From 5e37fc231f25bac5bc19d924f95080c5826931f5 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 16:31:55 -0800 Subject: [PATCH 06/16] Tweak --- tests/kernels/test_sampler.py | 20 ++-------------- .../layers/triton_kernel/sample.py | 23 +++++++++++-------- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 8160cc68243fa..b929158a84061 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -4,31 +4,15 @@ import triton.language as tl from vllm.model_executor.layers.triton_kernel.sample import ( - sample, get_num_triton_sampler_splits, MAX_TRITON_N_COLS) + _uniform_to_exponential, sample, get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) from vllm.model_executor.utils import set_random_seed from vllm.model_executor.sampling_metadata import SamplingTensors -_EPS = 1e-6 - SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size MULTI_SPLIT_VOCAB_SIZE = (MAX_TRITON_N_COLS * 2) - 100 -# same code as in vllm/model_executor/layers/triton_kernel/rand.py -# TODO(yard1): figure out how to keep in sync... -@triton.jit -def _uniform_to_exponential(uniform_noise): - """Convert uniform samples to exponential samples.""" - # tl.rand returns values in [0, 1), so we clamp lower bound - # to _EPS to avoid log(0) and thus division by 0 later - lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) - uniform_noise = tl.maximum(uniform_noise, lb) - # Use the inversion method to turn uniform samples - # into exponential samples - exponential_noise = -tl.log(uniform_noise) - return exponential_noise - - @triton.jit def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): idx = tl.arange(0, n) diff --git a/vllm/model_executor/layers/triton_kernel/sample.py b/vllm/model_executor/layers/triton_kernel/sample.py index 0b8fe30b3ea50..4abc39cb622d4 100644 --- a/vllm/model_executor/layers/triton_kernel/sample.py +++ b/vllm/model_executor/layers/triton_kernel/sample.py @@ -310,6 +310,18 @@ def _sample(probs: torch.Tensor, return output_samples, output_logprobs, output_modified_probs +@triton.jit +def _uniform_to_exponential(uniform_noise): + """Convert uniform samples to exponential samples.""" + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by 0 later + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + exponential_noise = -tl.log(uniform_noise) + return exponential_noise + @triton.jit def _sample_triton( sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, @@ -349,16 +361,7 @@ def _sample_triton( uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, mask=col_offsets < n_cols, other=0.5) - - # NEEDS TO BE MANUALLY KEPT IN SYNC WITH tests/kernels/test_rand.py - # tl.rand returns values in [0, 1), so we clamp lower bound - # to _EPS to avoid log(0) and thus division by nan later - lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) - uniform_noise = tl.maximum(uniform_noise, lb) - # Use the inversion method to turn uniform samples - # into exponential samples - exponential_noise = -tl.log(uniform_noise) - + exponential_noise = _uniform_to_exponential(uniform_noise) row /= exponential_noise sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) From 8ccb36c86f5cda9c4d750640c47e06fbe5c3d72d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Mar 2024 16:38:17 -0800 Subject: [PATCH 07/16] Lint --- vllm/model_executor/layers/triton_kernel/sample.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/triton_kernel/sample.py b/vllm/model_executor/layers/triton_kernel/sample.py index 4abc39cb622d4..c01c1ba8e8778 100644 --- a/vllm/model_executor/layers/triton_kernel/sample.py +++ b/vllm/model_executor/layers/triton_kernel/sample.py @@ -322,6 +322,7 @@ def _uniform_to_exponential(uniform_noise): exponential_noise = -tl.log(uniform_noise) return exponential_noise + @triton.jit def _sample_triton( sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, From 73e5d28079e7b5dede5199be4d327fa74570e6cb Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Mar 2024 10:16:05 -0800 Subject: [PATCH 08/16] Try fix test --- tests/kernels/test_sampler.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index b929158a84061..50be47fd71f52 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -1,3 +1,5 @@ +import gc + import torch import pytest import triton @@ -13,6 +15,13 @@ MULTI_SPLIT_VOCAB_SIZE = (MAX_TRITON_N_COLS * 2) - 100 +@pytest.fixture(autouse=True) +def _cleanup(): + yield + gc.collect() + torch.cuda.empty_cache() + + @triton.jit def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): idx = tl.arange(0, n) From c276f4d70aaa48aca2c431a152f6dd4b0890209a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Mar 2024 11:48:37 -0800 Subject: [PATCH 09/16] Update tests/kernels/test_sampler.py --- tests/kernels/test_sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 50be47fd71f52..8bd96190fdf84 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -12,7 +12,7 @@ from vllm.model_executor.sampling_metadata import SamplingTensors SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size -MULTI_SPLIT_VOCAB_SIZE = (MAX_TRITON_N_COLS * 2) - 100 +MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 @pytest.fixture(autouse=True) From f6bb1d02888a9aeda0e589104825cbdbd49e6bbe Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Mar 2024 14:28:38 -0800 Subject: [PATCH 10/16] Disable for now --- vllm/model_executor/layers/sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1e04600e3090a..aeec503d22717 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -376,7 +376,7 @@ def _multinomial( return probs.div_(q).argmax(dim=1).view(-1, num_samples) -def _sample_old( +def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -534,8 +534,9 @@ def _sample( sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, ) -> List[Tuple[List[int], List[int]]]: - return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, - sampling_tensors) + return _sample_with_torch(probs, logprobs, sampling_metadata) + # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, + # sampling_tensors) def _get_logprobs( From 1a548bee480b8d6bfe6316149c83760b7e9f8d6e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 15 Mar 2024 15:33:01 -0700 Subject: [PATCH 11/16] Update vllm/model_executor/sampling_metadata.py Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- vllm/model_executor/sampling_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 5ccda95c20bba..ec3484b8e14a0 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -180,7 +180,6 @@ def from_sampling_metadata( for seq_id in seq_ids: seq_data = sampling_metadata.seq_data[seq_id] extra_entropy = extra_entropy or () - # extra_entropy = extra_entropy + (seq_id, ) seq_seeds = cls._get_sequence_seeds( seed, seq_data.get_len(), From c41cb11d17fbab04ec2a3ae5920fc472bcf7e502 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 18 Mar 2024 15:34:28 -0700 Subject: [PATCH 12/16] Review feedback --- vllm/model_executor/layers/attention/ops/rand.py | 3 ++- vllm/model_executor/layers/sampler.py | 14 ++++++++------ vllm/model_executor/sampling_metadata.py | 3 +-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/attention/ops/rand.py b/vllm/model_executor/layers/attention/ops/rand.py index 0047d1007ffec..5b4b7a153351f 100644 --- a/vllm/model_executor/layers/attention/ops/rand.py +++ b/vllm/model_executor/layers/attention/ops/rand.py @@ -104,7 +104,8 @@ def _seeded_uniform_triton( n_slices: tl.constexpr, block_size: tl.constexpr, ): - """Generate a random number in [0, 1) for each element in the output + """ + Generate a random float32 number in [0, 1) for each element in the output tensor. The random numbers in a row generated using the seed for that row. Args: diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 064bb822649e1..c67f105f99649 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -409,17 +409,17 @@ def _sample_with_torch( greedy_samples = torch.argmax(logprobs[sample_indices.long()], dim=-1) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_best_of = 1 + max_best_of_in_batch = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of, **seeded_args) + probs[sample_indices.long()], max_best_of_in_batch, **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -465,7 +465,7 @@ def _sample_with_triton_kernel( sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_metadata = {} - max_best_of = 1 + max_best_of_in_batch = 1 # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. @@ -486,7 +486,7 @@ def _sample_with_triton_kernel( for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -495,7 +495,7 @@ def _sample_with_triton_kernel( sampled_tokens, _, _ = sample_triton( probs=probs, seeds=sampling_tensors.sampling_seeds, - max_best_of=max_best_of, + max_best_of=max_best_of_in_batch, sample_indices=sampling_tensors.sample_indices, logprobs=logprobs, # don't save logprobs because we have logic for that below @@ -536,6 +536,8 @@ def _sample( sampling_tensors: SamplingTensors, ) -> List[Tuple[List[int], List[int]]]: return _sample_with_torch(probs, logprobs, sampling_metadata) + + # TODO: Enable once Triton kernel & associated code is faster. # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, # sampling_tensors) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 135f1fc753c61..d8e55a49abe11 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -185,6 +185,7 @@ def from_sampling_metadata( seed, seq_data.get_len(), *extra_entropy, + seq_id, seeds_to_generate=seeds_to_generate, is_greedy=is_greedy) sampling_seeds.append(seq_seeds) @@ -332,8 +333,6 @@ def _get_sequence_seeds( is_greedy: bool, ): """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" - # TODO(yard1): make sure sequences in the same group with - # best_of > 1 have different seeds (not trivial...) if not is_greedy: if seed is None: randint_fn = random.randint From f018ebb22b77e5bd00cfc215a612015bd6738173 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 18 Mar 2024 15:35:35 -0700 Subject: [PATCH 13/16] Lint --- tests/kernels/test_sampler.py | 3 ++- .../layers/attention/ops/sample.py | 4 +++- vllm/model_executor/layers/sampler.py | 16 ++++++++++------ vllm/model_executor/sampling_metadata.py | 3 ++- vllm/sequence.py | 3 --- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 8bd96190fdf84..3aad22c9c465c 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -175,7 +175,8 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of, @pytest.mark.parametrize("seed", list(range(16))) def test_get_sequence_seeds(seed): - """Ensure that we get a different child seed from base seed + extra entropy""" + """Ensure that we get a different child seed from base + seed + extra entropy""" starting_seed = seed seq_seed = None extra_entropy = 1 diff --git a/vllm/model_executor/layers/attention/ops/sample.py b/vllm/model_executor/layers/attention/ops/sample.py index c01c1ba8e8778..353d1da8a8676 100644 --- a/vllm/model_executor/layers/attention/ops/sample.py +++ b/vllm/model_executor/layers/attention/ops/sample.py @@ -358,7 +358,9 @@ def _sample_triton( other=float("-inf")) if uses_random_sampling: - uniform_noise_start_ptr = uniform_noise_ptr + sample_idx * uniform_noise_row_stride + best_idx * uniform_noise_best_stride + uniform_noise_start_ptr = (uniform_noise_ptr + + sample_idx * uniform_noise_row_stride + + best_idx * uniform_noise_best_stride) uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, mask=col_offsets < n_cols, other=0.5) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c67f105f99649..04ba12552e30c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -12,7 +12,8 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -from vllm.model_executor.layers.triton_kernel.sample import sample as sample_triton +from vllm.model_executor.layers.triton_kernel.sample import (sample as + sample_triton) from vllm.utils import is_neuron @@ -413,13 +414,15 @@ def _sample_with_torch( for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of_in_batch, **seeded_args) + probs[sample_indices.long()], max_best_of_in_batch, + **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -486,7 +489,8 @@ def _sample_with_triton_kernel( for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of_in_batch = max(max_best_of_in_batch, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -508,8 +512,8 @@ def _sample_with_triton_kernel( for sampling_type in SamplingType: if sampling_type not in sample_metadata: continue - seq_group_ids, seq_groups, is_prompts, sample_indices, sampled_token_indices = sample_metadata[ - sampling_type] + (seq_group_ids, seq_groups, is_prompts, sample_indices, + sampled_token_indices) = sample_metadata[sampling_type] if sampling_type == SamplingType.GREEDY: sample_results = _greedy_sample( seq_groups, sampled_tokens[sampled_token_indices][:, 0]) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d8e55a49abe11..333d343bd987d 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,7 +7,8 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron -from vllm.model_executor.layers.triton_kernel.sample import get_num_triton_sampler_splits +from vllm.model_executor.layers.triton_kernel.sample import ( + get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 diff --git a/vllm/sequence.py b/vllm/sequence.py index e289d5783ac04..ff96dd306791c 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -122,9 +122,6 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids - def get_prompt_token_ids(self) -> List[int]: - return self.prompt_token_ids - def get_last_token_id(self) -> int: if not self.output_token_ids: return self.prompt_token_ids[-1] From f2177669bc26002aa21e107b421761f61129a5b0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 18 Mar 2024 17:22:17 -0700 Subject: [PATCH 14/16] FIx C --- tests/kernels/test_rand.py | 2 +- tests/kernels/test_sampler.py | 6 +++--- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/{attention => }/ops/rand.py | 0 vllm/model_executor/layers/{attention => }/ops/sample.py | 2 +- vllm/model_executor/layers/sampler.py | 3 +-- vllm/model_executor/sampling_metadata.py | 3 +-- 7 files changed, 7 insertions(+), 9 deletions(-) create mode 100644 vllm/model_executor/layers/ops/__init__.py rename vllm/model_executor/layers/{attention => }/ops/rand.py (100%) rename vllm/model_executor/layers/{attention => }/ops/sample.py (99%) diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py index 632e3d7a91b77..ac2382a48cc4d 100644 --- a/tests/kernels/test_rand.py +++ b/tests/kernels/test_rand.py @@ -2,7 +2,7 @@ import pytest import random -from vllm.model_executor.layers.triton_kernel.rand import seeded_uniform +from vllm.model_executor.ops.rand import seeded_uniform from vllm.model_executor.utils import set_random_seed diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 3aad22c9c465c..11fdfa8260716 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -5,9 +5,9 @@ import triton import triton.language as tl -from vllm.model_executor.layers.triton_kernel.sample import ( - _uniform_to_exponential, sample, get_num_triton_sampler_splits, - MAX_TRITON_N_COLS) +from vllm.model_executor.ops.sample import (_uniform_to_exponential, sample, + get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) from vllm.model_executor.utils import set_random_seed from vllm.model_executor.sampling_metadata import SamplingTensors diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/attention/ops/rand.py b/vllm/model_executor/layers/ops/rand.py similarity index 100% rename from vllm/model_executor/layers/attention/ops/rand.py rename to vllm/model_executor/layers/ops/rand.py diff --git a/vllm/model_executor/layers/attention/ops/sample.py b/vllm/model_executor/layers/ops/sample.py similarity index 99% rename from vllm/model_executor/layers/attention/ops/sample.py rename to vllm/model_executor/layers/ops/sample.py index 353d1da8a8676..5fd06ab54e575 100644 --- a/vllm/model_executor/layers/attention/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -5,7 +5,7 @@ import triton import triton.language as tl -from vllm.model_executor.layers.triton_kernel.rand import seeded_uniform +from vllm.model_executor.ops.rand import seeded_uniform _EPS = 1e-6 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 04ba12552e30c..c047c183541d7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -12,8 +12,7 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -from vllm.model_executor.layers.triton_kernel.sample import (sample as - sample_triton) +from vllm.model_executor.ops.sample import (sample as sample_triton) from vllm.utils import is_neuron diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 333d343bd987d..ecd008ea9b469 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,8 +7,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron -from vllm.model_executor.layers.triton_kernel.sample import ( - get_num_triton_sampler_splits) +from vllm.model_executor.ops.sample import (get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 From fe6ae2434affee4fd181d334d5e7a847185de3ba Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 18 Mar 2024 20:02:37 -0700 Subject: [PATCH 15/16] Fix --- tests/kernels/test_rand.py | 2 +- tests/kernels/test_sampler.py | 2 +- vllm/model_executor/layers/ops/sample.py | 2 +- vllm/model_executor/layers/sampler.py | 2 +- vllm/model_executor/sampling_metadata.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py index ac2382a48cc4d..3b9d0d732acf5 100644 --- a/tests/kernels/test_rand.py +++ b/tests/kernels/test_rand.py @@ -2,7 +2,7 @@ import pytest import random -from vllm.model_executor.ops.rand import seeded_uniform +from vllm.model_executor.layers.ops.rand import seeded_uniform from vllm.model_executor.utils import set_random_seed diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 11fdfa8260716..22b07329db60d 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -5,7 +5,7 @@ import triton import triton.language as tl -from vllm.model_executor.ops.sample import (_uniform_to_exponential, sample, +from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential, sample, get_num_triton_sampler_splits, MAX_TRITON_N_COLS) from vllm.model_executor.utils import set_random_seed diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py index 5fd06ab54e575..0077317282204 100644 --- a/vllm/model_executor/layers/ops/sample.py +++ b/vllm/model_executor/layers/ops/sample.py @@ -5,7 +5,7 @@ import triton import triton.language as tl -from vllm.model_executor.ops.rand import seeded_uniform +from vllm.model_executor.layers.ops.rand import seeded_uniform _EPS = 1e-6 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c047c183541d7..1fab1e734e1d7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -12,7 +12,7 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -from vllm.model_executor.ops.sample import (sample as sample_triton) +from vllm.model_executor.layers.ops.sample import (sample as sample_triton) from vllm.utils import is_neuron diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index ecd008ea9b469..3b039e5ce6095 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,7 +7,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron -from vllm.model_executor.ops.sample import (get_num_triton_sampler_splits) +from vllm.model_executor.layers.ops.sample import (get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 From 119e3551e0f9f30b4af049e66b9435b1a1beef67 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 18 Mar 2024 20:09:32 -0700 Subject: [PATCH 16/16] Lint --- tests/kernels/test_sampler.py | 6 +++--- vllm/model_executor/sampling_metadata.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py index 22b07329db60d..5f8c51fb074f4 100644 --- a/tests/kernels/test_sampler.py +++ b/tests/kernels/test_sampler.py @@ -5,9 +5,9 @@ import triton import triton.language as tl -from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential, sample, - get_num_triton_sampler_splits, - MAX_TRITON_N_COLS) +from vllm.model_executor.layers.ops.sample import ( + _uniform_to_exponential, sample, get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) from vllm.model_executor.utils import set_random_seed from vllm.model_executor.sampling_metadata import SamplingTensors diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 3b039e5ce6095..7d08feb3fee1c 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -7,7 +7,8 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron -from vllm.model_executor.layers.ops.sample import (get_num_triton_sampler_splits) +from vllm.model_executor.layers.ops.sample import ( + get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558