Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/n] Triton sampling kernel #3186

Merged
merged 20 commits into from
Mar 20, 2024
51 changes: 51 additions & 0 deletions tests/kernels/test_rand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import pytest
import random

from vllm.model_executor.layers.triton_kernel.rand import seeded_uniform
from vllm.model_executor.utils import set_random_seed


@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_3d", [True, False])
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
device = "cuda"
for seed in range(512):
set_random_seed(seed)
rows = random.randint(1, 512)
cols = random.randint(1, 64000)
if use_3d:
third_dim = random.randint(2, 10)
dims = [rows, third_dim, cols]
else:
dims = [rows, cols]
seeds = torch.randint(torch.iinfo(torch.long).min,
torch.iinfo(torch.long).max, (rows, ),
device=device)

# Test that the same seed produces the same output
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out2)
# del to save memory
del out2

out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out3)
# del to save memory
del out3

# Initialize out tensor with garbage to ensure that it is overwritten
out_with_tensor = seeded_uniform(
*dims,
out=torch.full(
(*dims, ),
-1,
dtype=dtype,
device=device,
),
seeds=seeds,
dtype=dtype,
)
torch.testing.assert_close(out, out_with_tensor)
196 changes: 196 additions & 0 deletions tests/kernels/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import gc

import torch
import pytest
import triton
import triton.language as tl

from vllm.model_executor.layers.triton_kernel.sample import (
_uniform_to_exponential, sample, get_num_triton_sampler_splits,
MAX_TRITON_N_COLS)
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.sampling_metadata import SamplingTensors

SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100


@pytest.fixture(autouse=True)
def _cleanup():
yield
gc.collect()
torch.cuda.empty_cache()


@triton.jit
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = _uniform_to_exponential(x)
tl.store(output + idx, y)


def test_uniform_to_exponential():
"""Test that we can convert uniform to exponential without div by 0."""
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
dtype=torch.float32,
device="cuda")
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
_uniform_to_exponential_kernel[(1, )](input, output, 2)
assert torch.all(torch.isfinite(output))
assert torch.all(output > 0)
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))


@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
@pytest.mark.parametrize("save_logprobs", [True, False])
def test_sample_decoding_only(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size,
save_logprobs):
set_random_seed(seed)
bs = 8
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = (torch.rand(
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, bs),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, bs),
dtype=torch.bool,
device="cuda")

seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
request_uses_random_sampling = random_sampling_mask[0, i]
if modify_greedy_probs and not request_uses_random_sampling:
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
assert torch.allclose(
probs[i][sampled_tokens[i]],
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
assert torch.sum(probs[i]) == 1.0
assert torch.allclose(
sampled_modified_probs[i][0],
torch.full_like(sampled_modified_probs[i][0], 1.0))
elif request_uses_random_sampling:
# If the request is random, we want to make sure
# sampled_modified_probs tensor has noise added
# (and thus is different from probs tensor)
assert not torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])
elif not request_uses_random_sampling:
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
assert torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])

if save_logprobs:
assert sampled_logprobs.shape == (bs, max_best_of)
for i in range(bs):
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[i][
sampled_tokens[i, best_of]])
else:
assert sampled_logprobs is None


@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size):
set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2
samples = 8
bs = samples + sum(prompt_sizes)
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.tensor(prompt_sizes,
dtype=torch.long,
device="cuda").cumsum_(0)
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = torch.rand(
(n_splits, samples), device="cuda") < 0.5
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, samples),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, samples),
dtype=torch.bool,
device="cuda")

seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, _ = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices):
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
[sampled_tokens[i, best_of]])


@pytest.mark.parametrize("seed", list(range(16)))
def test_get_sequence_seeds(seed):
"""Ensure that we get a different child seed from base
seed + extra entropy"""
starting_seed = seed
seq_seed = None
extra_entropy = 1
for i in range(512):
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
i,
seeds_to_generate=1,
is_greedy=False)[0]
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
starting_seed,
i,
extra_entropy,
seeds_to_generate=1,
is_greedy=False)[0]
assert new_seq_seed_extra_entropy != new_seq_seed
assert seq_seed != new_seq_seed
seq_seed = new_seq_seed
6 changes: 3 additions & 3 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
logits[len(token_ids)] = torch.finfo(logits.dtype).max
return logits

seq_group_metadata_list = []
Expand Down Expand Up @@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):

sample_probs = None

def mock_sample(probs, logprobs, sampling_metadata):
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
Expand Down
Loading
Loading