Skip to content

Commit

Permalink
Add output streaming support to multi-step + async while ensuring Req…
Browse files Browse the repository at this point in the history
…uestOutput obj reuse (#8335)
  • Loading branch information
alexm-redhat authored and agt committed Sep 24, 2024
1 parent 8581c76 commit 3075574
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 42 deletions.
6 changes: 5 additions & 1 deletion tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
MORE_ARGS_LIST = [
["--enable-chunked-prefill"], # Chunked
["--num-scheduler-steps", "8"], # MS
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
]


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def __init__(self,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False,
send_delta_data: bool = False) -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
Expand Down Expand Up @@ -1000,6 +1001,7 @@ def __init__(self,
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data
self._verify_args()

Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = False
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
Expand Down Expand Up @@ -595,6 +596,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=('Maximum number of forward steps per '
'scheduler call.'))

parser.add_argument(
'--multi-step-stream-outputs',
action='store_true',
help='If True, then multi-step will stream outputs for every step')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
Expand Down Expand Up @@ -999,6 +1004,7 @@ def create_engine_config(self) -> EngineConfig:
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
)
Expand Down
37 changes: 28 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,16 @@ class OutputData(NamedTuple):

class SchedulerContext:

def __init__(self):
def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None

self.multi_step_stream_outputs: bool = multi_step_stream_outputs

def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
Expand Down Expand Up @@ -219,6 +221,7 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
Expand All @@ -234,8 +237,9 @@ def __init__(
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -266,8 +270,10 @@ def __init__(
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
Expand All @@ -287,6 +293,7 @@ def __init__(
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs

if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
Expand Down Expand Up @@ -379,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
]

self.scheduler_contexts = [
SchedulerContext()
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]

Expand Down Expand Up @@ -998,7 +1006,8 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand All @@ -1019,8 +1028,8 @@ def _process_model_outputs(self,
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()

# For multi-step, do not create outputs each iteration
if not is_last_step:
# For multi-step without streaming, don't create outputs each iteration
if not is_last_step and not ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
Expand All @@ -1037,17 +1046,27 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

# For multi-step with streaming, create outputs each iteration
if not is_last_step and ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if self.process_request_outputs_callback is not None:
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return

for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue

request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down
9 changes: 8 additions & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def __init__(self,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.engine = LLMEngine(*args, **kwargs)
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True

self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
Expand Down
96 changes: 74 additions & 22 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,28 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def from_seq_group(cls,
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
def from_seq_group(cls, seq_group: SequenceGroup,
use_cache: bool) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")

finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None

# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[],
finished=False)

seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
Expand All @@ -149,29 +160,66 @@ def from_seq_group(cls,

outputs = []
include_prompt = True
for seq in top_n_seqs:
for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)

output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)

output_logprobs = seq.output_logprobs if include_logprobs else None

if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
output_logprobs = output_logprobs[-num_output_tokens:]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
if include_prompt and seq.get_output_len() > num_output_tokens:
include_prompt = False

outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
if use_cache:
# Get cached output object
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
if i >= len(cached_outputs):
cached_outputs.append(
CompletionOutput(index=i,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))
output = cached_outputs[i]

# Init cached output object
assert output.index == i
output.text = output_text

if isinstance(output_token_ids, int):
output.token_ids.clear()
output.token_ids.append(output_token_ids)
else:
output.token_ids = output_token_ids

output.cumulative_logprob = seq.get_cumulative_logprob() \
if include_logprobs else None
output.logprobs = output_logprobs
output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
output.stop_reason = seq.stop_reason

else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))
seq.stop_reason)

outputs.append(output)

# Every sequence in the sequence group should have the same prompt.
if include_prompt:
Expand All @@ -188,16 +236,20 @@ def from_seq_group(cls,
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
prompt,
prompt_token_ids,
prompt_logprobs,
outputs,
finished,
seq_group.metrics,
lora_request=seq_group.lora_request,
encoder_prompt=encoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids)

init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids)

if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(*init_args) # type: ignore

else:
request_output = cls(*init_args)

return request_output

def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
Expand Down Expand Up @@ -261,10 +313,10 @@ def __repr__(self):
class RequestOutputFactory:

@staticmethod
def create(seq_group):
def create(seq_group: SequenceGroup, use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group)
return RequestOutput.from_seq_group(seq_group, use_cache)
28 changes: 19 additions & 9 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def __init__(
self.stop_reason: Union[int, str, None] = None

# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_token_ids_offset: int = 0
self._last_output_text_offset: int = 0

# Used for incremental detokenization
Expand Down Expand Up @@ -499,18 +499,26 @@ def get_output_text_to_return(self, buffer_length: int,
return self.output_text[last_offset:length]
return ""

def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()

output_len = self.get_output_len()

# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len

# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]

return self.data._cached_all_token_ids[-num_new_tokens:]

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
Expand Down Expand Up @@ -671,6 +679,8 @@ def __init__(
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers

self.cached_request_output = None

@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
Expand Down

0 comments on commit 3075574

Please sign in to comment.