diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a0498315516b8..684c54b7d8139 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -601,7 +601,7 @@ def _execute_model(*args): batch_idx += 1 else: for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx][0] + next_token_id = next_token_ids[batch_idx] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) @@ -722,6 +722,9 @@ def forward( sampled_token_ids = torch.multinomial(probs, num_samples, replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) next_token_ids = torch.where(t != 0, sampled_token_ids, argmax_token_ids) return next_token_ids