From 546034b466bf11f12936791312981b9982850eb0 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Mon, 16 Sep 2024 20:04:48 -0700 Subject: [PATCH] [refactor] remove triton based sampler (#8524) --- tests/kernels/test_rand.py | 52 --- tests/kernels/test_sampler.py | 209 ----------- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/ops/rand.py | 157 -------- vllm/model_executor/layers/ops/sample.py | 394 --------------------- vllm/model_executor/layers/sampler.py | 97 +---- vllm/model_executor/sampling_metadata.py | 211 +++-------- vllm/triton_utils/sample.py | 13 - vllm/utils.py | 37 +- 9 files changed, 75 insertions(+), 1095 deletions(-) delete mode 100644 tests/kernels/test_rand.py delete mode 100644 tests/kernels/test_sampler.py delete mode 100644 vllm/model_executor/layers/ops/__init__.py delete mode 100644 vllm/model_executor/layers/ops/rand.py delete mode 100644 vllm/model_executor/layers/ops/sample.py delete mode 100644 vllm/triton_utils/sample.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py deleted file mode 100644 index a4242d22eb489..0000000000000 --- a/tests/kernels/test_rand.py +++ /dev/null @@ -1,52 +0,0 @@ -import random - -import pytest -import torch - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.model_executor.utils import set_random_seed - - -@pytest.mark.parametrize("dtype", - [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("use_3d", [True, False]) -def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): - device = "cuda" - for seed in range(512): - set_random_seed(seed) - rows = random.randint(1, 512) - cols = random.randint(1, 64000) - if use_3d: - third_dim = random.randint(2, 10) - dims = [rows, third_dim, cols] - else: - dims = [rows, cols] - seeds = torch.randint(torch.iinfo(torch.long).min, - torch.iinfo(torch.long).max, (rows, ), - device=device) - - # Test that the same seed produces the same output - out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out2) - # del to save memory - del out2 - - out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) - torch.testing.assert_close(out, out3) - # del to save memory - del out3 - - # Initialize out tensor with garbage to ensure that it is overwritten - out_with_tensor = seeded_uniform( - *dims, - out=torch.full( - (*dims, ), - -1, - dtype=dtype, - device=device, - ), - seeds=seeds, - dtype=dtype, - ) - torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py deleted file mode 100644 index 03844aba20f8a..0000000000000 --- a/tests/kernels/test_sampler.py +++ /dev/null @@ -1,209 +0,0 @@ -import gc -from unittest.mock import patch - -import pytest -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.sample import (_sample_triton, - _uniform_to_exponential, - sample) -from vllm.model_executor.sampling_metadata import SamplingTensors -from vllm.model_executor.utils import set_random_seed -from vllm.triton_utils.libentry import LibEntry -from vllm.triton_utils.sample import (MAX_TRITON_N_COLS, - get_num_triton_sampler_splits) - -SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size -MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 - - -@pytest.fixture(autouse=True) -def _cleanup(): - yield - gc.collect() - torch.cuda.empty_cache() - - -@triton.jit -def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): - idx = tl.arange(0, n) - x = tl.load(input + idx) - y = _uniform_to_exponential(x) - tl.store(output + idx, y) - - -def test_uniform_to_exponential(): - """Test that we can convert uniform to exponential without div by 0.""" - input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], - dtype=torch.float32, - device="cuda") - output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") - _uniform_to_exponential_kernel[(1, )](input, output, 2) - assert torch.all(torch.isfinite(output)) - assert torch.all(output > 0) - assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -@pytest.mark.parametrize("save_logprobs", [True, False]) -def test_sample_decoding_only(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size, - save_logprobs): - set_random_seed(seed) - bs = 8 - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = (torch.rand( - (1, bs), device="cuda") < 0.5).expand(n_splits, bs) - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, bs), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, bs), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, bs), - device="cuda").mul_(random_sampling_mask) - #The current _sample_triton does not utilize the - # libentry decoration. The purpose of adding this patch is to test - # the correctness of libentry. - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - _save_modified_probs=True) - assert sampled_tokens.shape == (bs, max_best_of) - for i in range(bs): - assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) - request_uses_random_sampling = random_sampling_mask[0, i] - if modify_greedy_probs and not request_uses_random_sampling: - # If we are modifying greedy probs and the request is greedy, - # we want to make sure the probs tensor is modified in place - torch.testing.assert_close( - probs[i][sampled_tokens[i]], - torch.full_like(probs[i][sampled_tokens[i]], 1.0)) - assert torch.sum(probs[i]) == 1.0 - torch.testing.assert_close( - sampled_modified_probs[i][0], - torch.full_like(sampled_modified_probs[i][0], 1.0)) - elif request_uses_random_sampling: - # If the request is random, we want to make sure - # sampled_modified_probs tensor has noise added - # (and thus is different from probs tensor) - assert not torch.allclose(sampled_modified_probs[i][0], - probs[i][sampled_tokens[i]]) - elif not request_uses_random_sampling: - # If the request is greedy and we are not modifying greedy probs, - # we want to make sure sampled_modified_probs tensor is the same as - # the probs tensor. - torch.testing.assert_close(sampled_modified_probs[i], - probs[i][sampled_tokens[i]]) - - if save_logprobs: - assert sampled_logprobs.shape == (bs, max_best_of) - for i in range(bs): - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[i][ - sampled_tokens[i, best_of]]) - else: - assert sampled_logprobs is None - - -@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) -@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) -@pytest.mark.parametrize("modify_greedy_probs", [True, False]) -@pytest.mark.parametrize("seed", [1337]) -@pytest.mark.parametrize("vocab_size", - [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) -def test_sample_prompt_logprobs(random_sampling, max_best_of, - modify_greedy_probs, seed, vocab_size): - - set_random_seed(seed) - prompt_sizes = [16, 32, 64, 128] * 2 - samples = 8 - bs = samples + sum(prompt_sizes) - probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") - for i in range(bs): - probs[i, i * (vocab_size // bs)] = 1.0 - logprobs = torch.rand_like(probs) - sample_indices = torch.tensor(prompt_sizes, - dtype=torch.long, - device="cuda").cumsum_(0) - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if random_sampling == "mixed": - random_sampling_mask = torch.rand( - (n_splits, samples), device="cuda") < 0.5 - elif random_sampling: - random_sampling_mask = torch.ones((n_splits, samples), - dtype=torch.bool, - device="cuda") - else: - random_sampling_mask = torch.zeros((n_splits, samples), - dtype=torch.bool, - device="cuda") - - seeds = torch.randint(1, - torch.iinfo(torch.long).max, (n_splits, samples), - device="cuda").mul_(random_sampling_mask) - #ditto - with patch("vllm.model_executor.layers.ops.sample._sample_triton", - LibEntry(_sample_triton)): - sampled_tokens, sampled_logprobs, _ = sample( - probs=probs, - logprobs=logprobs, - sample_indices=sample_indices, - seeds=seeds, - max_best_of=max_best_of, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=True) - assert sampled_tokens.shape == (samples, max_best_of) - assert sampled_logprobs.shape == (samples, max_best_of) - for i, t in enumerate(sample_indices): - assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) - for best_of in range(max_best_of): - assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] - [sampled_tokens[i, best_of]]) - - -@pytest.mark.parametrize("seed", list(range(16))) -def test_get_sequence_seeds(seed): - """Ensure that we get a different child seed from base - seed + extra entropy""" - starting_seed = seed - seq_seed = None - extra_entropy = 1 - for i in range(512): - new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, - i, - seeds_to_generate=1, - is_greedy=False)[0] - new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( - starting_seed, - i, - extra_entropy, - seeds_to_generate=1, - is_greedy=False)[0] - assert new_seq_seed_extra_entropy != new_seq_seed - assert seq_seed != new_seq_seed - seq_seed = new_seq_seed diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py deleted file mode 100644 index 4a429e329567d..0000000000000 --- a/vllm/model_executor/layers/ops/rand.py +++ /dev/null @@ -1,157 +0,0 @@ -from typing import Optional, Union - -import torch -import triton -import triton.language as tl - - -def seeded_uniform( - *size, - seeds: torch.Tensor, - out: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None, - device: Optional[Union[torch.device, str]] = None, - pin_memory: Optional[bool] = False, -) -> torch.Tensor: - """Similar to torch.rand, but allows for seeds to be set per row. - - seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. - If it is 3d, the additional seeds needed will be derived automatically - in a deterministic fashion: - [ - row 0: [columns_with_seed_0], [columns_with_seed0^1], ... - ] - """ - n_dims = len(size) - - if n_dims > 3: - raise ValueError("seeded_uniform only supports up to 3D tensors") - - if out is None: - out = torch.empty(*size, - dtype=dtype, - device=device, - pin_memory=pin_memory) - elif out.shape != size: - raise ValueError("shape of out and size must be the same") - - if n_dims == 3: - n_rows, n_3d, n_cols = out.shape - stride_row = out.stride(0) - stride_3d = out.stride(1) - elif n_dims == 2: - n_rows, n_cols = out.shape - n_3d = 1 - stride_row = out.stride(0) - stride_3d = 1 - else: - n_cols = out.shape[0] - n_rows = 1 - n_3d = 1 - stride_row = 1 - stride_3d = 1 - - if seeds.ndim != 1: - raise ValueError("seeds must be a 1D tensor") - - if seeds.numel() != n_rows: - raise ValueError( - "seeds must have the same number of elements as out has rows") - - # The philox PRNG Triton uses generates 4 random numbers at once. - # Therefore, the most efficient use of it is to divide the - # block size by 4, and then save the generated random numbers to - # each of the 4 slices of the tensor. - full_block_size = triton.next_power_of_2(n_cols) - philox_block_size = max(full_block_size // 4, 1) - n_slices = full_block_size // philox_block_size - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if philox_block_size >= 8192: - num_warps = 32 - elif philox_block_size >= 4096: - num_warps = 16 - elif philox_block_size >= 2048: - num_warps = 8 - - _seeded_uniform_triton[(n_rows, n_3d)]( - out, - seeds, - stride_row, - stride_3d, - seeds.stride(0), - n_rows, - n_3d, - n_cols, - n_slices=n_slices, - num_warps=num_warps, - block_size=philox_block_size, - ) - return out - - -@triton.jit -def _seeded_uniform_triton( - out_ptr: torch.Tensor, - seed_ptr: torch.Tensor, - out_row_stride: int, - out_3d_stride: int, - seed_row_stride: int, - n_rows: int, - n_3d: int, - n_cols: int, - n_slices: tl.constexpr, - block_size: tl.constexpr, -): - """ - Generate a random float32 number in [0, 1) for each element in the output - tensor. The random numbers in a row generated using the seed for that row. - - Args: - out_ptr: The output tensor. - seed_ptr: The per-row seeds to use for random number generation. - out_row_stride: The stride between rows of the output tensor. - out_3d_stride: The stride between 3D slices of the output tensor. - seed_row_stride: The stride between rows of the seed tensor. - n_rows: The number of rows in the output tensor. - n_3d: The size of second dimension of the output tensor, - if output tensor is 3D. - n_cols: The number of columns in the output tensor. - n_slices: The number of philox outputs to use. - """ - tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") - - # Get the row index. - row_idx = tl.program_id(axis=0) - three_d_idx = tl.program_id(axis=1) - - philox_offsets = tl.arange(0, block_size) - # Get the seed for the current element. - seed = tl.load(seed_ptr + row_idx * seed_row_stride) - if three_d_idx > 0: - seed ^= three_d_idx - # Generate random numbers in [0, 1). - out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) - - output_row_start_ptr = (out_ptr + row_idx * out_row_stride + - three_d_idx * out_3d_stride) - out1_offsets = philox_offsets - tl.store(output_row_start_ptr + out1_offsets, - out1, - mask=out1_offsets < n_cols) - if n_slices > 1: - out2_offsets = tl.arange(block_size, block_size * 2) - tl.store(output_row_start_ptr + out2_offsets, - out2, - mask=out2_offsets < n_cols) - if n_slices > 2: - out3_offsets = tl.arange(block_size * 2, block_size * 3) - tl.store(output_row_start_ptr + out3_offsets, - out3, - mask=out3_offsets < n_cols) - if n_slices > 3: - out4_offsets = tl.arange(block_size * 3, block_size * 4) - tl.store(output_row_start_ptr + out4_offsets, - out4, - mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py deleted file mode 100644 index fb88a05daf482..0000000000000 --- a/vllm/model_executor/layers/ops/sample.py +++ /dev/null @@ -1,394 +0,0 @@ -from typing import Optional, Tuple - -import torch -import triton -import triton.language as tl - -from vllm.model_executor.layers.ops.rand import seeded_uniform -from vllm.triton_utils.sample import get_num_triton_sampler_splits - -_EPS: tl.constexpr = 1e-6 - - -def _multi_split_sample( - probs: torch.Tensor, - seeds: torch.Tensor, - n_splits: int, - sampled_tokens_size: Tuple[int, int], - sampled_logprobs_size: Tuple[int, int], - sample_indices: torch.Tensor, - logprobs: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, -): - """Sample tokens where vocab size is split into multiple parts - (too large for Triton otherwise).""" - assert seeds.ndim == 2 and seeds.shape[0] == n_splits - split_probs = probs.tensor_split(n_splits, 1) - split_logprobs = logprobs.tensor_split(n_splits, 1) - sampled_tokens_tmp = [ - torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) - for _ in range(n_splits) - ] - sampled_logprobs_tmp = [ - torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - # We are purposefuly using sampled_tokens_size as we need to always - # save modified probs in this case. - sampled_modified_probs_tmp = [ - torch.empty(sampled_tokens_size, - dtype=probs.dtype, - device=probs.device) for _ in range(n_splits) - ] - for i in range(n_splits): - n_samples = sample_indices.shape[0] - n_cols = split_probs[i].shape[1] - n_best = sampled_tokens_tmp[i].shape[1] - uniform_noise = seeded_uniform(n_samples, - n_best, - n_cols, - seeds=seeds[i].flatten(), - device=split_probs[i].device, - dtype=split_probs[i].dtype) - # TODO(yard1): See if we can remove the contiguous() calls. - # Will need kernel support. - _sample( - split_probs[i].contiguous(), - split_logprobs[i].contiguous(), - sample_indices, - sampled_tokens_tmp[i], - sampled_logprobs_tmp[i], - sampled_modified_probs_tmp[i], - seeds[i], - uniform_noise, - modify_greedy_probs=False, - save_logprobs=save_logprobs, - save_modified_probs=True, - ) - if i > 0: - # Add offset to sampled tokens - sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) - sampled_tokens = torch.stack(sampled_tokens_tmp) - sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) - # Reduce the results from the splits. - sampled_modified_probs, indices = torch.max(sampled_modified_probs, - dim=0, - keepdim=True) - sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) - if save_logprobs: - sampled_logprobs = torch.stack(sampled_logprobs_tmp) - sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) - else: - sampled_logprobs = None - sampled_modified_probs = sampled_modified_probs.squeeze(0) - - if modify_greedy_probs: - # We need to modify the greedy probs for the sampled tokens. - # We can't do this in the kernel as we need to know the - # sampled tokens. - probs.fill_(0.0) - probs.scatter_(1, sampled_tokens, 1.0) - - return (sampled_tokens, sampled_logprobs, sampled_modified_probs) - - -def sample( - probs: torch.Tensor, - seeds: torch.Tensor, - *, - max_best_of: int = 1, - sample_indices: Optional[torch.Tensor] = None, - logprobs: Optional[torch.Tensor] = None, - modify_greedy_probs: bool = False, - save_logprobs: bool = False, - _save_modified_probs: bool = False, # pylint: disable=invalid-name -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - """Sample tokens from probs. with per-sequence seeds. - - Can sample from a subset of sequences through sample_indices. - - Args: - probs: Probabilities to sample from. - shape = [batch_size, vocab_size] - seeds: Per-sequence seed values. - shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] - max_best_of: Number of samples to generate per sequence. - Sequence seed will be incremented by 1 each time. - sample_indices: Indices of sequences to sample from. - If not provided, will sample from all sequences. - shape = [n] - logprobs: Log-probabilities of the sampled tokens. - Only used for saving the logprobs if save_logprobs is True. - shape = [batch_size, vocab_size] - modify_greedy_probs: Whether to modify the greedy probabilities - for speculative sampling (sampled token = 1.0, - everything else = 0.0). - save_logprobs: Whether to save the log-probabilities of the - sampled tokens to a tensor. - _save_modified_probs: Whether to save the modified probabilities - (including gumbel noise) of the sampled tokens to a tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - This is exposed only for testing. - - Returns: - sampled_tokens: shape = [n, max_best_of] - sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None - sampled_modified_probs: shape = [n, max_best_of] - if save_modified_probs else None - """ - if sample_indices is None: - sample_indices = torch.arange(0, probs.shape[0], device=probs.device) - - sampled_tokens_size = (sample_indices.size(0), max_best_of) - if save_logprobs: - if logprobs is None: - raise ValueError( - "logprobs tensor must be provided if save_logprobs is True") - sampled_logprobs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_logprobs_size = (0, 0) - logprobs = probs - - assert logprobs is not None - if _save_modified_probs: - sampled_modified_probs_size = sampled_tokens_size - else: - # Empty tensors to invoke the kernel - sampled_modified_probs_size = (0, 0) - - # If the number of columns in probs is too large for Triton to handle, - # we split the tensor and sample from each split separately, and then - # do an argmax+gather to combine the results. - n_splits = get_num_triton_sampler_splits(probs.shape[1]) - if n_splits > 1: - (sampled_tokens, sampled_logprobs, - sampled_modified_probs) = _multi_split_sample( - probs, - seeds, - n_splits, - sampled_tokens_size, - sampled_logprobs_size, - sample_indices, - logprobs=logprobs, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs) - else: - sampled_tokens = torch.empty(sampled_tokens_size, - dtype=torch.long, - device=probs.device) - sampled_logprobs = torch.empty(sampled_logprobs_size, - dtype=probs.dtype, - device=probs.device) - sampled_modified_probs = torch.empty(sampled_modified_probs_size, - dtype=probs.dtype, - device=probs.device) - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - uniform_noise = seeded_uniform(n_samples, - max_best_of, - n_cols, - seeds=seeds.flatten(), - device=probs.device, - dtype=probs.dtype) - - _sample( - probs, - logprobs, - sample_indices, - sampled_tokens, - sampled_logprobs, - sampled_modified_probs, - seeds, - uniform_noise, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=_save_modified_probs, - ) - return (sampled_tokens, sampled_logprobs if save_logprobs else None, - sampled_modified_probs if _save_modified_probs else None) - - -def _sample(probs: torch.Tensor, - logprobs: torch.Tensor, - sample_indices: torch.Tensor, - output_samples: torch.Tensor, - output_logprobs: torch.Tensor, - output_modified_probs: torch.Tensor, - seeds: torch.Tensor, - uniform_noise: torch.Tensor, - *, - modify_greedy_probs: bool = False, - save_logprobs: bool = True, - save_modified_probs: bool = False) -> torch.Tensor: - """Sample tokens from probs. - - Args: - probs [batch_size, vocab_size]: probs to sample from. - logprobs [batch_size, vocab_size]: logprobs (used when - save_logprobsis True). - sample_indices [n]: Indices of the samples to use for each row of probs. - output_samples [n, n_best]: Output tensor to store samples in. - output_logprobs [n, n_best]: Output tensor to store logprobs in. - output_modified_probs [n, n_best]: Output tensor to store - probs of chosen tokens in (modified with noise). - seeds [n]: Seeds to use for sampling. If the seed is 0, we use - greedy sampling. Note this is ONLY used for determining - whether to use random sampling or not. The actual random - noise should be passed as uniform_noise. - uniform_noise [batch_size, n_best, vocab_size]: Uniform - noise to use for random sampling (will be converted - to exponential gumbel noise by the kernel). - modify_greedy_probs: If True, we modify the probs tensor in-place - to encode the sampling method used for each row. This is used - in speculative decoding. Only applies in greedy decoding. - save_logprobs: If True, we save the logprobs of the sampled tokens - in the output_logprobs tensor. - save_modified_probs: If True, we save the modified probs (with noise) - of the sampled tokens in the output_modified_probs tensor. - DOES NOT include the modification done by modify_greedy_probs - (because we want to use the unmodified probs to pick the best - split in case of multi-split sampling). - """ - n_samples = sample_indices.shape[0] - n_cols = probs.shape[1] - n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 - - # The block size is the smallest power of two greater than the number of - # columns in probs - block_size = triton.next_power_of_2(n_cols) - num_warps = 4 - # Manual tuning. This seems to give best performance on A100 for - # simple kernels like this. - if block_size >= 8192: - num_warps = 32 - elif block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - - # Enqueue kernel. The 1D launch grid is simple: we have one kernel - # instance per row of the probs matrix - _sample_triton[(n_samples, n_best)]( - sample_indices, - output_samples, - output_logprobs, - output_modified_probs, - probs, - logprobs, - seeds, - uniform_noise, - output_samples.stride(0), - probs.stride(0), - uniform_noise.stride(0), - uniform_noise.stride(1) if n_best > 1 else 1, - n_samples, - n_cols, - n_best, - num_warps=num_warps, - block_size=block_size, - modify_greedy_probs=modify_greedy_probs, - save_logprobs=save_logprobs, - save_modified_probs=save_modified_probs, - ) - return output_samples, output_logprobs, output_modified_probs - - -@triton.jit -def _uniform_to_exponential(uniform_noise): - """Convert uniform samples to exponential samples.""" - # tl.rand returns values in [0, 1), so we clamp lower bound - # to _EPS to avoid log(0) and thus division by 0 later - lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) - uniform_noise = tl.maximum(uniform_noise, lb) - # Use the inversion method to turn uniform samples - # into exponential samples - exponential_noise = -tl.log(uniform_noise) - return exponential_noise - - -@triton.jit -def _sample_triton( - sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, - output_logprobs_ptr: torch.Tensor, - output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, - logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, - uniform_noise_ptr: torch.Tensor, output_row_stride: int, - probs_row_stride: int, uniform_noise_row_stride: int, - uniform_noise_best_stride: int, n_samples: int, n_cols: int, - n_best: int, block_size: tl.constexpr, - modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, - save_modified_probs: tl.constexpr): - # The rows are independent, so we parallelize across those - sample_idx = tl.program_id(0) - best_idx = tl.program_id(1) - - # Load the row index from DRAM - row_idx = tl.load(sample_indices_ptr + sample_idx) - seed = tl.load(seeds_ptr + sample_idx) - uses_random_sampling = seed != 0 - - # The stride represents how much we need to increase the - # pointer to advance 1 row - row_start_ptr = probs_ptr + row_idx * probs_row_stride - - # The block size is the next power of two greater than n_cols, - # so we can fit each row in a single block - col_offsets = tl.arange(0, block_size) - - # Load the row into SRAM, using a mask since block_size may be > than n_cols - row = tl.load(row_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=float("-inf")) - - if uses_random_sampling: - uniform_noise_start_ptr = (uniform_noise_ptr + - sample_idx * uniform_noise_row_stride + - best_idx * uniform_noise_best_stride) - uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, - mask=col_offsets < n_cols, - other=0.5) - exponential_noise = _uniform_to_exponential(uniform_noise) - row /= exponential_noise - - sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) - # clamp sampled token to n_cols - 1 - # this should not be necessary, but we do it - # just in case - if sampled_token >= n_cols: - sampled_token = n_cols - 1 - # Write back output to DRAM - output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + - best_idx) - tl.store(output_row_start_ptr, sampled_token) - - if modify_greedy_probs: # noqa - if not uses_random_sampling: - # Set the probability of the sampled token to 1, all other - # tokens to zero. This is used in speculative decoding where - # the sampling method must be encoded within the sampled - # probability distributions. - row = tl.where(col_offsets == sampled_token, 1.0, 0.0) - tl.store(row_start_ptr + col_offsets, - row, - mask=col_offsets < n_cols) - - if save_modified_probs: - output_row_start_ptr = (output_modified_probs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_value) - - if save_logprobs: - # Load the row into SRAM, using a mask since block_size - # may be > than n_cols - sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + - sampled_token) - # Write back output to DRAM - output_row_start_ptr = (output_logprobs_ptr + - sample_idx * output_row_stride + best_idx) - tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c00da106734ae..487f5a3d2a441 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,12 +10,6 @@ import torch import torch.nn as nn -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics -from vllm.triton_utils import HAS_TRITON - -if HAS_TRITON: - from vllm.model_executor.layers.ops.sample import sample as sample_triton - import vllm.envs as envs from vllm.model_executor.sampling_metadata import (SamplingMetadata, SamplingTensors, @@ -23,6 +17,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -740,7 +735,7 @@ def _sample_with_torch( ) -> SampleReturnType: '''Torch-oriented _sample() implementation. - Single-step scheduling: + Single-step scheduling: * Perform GPU-side sampling computation * Immediately Pythonize sampling result @@ -777,7 +772,7 @@ def _sample_with_torch( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] + sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -863,88 +858,6 @@ def _sample_with_torch( ) -def _sample_with_triton_kernel( - probs: torch.Tensor, - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sampling_tensors: SamplingTensors, -) -> SampleResultType: - categorized_seq_group_ids: Dict[SamplingType, - List[int]] = {t: [] - for t in SamplingType} - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - sampling_params = seq_group.sampling_params - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample], - torch.Tensor, torch.Tensor]] = {} - max_best_of_in_batch = 1 - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type][:, 0] - sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - seq_group_id = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] - sample_metadata[sampling_type] = (seq_group_id, seq_groups, - sample_indices, - sampled_token_indices) - if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - for seq_group in seq_groups: - if seq_group.is_prompt: - sampling_params = seq_group.sampling_params - max_best_of_in_batch = max(max_best_of_in_batch, - sampling_params.best_of) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - sampled_tokens, _, _ = sample_triton( - probs=probs, - seeds=sampling_tensors.sampling_seeds, - max_best_of=max_best_of_in_batch, - sample_indices=sampling_tensors.sample_indices, - logprobs=logprobs, - # don't save logprobs because we have logic for that below - # TODO: use this instead of the CPU-based logic below - save_logprobs=False, - ) - - # GPU<->CPU sync happens in the loop below. - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups, sample_indices, - sampled_token_indices) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample( - seq_groups, sampled_tokens[sampled_token_indices][:, 0]) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, sampled_tokens[sampled_token_indices]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] - return sample_results - - def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -974,10 +887,6 @@ def _sample( modify_greedy_probs=modify_greedy_probs, ) - # TODO: Enable once Triton kernel & associated code is faster. - # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, - # sampling_tensors) - def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index a085779bc61a7..97d36d31f2b11 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,4 +1,3 @@ -import random from array import array from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -8,15 +7,10 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata) -from vllm.triton_utils.sample import get_num_triton_sampler_splits from vllm.utils import (PyObjectCache, async_tensor_h2d, - is_pin_memory_available, make_tensor_with_pad, - maybe_expand_dim) + is_pin_memory_available, make_tensor_with_pad) _SAMPLING_EPS = 1e-5 -_SEED_0_REPLACEMENT = 3403598558 -# Some triton sampler related code is guarded before it is ready. -_USE_TRITON_SAMPLER = False @dataclass @@ -74,12 +68,12 @@ def gen_seq_group_to_sample_builder(num_seqs: int): generator=None, is_prompt=True, prompt_logprob_indices=[], - sample_indices=[]) + sample_indices=[], + ) class SamplingMetadataCache: - """Used to cache SamplingMetadata objects between scheduler iterations - """ + """Used to cache SamplingMetadata objects between scheduler iterations""" def __init__(self): self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} @@ -124,12 +118,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. - skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU serialization of token outputs. - reuse_sampling_tensors: Indicates if we want to reuse sampling + reuse_sampling_tensors: Indicates if we want to reuse sampling tensors that are part of the sampler forward pass. Currently, it is mainly used for multi-step decode. - + """ def __init__( @@ -165,16 +159,19 @@ def prepare( num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, device, generators, cache) - selected_token_indices = async_tensor_h2d(selected_token_indices, - dtype=torch.long, - target_device=device, - pin_memory=pin_memory) + selected_token_indices = async_tensor_h2d( + selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) categorized_sample_indices = { - t: maybe_expand_dim( - async_tensor_h2d(seq_ids, - dtype=torch.int, - target_device=device, - pin_memory=pin_memory), 2, 2) + t: async_tensor_h2d( + seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory, + ) for t, seq_ids in categorized_sample_indices.items() } @@ -201,8 +198,8 @@ def _prepare_seq_groups( device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, -) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ - SamplingType, List[Tuple[int, int]]], int]: +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType, + List[int]], int, ]: """Prepare sequence groups and indices for sampling. Args: @@ -233,16 +230,13 @@ def _prepare_seq_groups( # Sampling type -> ( # indices to sample/prompt logprob within pruned output logits, # indices to sample within pruned logits) - categorized_sample_indices: Dict[SamplingType, List[Tuple[int, int]]] = { + categorized_sample_indices: Dict[SamplingType, List[int]] = { t: [] for t in SamplingType } # Index of logits to compute logprob. Logits include both prompt logprob # and sample logprob indices. logit_idx = 0 - # Index to sample from a sample tensor. It is used by triton sample kernel. - # See `_sample_with_triton_kernel` for more details. - sample_idx = 0 # Total number of prompts from given sequence groups. num_prompts = 0 @@ -264,10 +258,10 @@ def _prepare_seq_groups( # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - prompt_logprob_indices: List[int] = \ - sample_obj.prompt_logprob_indices if cache is not None else [] - sample_indices: List[int] = \ - sample_obj.sample_indices if cache is not None else [] + prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + if cache is not None else []) + sample_indices: List[int] = (sample_obj.sample_indices + if cache is not None else []) do_sample = seq_group_metadata.do_sample if seq_group_metadata.is_prompt: @@ -333,11 +327,8 @@ def sample(logits): if do_sample: sample_indices.extend(range(logit_idx, logit_idx + sample_len)) categorized_sample_indices[sampling_params.sampling_type].extend( - list( - zip(range(logit_idx, logit_idx + sample_len), - range(sample_idx, sample_idx + sample_len)))) + list(range(logit_idx, logit_idx + sample_len))) logit_idx += sample_len - sample_idx += sample_len if cache is not None: sample_obj.sampling_params = sampling_params @@ -356,7 +347,8 @@ def sample(logits): generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices)) + sample_indices=list(sample_indices), + ) seq_groups.append(sample_obj) @@ -378,9 +370,6 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor - sampling_seeds: torch.Tensor - sample_indices: torch.Tensor - extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @@ -391,15 +380,7 @@ def from_sampling_metadata( vocab_size: int, device: torch.device, dtype: torch.dtype, - *, - extra_seeds_to_generate: int = 0, - extra_entropy: Optional[Tuple[int, ...]] = None ) -> Tuple["SamplingTensors", bool, bool, bool]: - """ - extra_seeds_to_generate: extra seeds to generate using the - user-defined seed for each sequence. - extra_entropy: extra entropy to use when generating seeds. - """ prompt_tokens: List[array] = [] output_tokens: List[array] = [] top_ks: List[int] = [] @@ -409,19 +390,10 @@ def from_sampling_metadata( presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] - sampling_seeds: List[int] = [] - sample_indices: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False - if _USE_TRITON_SAMPLER: - prompt_best_of: List[int] = [] - - # We need one base seed per Triton slice. - seeds_to_generate = (extra_seeds_to_generate + - get_num_triton_sampler_splits(vocab_size)) - assert sampling_metadata.seq_groups is not None for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -452,7 +424,7 @@ def from_sampling_metadata( do_penalties = True is_prompt = seq_group.is_prompt - if (is_prompt and sampling_params.prompt_logprobs is not None): + if is_prompt and sampling_params.prompt_logprobs is not None: # For tokens in the prompt that we only need to get # their logprobs query_len = seq_group.query_len @@ -477,28 +449,6 @@ def from_sampling_metadata( frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) - if _USE_TRITON_SAMPLER: - if is_prompt: - prompt_best_of.append(sampling_params.best_of) - query_len = seq_group.query_len - assert query_len is not None - - seed = sampling_params.seed - is_greedy = sampling_params.sampling_type == SamplingType.GREEDY - - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - extra_entropy = extra_entropy or () - seq_seeds = cls._get_sequence_seeds( - seed, - seq_data.get_len(), - *extra_entropy, - seq_id, - seeds_to_generate=seeds_to_generate, - is_greedy=is_greedy) - sampling_seeds.append(seq_seeds) - sample_indices.extend(seq_group.sample_indices) - if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -518,23 +468,37 @@ def from_sampling_metadata( output_tokens.append(seq_data.output_token_ids_array) sampling_tensors = SamplingTensors.from_lists( - temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, sampling_seeds, - sample_indices, prompt_tokens, output_tokens, vocab_size, - extra_seeds_to_generate, device, dtype) + temperatures, + top_ps, + top_ks, + min_ps, + presence_penalties, + frequency_penalties, + repetition_penalties, + prompt_tokens, + output_tokens, + vocab_size, + device, + dtype, + ) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod - def from_lists(cls, temperatures: List[float], top_ps: List[float], - top_ks: List[int], min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - sampling_seeds: List[int], sample_indices: List[int], - prompt_tokens: List[array], output_tokens: List[array], - vocab_size: int, extra_seeds_to_generate: int, - device: torch.device, - dtype: torch.dtype) -> "SamplingTensors": + def from_lists( + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. pin_memory = is_pin_memory_available() @@ -603,34 +567,9 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype=torch.int, pin_memory=pin_memory, ) - sample_indices_t = torch.tensor( - sample_indices, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ) - # need to transpose and make contiguous to - # copy the tensor correctly. - # [batch_size, n_seeds] -> [n_seeds, batch_size] - sampling_seeds_t = torch.tensor( - sampling_seeds, - device="cpu", - dtype=torch.long, - pin_memory=pin_memory, - ).t().contiguous() - # Because the memory is pinned, we can do non-blocking # transfer to device. - # How many seeds the sample operation itself will need. - num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate - sampling_seeds_gpu = sampling_seeds_t.to(device=device, - non_blocking=True) - extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] - if not extra_seeds_gpu.numel(): - extra_seeds_gpu = None - sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] - return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -644,38 +583,4 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], non_blocking=True), prompt_tokens=prompt_t.to(device=device, non_blocking=True), output_tokens=output_t.to(device=device, non_blocking=True), - sampling_seeds=sampling_seeds_gpu, - sample_indices=sample_indices_t.to(device=device, - non_blocking=True), - extra_seeds=extra_seeds_gpu, ) - - @staticmethod - def _get_sequence_seeds( - seed: int, - *extra_entropy: int, - seeds_to_generate: int, - is_greedy: bool, - ): - """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" - if not is_greedy: - if seed is None: - randint_fn = random.randint - else: - generator = random.Random(str((seed, ) + extra_entropy)) - randint_fn = generator.randint - lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max - # If the user/random sets seed = 0 but request should - # have sampling, we need to change it to something - # else. We use a constant in that case. - # This way we don't need to create and load a bool - # matrix in the sampling kernel, which reduces CPU - # overhead and latency. - seq_seeds = [ - randint_fn(lo, hi) or _SEED_0_REPLACEMENT - for _ in range(seeds_to_generate) - ] - else: - # For the kernel, seed == 0 means greedy decoding. - seq_seeds = [0] * seeds_to_generate - return seq_seeds diff --git a/vllm/triton_utils/sample.py b/vllm/triton_utils/sample.py deleted file mode 100644 index 401e4d28a3c99..0000000000000 --- a/vllm/triton_utils/sample.py +++ /dev/null @@ -1,13 +0,0 @@ -import math - -# This is a hardcoded limit in Triton (max block size). -MAX_TRITON_N_COLS = 131072 - - -def get_num_triton_sampler_splits(n_cols: int) -> int: - """Get the number of splits to use for Triton sampling. - - Triton has a limit on the number of columns it can handle, so we need to - split the tensor and call the kernel multiple times if it's too large. - """ - return math.ceil(n_cols / MAX_TRITON_N_COLS) diff --git a/vllm/utils.py b/vllm/utils.py index 014fc16a17c1f..1cbd9d55c68b3 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -270,7 +270,7 @@ def clear(self): class PyObjectCache: - """Used to cache python objects to avoid object allocations + """Used to cache python objects to avoid object allocations across scheduler iterations. """ @@ -289,7 +289,7 @@ def _grow_cache(self): self._obj_cache.append(self._obj_builder()) def get_object(self): - """Returns a pre-allocated cached object. If there is not enough + """Returns a pre-allocated cached object. If there is not enough objects, then the cache size will double. """ if self._index >= len(self._obj_cache): @@ -837,15 +837,6 @@ def async_tensor_h2d( return t.to(device=target_device, non_blocking=True) -def maybe_expand_dim(tensor: torch.Tensor, - target_dims: int, - size: int = 1) -> torch.Tensor: - """Expand the tensor to the target_dims.""" - if tensor.ndim < target_dims: - tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) - return tensor - - def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() @@ -1070,7 +1061,7 @@ def _cuda_device_count_stateless( def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. - + This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" @@ -1136,10 +1127,10 @@ def parse_args(self, args=None, namespace=None): def _pull_args_from_config(args: List[str]) -> List[str]: """Method to pull arguments specified in the config file into the command-line args variable. - - The arguments in config file will be inserted between + + The arguments in config file will be inserted between the argument list. - + example: ```yaml port: 12323 @@ -1150,21 +1141,21 @@ def _pull_args_from_config(args: List[str]) -> List[str]: --config config.yaml -tp 2 $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--config', 'config.yaml', + "facebook/opt-12B", + '--config', 'config.yaml', '-tp', '2' ] $: args = [ "serve,chat,complete", - "facebook/opt-12B", - '--port', '12323', - '--tensor-parallel-size', '4', + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', '-tp', '2' ] ``` Please note how the config args are inserted after the sub command. - this way the order of priorities is maintained when these are args + this way the order of priorities is maintained when these are args parsed by super(). """ assert args.count( @@ -1190,7 +1181,7 @@ def _pull_args_from_config(args: List[str]) -> List[str]: @staticmethod def _load_config_file(file_path: str) -> List[str]: - """Loads a yaml file and returns the key value pairs as a + """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml port: 12323 @@ -1201,7 +1192,7 @@ def _load_config_file(file_path: str) -> List[str]: '--port': '12323', '--tensor-parallel-size': '4' ] - + """ extension: str = file_path.split('.')[-1]