Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support beam search & parallel generation #7

Merged
merged 43 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e0a8519
Minor
WoosukKwon Mar 6, 2023
a320382
Add get_seqs
WoosukKwon Mar 6, 2023
cf9536f
Minor
WoosukKwon Mar 6, 2023
19ff0d0
max_context_len -> context_window_size
WoosukKwon Mar 6, 2023
c6ae9e1
Add __repr__ to SamplingParams
WoosukKwon Mar 7, 2023
22822d8
Minor
WoosukKwon Mar 8, 2023
271d5df
Enhance Frontend and SamplingParams
WoosukKwon Mar 8, 2023
ab13c30
Add get_last_token_id
WoosukKwon Mar 8, 2023
2f04d46
Add InputSequenceGroup
WoosukKwon Mar 8, 2023
f743df9
Support temperature & top_p sampling
WoosukKwon Mar 8, 2023
f10fbac
Support parallel generation
WoosukKwon Mar 8, 2023
f1f49f8
Use n=2 for test inputs
WoosukKwon Mar 8, 2023
ac85d81
Enforce zero temperature for beam search
WoosukKwon Mar 9, 2023
b790887
Remove group_id from seq_groups
WoosukKwon Mar 9, 2023
bdbb3f9
Refactor Sampler
WoosukKwon Mar 9, 2023
261f3cd
Minor
WoosukKwon Mar 9, 2023
6184793
Use replacement=True for torch.multinomial
WoosukKwon Mar 9, 2023
c158f6e
Fix a bug in block copy
WoosukKwon Mar 9, 2023
4340cdb
InputSequenceGroup -> SequenceGroupInput
WoosukKwon Mar 9, 2023
893c1a0
SequenceGroupInput -> SequenceGroupInputs
WoosukKwon Mar 9, 2023
7b8889c
Add SequenceOutputs & Stre logprobs for sequences
WoosukKwon Mar 9, 2023
f8493e6
Add num_logprobs to SamplingParams
WoosukKwon Mar 9, 2023
38244e4
Use num_logprobs in sampling_params
WoosukKwon Mar 9, 2023
2a4b8bb
[WIP] Refactor
WoosukKwon Mar 9, 2023
d449b3d
Minor
WoosukKwon Mar 9, 2023
2ac01dc
[WIP] Refactor
WoosukKwon Mar 9, 2023
de1c3d7
Implement beam search
WoosukKwon Mar 9, 2023
0daed38
Minor
WoosukKwon Mar 9, 2023
a0a55b0
Shallow copy -> deep copy
WoosukKwon Mar 9, 2023
e1f359a
Bugfix for beam search
WoosukKwon Mar 9, 2023
6ddcf6b
Minor
WoosukKwon Mar 9, 2023
d9610f3
Update server.py
WoosukKwon Mar 9, 2023
b6312a4
Minor
WoosukKwon Mar 9, 2023
f41b240
Refactor
WoosukKwon Mar 9, 2023
3ec2aa5
Minor
WoosukKwon Mar 9, 2023
d25d97d
Fix a logprob bug in beam search
WoosukKwon Mar 9, 2023
bec648b
Add __repr__ to SequenceOutputs
WoosukKwon Mar 9, 2023
9b08425
Minor
WoosukKwon Mar 9, 2023
69fc1ef
Minor
WoosukKwon Mar 9, 2023
7c07891
Minor
WoosukKwon Mar 10, 2023
7239d4c
Add seed
WoosukKwon Mar 10, 2023
2ccb084
Minor change in comment
WoosukKwon Mar 10, 2023
b6684f6
Minor
WoosukKwon Mar 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cacheflow/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def append(self, token_ids: List[int]) -> None:
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]

def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]


class PhysicalTokenBlock:

Expand Down
33 changes: 28 additions & 5 deletions cacheflow/master/frontend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Set, Tuple

from transformers import AutoTokenizer

Expand All @@ -25,12 +25,35 @@ def __init__(
def query(
self,
prompt: str,
sampling_params: Optional[SamplingParams] = None,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_num_steps: int = 16, # From OpenAI API.
num_logprobs: int = 0,
context_window_size: Optional[int] = None,
) -> None:
if sampling_params is None:
sampling_params = SamplingParams()
token_ids: List[int] = self.tokenizer.encode(prompt)
# Stop when we see an EOS token.
stop_token_ids.add(self.tokenizer.eos_token_id)
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop_token_ids=stop_token_ids,
max_num_steps=max_num_steps,
num_logprobs=num_logprobs,
context_window_size=context_window_size,
)
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)

