Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Aug 27, 2024
1 parent 6d4099f commit a1986d1
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 43 deletions.
5 changes: 4 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", num_scheduler_steps=8, use_v2_block_manager=True, disable_async_output_proc=False)
llm = LLM(model="facebook/opt-125m",
num_scheduler_steps=8,
use_v2_block_manager=True,
disable_async_output_proc=False)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
10 changes: 6 additions & 4 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str,
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):
num_scheduler_steps: int, num_prompts: int,
is_async: bool):

prompts = example_prompts
if len(prompts) < num_prompts:
Expand All @@ -62,9 +64,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int,
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

# Disable output proc callback as its not supported
# with multi-step right now
ms_server_args += ["--disable-async-output-proc"]
if not is_async:
ms_server_args += ["--disable-async-output-proc"]

if eager_mode:
ms_server_args.append("--enforce-eager")

Expand Down
53 changes: 37 additions & 16 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# skip the scheduler if there are any remaining steps in the seq groups.
Expand All @@ -294,11 +298,22 @@ async def step_async(
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
Expand All @@ -310,9 +325,6 @@ async def step_async(
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

assert not (self.scheduler_config.is_multi_step and \
allow_async_output_proc)

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
Expand Down Expand Up @@ -340,6 +352,8 @@ async def step_async(
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback_data[
virtual_engine]
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

# Execute the model.
output = await self.model_executor.execute_model_async(
Expand All @@ -349,7 +363,7 @@ async def step_async(
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if len(ctx.output_queue) > 0:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
Expand All @@ -361,22 +375,25 @@ async def step_async(
seq_group.finish_step()

if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
# Clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

# Cache results in engine
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
Expand All @@ -389,7 +406,11 @@ async def step_async(
self.do_tracing(scheduler_outputs)

else:
ctx.request_outputs = []
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
Expand Down
28 changes: 14 additions & 14 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,8 @@ def _process_model_outputs(self,
(outputs, seq_group_metadata_list,
scheduler_outputs) = ctx.output_queue.popleft()

assert outputs is not None

# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)
Expand Down Expand Up @@ -1325,20 +1327,19 @@ def _process_model_outputs(self,
self.output_processor.process_outputs(seq_group, output,
is_async)

# Free finished sequence groups.
if is_multi_step:
if is_last_output:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
else:
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()
# For async + multi-step, free finished seqs and create outputs
# only on the final step.
if is_multi_step and not is_last_output:
return

for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()

# Create the outputs.
for i, _ in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

if i in finished_before:
if not is_multi_step and i in finished_before:
continue # Avoids double processing

seq_group = scheduled_seq_group.seq_group
Expand All @@ -1354,10 +1355,7 @@ def _process_model_outputs(self,

# For async + multi-step, do stats only on the last output.
# Otherwise, do stats if the execution is async
if is_multi_step:
do_stats = is_last_output
else:
do_stats = is_async
do_stats = is_multi_step or is_async

if do_stats:
# Log stats.
Expand Down Expand Up @@ -1493,6 +1491,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))

Expand Down Expand Up @@ -1533,7 +1532,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callback_data[
virtual_engine]
execute_model_req.use_async_and_multi_step = use_async_and_multi_step
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step

output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
Expand Down
18 changes: 10 additions & 8 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from vllm import _custom_ops as ops
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
Logprob, SamplerOutput, SequenceGroupMetadata,
SequenceOutput, AsyncCallbackData)
from vllm.sequence import (AsyncCallbackData, CompletionSequenceGroupOutput,
IntermediateTensors, Logprob, SamplerOutput,
SequenceGroupMetadata, SequenceOutput)
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -233,6 +233,8 @@ def _async_process_outputs(self, model_input: StatefulModelInput,

def _final_process_outputs(self, model_input: StatefulModelInput,
output_proc_callback: AsyncCallbackData):
assert model_input.frozen_model_input is not None

if output_proc_callback is not None:
output_proc_fn = output_proc_callback.func
output_proc_kw_args = output_proc_callback.kw_args
Expand Down Expand Up @@ -325,13 +327,13 @@ def execute_model(
frozen_model_input = dataclasses.replace( # type: ignore
model_input.frozen_model_input,
async_callback=async_callback)
assert frozen_model_input is not None

# Execute the model
output = self._base_model_runner.execute_model(
frozen_model_input,
kv_caches,
intermediate_tensors,
num_steps=1)
output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches,
intermediate_tensors,
num_steps=1)

# record the event for the current step so that the next step can sync
model_input.record_step_event(current_stream)
Expand Down

0 comments on commit a1986d1

Please sign in to comment.