Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][TPU] Raise errors for unsupported sampling params #5850

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
logger = init_logger(__name__)

_PAD_SLOT_ID = 0 # FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False


class TPUModelRunner:
Expand Down Expand Up @@ -339,9 +341,34 @@ def _prepare_sample(
assert seq_group_metadata.sampling_params is not None
sampling_params = seq_group_metadata.sampling_params

# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
if sampling_params.temperature >= 1e-5 else 1e-5)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k != -1:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.best_of > 1:
raise NotImplementedError(
"best_of > 1 is not currently supported by the TPU "
"backend.")
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
if sampling_params.prompt_logprobs is not None:
raise NotImplementedError(
"prompt_logprobs is not currently supported by the TPU "
"backend.")

num_paddings = padded_batch_size - len(seq_group_metadata_list)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
Expand All @@ -350,35 +377,32 @@ def _prepare_sample(
p = torch.tensor(p, dtype=torch.float32, device=self.device)
return t, p

def prepare_inputs(
def _execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
):
assert seq_group_metadata_list is not None
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> List[CompletionSequenceGroupOutput]:
# Prepare inputs.
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
if seq_group_metadata_list[0].is_prompt:
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
sample_inputs = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
return inputs + sample_inputs
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)

def _execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> List[CompletionSequenceGroupOutput]:
inputs = self.prepare_inputs(seq_group_metadata_list)
# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:])
if not self.is_driver_worker:
return []
*inputs[2:], t, p)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()

# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
Expand All @@ -400,6 +424,7 @@ def execute_model(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput:
assert seq_group_metadata_list is not None
assert len(seq_group_metadata_list) > 0
if seq_group_metadata_list[0].is_prompt:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
Expand Down Expand Up @@ -492,8 +517,8 @@ def forward(
logits = self.model.compute_logits(hidden_states, sampling_metadata)

logits = logits / t.unsqueeze(dim=1)
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
Expand Down
Loading