Skip to content

Commit

Permalink
Remove events
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon committed Sep 14, 2024
1 parent 27cf757 commit 62c454c
Showing 1 changed file with 31 additions and 50 deletions.
81 changes: 31 additions & 50 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import threading
import time
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Expand All @@ -8,7 +7,6 @@
import numpy as np
import torch
import torch.nn as nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr

Expand Down Expand Up @@ -124,9 +122,7 @@ def __init__(
self.block_size,
False,
)
self.cached_step_outputs: List[Tuple[Optional[threading.Event],
torch.Tensor]] = []
self.cached_sampler_outputs: List[SamplerOutput] = []
self.cached_step_outputs: List[torch.Tensor] = []

def load_model(self) -> None:
self.device = self.device_config.device
Expand Down Expand Up @@ -528,36 +524,34 @@ def execute_model(
) -> List[SamplerOutput]:
assert intermediate_tensors is None
if not model_input.is_first_multi_step:
if model_input.async_callback is not None:
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[self.cached_sampler_outputs.pop(0)],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
model_input.async_callback()

event, next_token_ids = self.cached_step_outputs.pop(0)
if event is not None:
event.wait()
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
self.cached_sampler_outputs.append(sampler_output)

if not model_input.is_last_step:
return [sampler_output]
return []

use_async_out_proc = model_input.async_callback is not None
sampler_outputs = []
num_outputs = len(self.cached_step_outputs)
for i in range(num_outputs):
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
sampler_outputs.append(sampler_output)

if i < num_outputs - 1 and use_async_out_proc:
assert model_input.async_callback is not None
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
ctx.append_output(
outputs=[sampler_output],
seq_group_metadata_list=ctx.seq_group_metadata_list,
scheduler_outputs=ctx.scheduler_outputs,
is_async=False,
is_last_step=False)
model_input.async_callback()
if use_async_out_proc:
return [sampler_outputs[-1]]
else:
if model_input.async_callback is not None:
assert len(self.cached_sampler_outputs) == 1
self.cached_sampler_outputs = []
return [sampler_output]
else:
sampler_outputs = self.cached_sampler_outputs
self.cached_sampler_outputs = []
return sampler_outputs
return sampler_outputs

is_prompt = model_input.attn_metadata.num_prefills > 0
if is_prompt:
Expand Down Expand Up @@ -650,8 +644,7 @@ def execute_model(
model_input.num_samples,
kv_caches,
is_prompt=False)
event = _get_event(output_token_ids) if num_steps > 1 else None
self.cached_step_outputs.append((event, output_token_ids))
self.cached_step_outputs.append(output_token_ids)

if i < num_steps - 1:
# Prepare the inputs for the next step.
Expand All @@ -675,15 +668,13 @@ def execute_model(
if model_input.async_callback is not None:
model_input.async_callback()

if num_steps > 1:
return []
# Retrieve the outputs to CPU.
event, next_token_ids = self.cached_step_outputs.pop(0)
if event is not None:
event.wait()
next_token_ids = self.cached_step_outputs.pop(0)
next_token_ids = next_token_ids.cpu().tolist()
sampler_output = _make_decode_output(next_token_ids,
model_input.seq_groups)
if num_steps > 1:
self.cached_sampler_outputs.append(sampler_output)
return [sampler_output]


Expand Down Expand Up @@ -834,16 +825,6 @@ def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return logits


def _get_event(x: torch.Tensor) -> threading.Event:
event = threading.Event()

def _callback_wrapper():
event.set()

torch_xla._XLAC._on_ready_callback(x, _callback_wrapper)
return event


def _make_decode_output(
next_token_ids: List[int],
seq_groups: List[List[int]],
Expand Down

0 comments on commit 62c454c

Please sign in to comment.