def _add_query(
self,
token_ids: List[int],
sampling_params: SamplingParams,
) -> None:
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
Expand Down
91 changes: 55 additions & 36 deletions cacheflow/master/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Dict, List, Tuple
from typing import Dict, List

from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus

_MAX_NUM_BATCHED_TOKENS = 2048
Expand Down Expand Up @@ -66,15 +68,18 @@ def _allocate(self, seq_group: SequenceGroup) -> None:
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
blocks_to_copy[src_block] = dst_block
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]

def _swap_in(
self,
Expand All @@ -83,9 +88,8 @@ def _swap_in(
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.SWAPPED:
seq.status = SequenceStatus.RUNNING
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)

def _swap_out(
Expand All @@ -96,16 +100,15 @@ def _swap_out(
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.seqs:
if seq.status == SequenceStatus.RUNNING:
seq.status = SequenceStatus.SWAPPED
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)

def step(self) -> None:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}

# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
Expand Down Expand Up @@ -143,6 +146,10 @@ def step(self) -> None:
# All swapped sequences are swapped in.
self.swapped.clear()

# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out

num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
Expand All @@ -152,7 +159,6 @@ def step(self) -> None:
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped:
# FIXME(woosuk): Acquire a lock to protect pending.
self._fetch_inputs()
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
Expand All @@ -168,73 +174,86 @@ def step(self) -> None:
else:
self.pending.clear()

# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out

# 4. Create input data structures.
prompt_tokens: Dict[int, List[int]] = {}
generation_tokens: Dict[int, int] = {}
context_lens: Dict[int, int] = {}
block_tables: Dict[int, List[int]] = {}
input_seq_groups: List[SequenceGroupInputs] = []
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]

# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
for seq in seq_group.seqs:
if seq.status != SequenceStatus.RUNNING:
continue

input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
prompt_tokens[seq_id] = seq.get_token_ids()
input_tokens[seq_id] = seq.get_token_ids()
else:
generation_tokens[seq_id] = seq.get_token_ids()[-1]
context_lens[seq_id] = seq.get_len()
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()

input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)

# 5. Execute the first stage of the pipeline.
self.controllers[0].execute_stage(
prompt_tokens,
generation_tokens,
context_lens,
block_tables,
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)

def post_step(
self,
next_tokens: Dict[int, Tuple[int, int]],
seq_outputs: Dict[int, SequenceOutputs],
) -> None:
# Update the running sequences and free blocks.
for seq_group in self.running:
group_id = seq_group.group_id
self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids

# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

parent_seq_id, next_token = next_tokens[seq.seq_id]
if seq.seq_id != parent_seq_id:
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(parent_seq_id)
seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)

# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue

# Append a new token to the sequence.
seq.append([next_token])
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)

# Check if the sequence has generated a stop token.
if next_token in stop_token_ids:
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue

Expand Down
4 changes: 3 additions & 1 deletion cacheflow/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_model
from cacheflow.models.model_utils import set_seed


__all__ = [
'get_model',
'InputMetadata',
'get_model',
'set_seed'
]
18 changes: 11 additions & 7 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
from typing import List
from typing import List, Dict, Tuple

import torch

from cacheflow.sampling_params import SamplingParams


class InputMetadata:

def __init__(
self,
seq_ids: List[int],
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
# FIXME: Rename
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
self.seq_ids = seq_ids
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert self.num_generation_tokens == block_tables.shape[0]
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens

def __repr__(self) -> str:
return (f'InputMetadata('
f'seq_ids={self.seq_ids}, '
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
Expand Down
10 changes: 10 additions & 0 deletions cacheflow/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -30,3 +32,11 @@ def get_model(
model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype)
return model.eval()
raise ValueError(f'Invalid model name: {model_name}')


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
3 changes: 2 additions & 1 deletion cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.sequence import SequenceOutputs

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -261,7 +262,7 @@ def forward(
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, Tuple[int, int]]:
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
Expand Down
Loading