Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: multi-step args and sequence modifications #713

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,8 @@ def __init__(self,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
Expand Down Expand Up @@ -952,6 +953,7 @@ def __init__(self,
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps

self._verify_args()

Expand All @@ -978,6 +980,16 @@ def _verify_args(self) -> None:
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")

if self.num_scheduler_steps < 1:
raise ValueError(
"num_scheduler_steps "
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")

@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1


class DeviceConfig:

Expand Down
57 changes: 56 additions & 1 deletion aphrodite/common/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, cast

import numpy
import torch

from aphrodite.common.pooling_params import PoolingParams
Expand Down Expand Up @@ -474,6 +475,19 @@ def __repr__(self) -> str:
f"num_blocks={self.n_blocks}, ")


@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""

# for multi-step decoding
num_steps: int = 1
current_step: int = 0

@property
def remaining_steps(self) -> int:
return self.num_steps - self.current_step


class SequenceGroup:
"""A group of sequences that are generated from the same prompt.

Expand Down Expand Up @@ -516,6 +530,7 @@ def __init__(
time_in_queue=None)
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
Expand Down Expand Up @@ -569,6 +584,10 @@ def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0

def init_multi_step(self, num_scheduler_steps: int) -> None:
self.state.num_steps = num_scheduler_steps
self.state.current_step = 0

def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
Expand Down Expand Up @@ -735,6 +754,7 @@ class SequenceGroupMetadata:
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
lora_request: LoRA request.
state: Internal state tied to this sequence group.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
multi_modal_data: Multi modal data.
Expand Down Expand Up @@ -762,6 +782,7 @@ def __init__(
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
Expand All @@ -777,6 +798,7 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size
Expand Down Expand Up @@ -815,6 +837,10 @@ def token_chunk_size(self) -> int:
assert self._token_chunk_size is not None
return self._token_chunk_size

def finish_step(self) -> None:
assert self.state.current_step < self.state.num_steps
self.state.current_step += 1


class SequenceOutput:
"""The model output associated with a sequence.
Expand Down Expand Up @@ -952,6 +978,7 @@ class SamplerOutput:

# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional[torch.Tensor] = None
sampled_token_ids_numpy: Optional[numpy.ndarray] = None

# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
Expand Down Expand Up @@ -1086,6 +1113,33 @@ class ExecuteModelRequest:
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
# The last sampled token ids for multi step decoding.
last_sampled_token_ids: Optional[torch.Tensor] = None

@property
def is_first_multi_step(self) -> bool:
# TODO: make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
return first_seq_group.state.current_step == 0

@property
def is_last_step(self) -> bool:
# TODO: make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
first_seq_group = self.seq_group_metadata_list[0]
num_steps = first_seq_group.state.num_steps
current_step = first_seq_group.state.current_step
return num_steps - current_step == 1

@property
def current_step(self) -> int:
# TODO: make this be able to handle batches with variable number of
# steps
assert len(self.seq_group_metadata_list) > 0
return self.seq_group_metadata_list[0].state.current_step

def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
Expand All @@ -1102,4 +1156,5 @@ def clone(
previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids,
)
last_sampled_token_ids=self.last_sampled_token_ids.clone()
if self.last_sampled_token_ids is not None else None)
28 changes: 25 additions & 3 deletions aphrodite/engine/args_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
num_scheduler_steps: int = 1
# Speculative Decoding Options
num_lookahead_slots: int = 0
speculative_model: Optional[str] = None
Expand Down Expand Up @@ -617,6 +618,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="Category: API Options\n"
"maximum number of sequences per iteration",
)
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
# Speculative Decoding Options
parser.add_argument("--num-lookahead-slots",
type=int,
Expand Down Expand Up @@ -970,19 +976,35 @@ def create_engine_config(self, ) -> EngineConfig:
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)

if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill:
raise ValueError("Chunked prefill is not supported with "
"multi-step (--num-scheduler-steps > 1)")

# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step
num_lookahead_slots = max(self.num_lookahead_slots,
self.num_scheduler_steps - 1)
num_lookahead_slots = num_lookahead_slots \
if speculative_config is None \
else speculative_config.num_lookahead_slots

scheduler_config = SchedulerConfig(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
is_attention_free=model_config.is_attention_free(),
use_v2_block_manager=self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else
speculative_config.num_lookahead_slots),
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
)

lora_config = LoRAConfig(
Expand Down
5 changes: 5 additions & 0 deletions aphrodite/processing/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,9 @@ def _schedule_prefills(
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_group.init_multi_step(
num_scheduler_steps=self._get_num_lookahead_slots(
is_prefill=True) + 1)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
Expand Down Expand Up @@ -1105,6 +1108,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down Expand Up @@ -1170,6 +1174,7 @@ def _append_slots(
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
Expand Down
Loading