Skip to content

Commit

Permalink
inplace test
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 9, 2024
1 parent 94ed67c commit 2fb2bc1
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import random
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
import torch
Expand Down Expand Up @@ -703,3 +703,28 @@ def test_sampling_params(sampling_params: List[SamplingParams]):

assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]


@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampler.include_gpu_probs_tensor = True
sampler.should_modify_greedy_probs_inplace = False

sampling_params = SamplingParams(temperature=0)

mock_inplace = Mock()
with patch(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
mock_inplace):

sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
mock_inplace.assert_not_called()

assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None

0 comments on commit 2fb2bc1

Please sign in to comment.