Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Optimize Async + Multi-step #8050

Merged
merged 8 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/multi_step/test_correctness_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ async def test_multi_step(
model,
server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this change needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was increased originally for multi-step tests, but I think it was still sensitive, so I had one instance when I had a timeout. Increasing more did make the test stable.

max_wait_seconds=5 * 240)
test_completions = await completions_with_server_args(
prompts,
model,
ms_server_args + distributed_args,
num_logprobs,
max_wait_seconds=3 * 240)
max_wait_seconds=5 * 240)

# Assert multi-step scheduling produces identical tokens
# to single-step scheduling.
Expand Down
109 changes: 54 additions & 55 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,40 +280,27 @@ async def step_async(
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)

ctx = self.scheduler_contexts[virtual_engine]

# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()

# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):

# Clear outputs on scheduler iteration start
ctx.request_outputs.clear()

# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

# Detect async + multi-step
use_async_and_multi_step = (self.scheduler_config.is_multi_step
and allow_async_output_proc)
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)

# For async + multi-step, init the queue
if use_async_and_multi_step:
assert len(ctx.output_queue) == 0
assert seq_group_metadata_list is not None
ctx.output_queue.append(
(None, seq_group_metadata_list, scheduler_outputs))
self._process_model_outputs(ctx=ctx)

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
Expand Down Expand Up @@ -351,26 +338,20 @@ async def step_async(
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
async_callback = self.async_callback_multi_step[
virtual_engine] if use_async_and_multi_step \
else self.async_callback[virtual_engine]

execute_model_req.async_callback = async_callback
execute_model_req.use_async_and_multi_step = \
use_async_and_multi_step
execute_model_req.async_callback = self.async_callback[
virtual_engine]

# Execute the model.
output = await self.model_executor.execute_model_async(
execute_model_req)

# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
else:
if not use_async_and_multi_step and len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
output = []

# Finish the current step for all the sequence groups.
Expand All @@ -384,24 +365,22 @@ async def step_async(
self.cached_scheduler_outputs[
virtual_engine] = SchedulerOutputState()

if use_async_and_multi_step:
# For async + multi-step, clear the queue
ctx.output_queue.clear()
else:
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs))
is_async = allow_async_output_proc
is_last_step = True
ctx.output_queue.append(
(output, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step))

if output and allow_async_output_proc:
assert len(
output
) == 1, "Multi step decoding does not work with async output processing." # noqa: E501
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
if output and allow_async_output_proc:
assert len(
output
) == 1, "Async postprocessor expects only a single output set"
self._advance_to_next_step(
output[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

if not allow_async_output_proc:
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=False)
self._process_model_outputs(ctx=ctx)

# Log stats.
self.do_log_stats(scheduler_outputs, output)
Expand All @@ -411,17 +390,12 @@ async def step_async(

else:
# Multi-step case
if use_async_and_multi_step:
return []
else:
ctx.request_outputs = []
return ctx.request_outputs

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
assert not self.scheduler_config.is_multi_step
self._process_model_outputs(virtual_engine=virtual_engine,
is_async=True)
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0

return ctx.request_outputs
Expand Down Expand Up @@ -640,6 +614,17 @@ def __init__(self,
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)

# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
# TODO: Currently, disabled for engine_use_ray, ask
# Cody/Will/Woosuk about this case.
self.use_process_request_outputs_callback = not self.engine_use_ray
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
comaniac marked this conversation as resolved.
Show resolved Hide resolved

if self.engine_use_ray:
print_warning_once(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
Expand Down Expand Up @@ -883,13 +868,27 @@ async def engine_step(self, virtual_engine: int) -> bool:
request_outputs = await self.engine.step_async(virtual_engine)

# Put the outputs into the corresponding streams.
finished = True
# If used as a callback, then already invoked inside
# LLMEngine's _process_model_outputs
if not self.use_process_request_outputs_callback:
all_finished = self.process_request_outputs(request_outputs)
else:
# For callback case, we only need to detect when all
# requests are finished
all_finished = all(request_output.finished
for request_output in request_outputs)

return not all_finished

def process_request_outputs(self, request_outputs) -> bool:
# Put the outputs into the corresponding streams.
all_finished = True
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
finished = finished and request_output.finished
all_finished = all_finished and request_output.finished

return not finished
return all_finished

async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
Expand Down
Loading
Loading