Skip to content

Commit

Permalink
Update bench
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 17, 2024
1 parent ba36b55 commit 401414b
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 40 deletions.
130 changes: 130 additions & 0 deletions benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance.
#
# Launch a server:
# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning

import random
import string
import time

from tqdm import tqdm
from transformers import AutoTokenizer

import sglang as sgl
from sglang import set_default_backend
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint


def generate_random_string(token_length: int) -> str:
random_string = "".join(
random.choices(string.ascii_letters + string.digits, k=token_length * 100)
)
tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[
:token_length
]

if len(tokenized_output) < token_length:
tokenized_output = tokenized_output + [tokenizer.pad_token_id] * (
token_length - len(tokenized_output)
)

decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False)
return decoded_string


def generate_unique_prefix(base_text, index):
return str(index) + base_text[len(str(index)) :]


@sgl.function
def text_qa(s, question, gen_len):
s += "Q: " + question + "\n"
s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len)


def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length):
base_prefix = generate_random_string(prefix_length)

tot_input_len = 0
all_prompts = []
for i in tqdm(range(num_prefix), desc="prepare prompts"):
unique_prefix = generate_unique_prefix(base_prefix, i)
prompt_list = []
for j in range(num_samples_per_prefix):
suffix = generate_random_string(suffix_length)
prompt = unique_prefix + suffix
prompt_list.append(prompt)
tot_input_len += len(tokenizer.encode(prompt))
all_prompts.append(prompt_list)
return all_prompts, tot_input_len


def test_batch_by_batch(all_prompts, gen_len):
backend.flush_cache()

tot_time = 0
for i in range(len(all_prompts)):
tic = time.time()
text_qa.run_batch(
list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))),
)
tot_time += time.time() - tic

return tot_time


def test_batch_by_batch_with_hint(all_prompts, gen_len):
backend.flush_cache()

tot_time = 0
for i in range(len(all_prompts)):
tic = time.time()
# Send a hint to cache the prefix
text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len])))
# Send the batch
text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))))

tot_time += time.time() - tic

return tot_time


def test_send_all(all_prompts, gen_len):
backend.flush_cache()

all_prompts = [x for prompt_list in all_prompts for x in prompt_list]

tic = time.time()
text_qa.run_batch(
list(zip(all_prompts, [gen_len] * len(all_prompts))),
)
tot_time = time.time() - tic

return tot_time


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
backend = RuntimeEndpoint("http://127.0.0.1:30000")
set_default_backend(backend)

random.seed(0)
num_prefix = 10
num_samples_per_prefix = 32
prefix_length = 1024
suffix_length = 128
gen_len = 1
all_prompts, tot_input_len = prepare_prompts(
num_prefix, num_samples_per_prefix, prefix_length, suffix_length
)

print(f"Total input token length: {tot_input_len}\n")

cost = test_batch_by_batch(all_prompts, gen_len)
print(f"Latency of test_batch_by_batch : {cost:.4f} s\n")

cost = test_batch_by_batch_with_hint(all_prompts, gen_len)
print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n")

cost = test_send_all(all_prompts, gen_len)
print(f"Latency of test_send_all : {cost:.4f} s\n")
85 changes: 45 additions & 40 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,19 @@
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
)

# The threshold to apply in-batch prefix caching.
# If we use too small value, in-batch prefix caching cannot be used. E.g.,
# imagine "the" prefix.
IN_BATCH_PREFIX_CACHING_THRESHOLD = int(
os.environ.get("SGLANG_IN_BATCH_PREFIX_CACHING_THRESHOLD", "32")
# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (against existing cache) less than this value,
# the scheduler runs the in-batch prefix caching check for this request.
# If we set it to -1, it means we disable in-batch prefix caching.
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
)

# Threshold for in-batch prefix cache.
# If a request has a matched prefix length (within the waiting queue) larger than this value,
# the scheduler deprioritizes this request
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
)


Expand All @@ -51,6 +59,11 @@ def __init__(self, policy: str, tree_cache: BasePrefixCache):
self.policy = policy
self.tree_cache = tree_cache

# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False
)

def calc_priority(self, waiting_queue: List[Req]):
if len(waiting_queue) > 128 and self.policy == "lpm":
# Turn off the expensive prefix matching and sorting when the #queue is large.
Expand All @@ -60,50 +73,51 @@ def calc_priority(self, waiting_queue: List[Req]):

# Compute matched prefix length
prefix_computed = False
# rid to deprioritize in the current run.
temporary_deprioritized = {}
if policy == "lpm" or policy == "dfs-weight":
# It is used to find the matching prefix for in-batch prefix caching.
temp_radix = RadixCache(None, None, False)
# rid to deprioritize in the current run for in-batch prefix caching.
temporary_deprioritized = set()
self.waiting_queue_radix_tree.reset()

for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids()

# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)

# NOTE(sang): This logic is for In-batch prefix caching;
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
# existing cache, but all those requests share the same prefix, we prefer
# to schedule only one of them so that we can increase the cache hit rate.
# We prefer to set IN_BATCH_PREFIX_CACHING_THRESHOLD > 0 because too small
# threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_THRESHOLD:
in_batch_matching_prefixes, _ = temp_radix.match_prefix(
rid=r.rid, key=prefix_ids
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = (
self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids
)
)
if (
len(in_batch_matching_prefixes)
>= IN_BATCH_PREFIX_CACHING_THRESHOLD
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
):
temporary_deprioritized[r.rid] = r
temporary_deprioritized.add(r.rid)
else:
temp_radix.insert(prefix_ids, torch.tensor(prefix_ids))
# Insert with a dummy key
self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
)

prefix_computed = True

if policy == "lpm":
# Longest Prefix Match
def get_priority(r: Req):
score = 0
if r.rid in temporary_deprioritized:
score = float("inf")
else:
score = -len(r.prefix_indices)
return score

waiting_queue.sort(key=get_priority)
waiting_queue.sort(
key=lambda r: (
-len(r.prefix_indices)
if r.rid not in temporary_deprioritized
else float("inf")
)
)
elif policy == "fcfs":
# first come first serve
pass
Expand All @@ -113,11 +127,11 @@ def get_priority(r: Req):
elif policy == "random":
random.shuffle(waiting_queue)
elif policy == "dfs-weight":
# Experimental policy based on custom weights
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)

# node -> # of requests for that node.
node_to_weight = defaultdict(int)
for node in last_node_to_reqs:
node_to_weight[node] = len(last_node_to_reqs[node])
Expand All @@ -129,9 +143,7 @@ def get_priority(r: Req):
node_to_weight,
last_node_to_reqs,
waiting_queue,
temporary_deprioritized,
)
waiting_queue.extend(temporary_deprioritized.values())
else:
raise ValueError(f"Unknown schedule_policy: {policy=}")

Expand All @@ -148,19 +160,12 @@ def get_dfs_priority(
node_to_priority: Dict[TreeNode, int],
last_node_to_reqs: Dict[TreeNode, List[Req]],
q: List,
temporary_deprioritized: Dict[str, Req],
):
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
self.get_dfs_priority(
child, node_to_priority, last_node_to_reqs, q, temporary_deprioritized
)

for req in last_node_to_reqs[cur_node]:
if req.rid in temporary_deprioritized:
continue
q.append(req)
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
q.extend(last_node_to_reqs[cur_node])


class AddReqResult(Enum):
Expand Down

0 comments on commit 401414b

Please sign in to comment.