Skip to content

Commit

Permalink
Enable cuda graph by default (#612)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 13, 2024
1 parent 396a692 commit 6658159
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 70 deletions.
82 changes: 58 additions & 24 deletions benchmark/latency_throughput/bench_one.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,43 @@
"""
Usage:
python3 bench_one.py --input-len 2048 --batch-size 1 2 4 8 16 32 64 128 256 512
"""

import argparse
import json
import time

import numpy as np
import requests

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()

if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")

def run_one_batch_size(bs):
url = f"{args.host}:{args.port}"
a = 20
max_new_tokens = args.max_tokens

a = 20
prompt = f"{a, }"

tic = time.time()
if args.backend == "srt":
if args.input_len:
inputs = {"input_ids": [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
]}
else:
inputs = {"text": [
f"{i, }" for i in range(bs)
]}

response = requests.post(
url + "/generate",
json={
"text": [prompt] * args.batch_size,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
**inputs,
},
)
elif args.backend == "lightllm":
Expand Down Expand Up @@ -91,5 +89,41 @@
ret = response.json()
print(ret)

speed = args.batch_size * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
output_throughput = bs * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {output_throughput:.2f} token/s")

with open("tmp_output.txt", "a") as fout:
res = {
"input_len": args.input_len,
"output_len": args.max_tokens,
"batch_size": bs,
"latency": latency,
"output_throughput": output_throughput
}
fout.write(json.dumps(res) + "\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--input-len", type=int, default=None)
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()

if args.port is None:
if args.backend == "srt":
args.port = 30000
elif args.backend == "vllm":
args.port = 21000
elif args.backend == "lightllm":
args.port = 22000
elif args.backend == "ginfer":
args.port = 9988
else:
raise ValueError(f"Invalid backend: {args.backend}")

for bs in args.batch_size:
run_one_batch_size(bs)
1 change: 0 additions & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import dataclasses
import logging
import multiprocessing
import os
import time


Expand Down
38 changes: 21 additions & 17 deletions python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,40 @@ def __init__(self):
# 2: output final text after every run
self.verbosity = 0

# Default backend of the language
self.default_backend = None

# Output configs
# Runtime constants: Request dependency time due to network delay
self.request_dependency_delay = 0.02
self.wait_for_new_request_delay = 0.0006

# Runtime constants: New generation token ratio estimation
self.base_new_token_ratio = 0.4
self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001
self.new_token_ratio_recovery = 0.05

# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192

# Runtime constants: Flashinfer
self.flashinfer_workspace_size = 192 * 1024 * 1024

# Output tokenization configs
self.skip_special_tokens_in_output = True
self.spaces_between_special_tokens_in_out = True

# Optimization configs
# Interpreter optimization configs
self.eager_fill_image = False
self.enable_precache_with_tracing = True
self.enable_parallel_encoding = True
self.enable_parallel_decoding = True

# Deprecated
# Choices: ["no_adjust", "adjust_cache"]
# no_adjust: Do not adjust the position embedding of KV cache.
# adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust"

# Request dependency time due to network delay
self.request_dependency_delay = 0.02
self.wait_for_new_request_delay = 0.0006

# New generation token ratio estimation
self.base_new_token_ratio = 0.4
self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001
self.new_token_ratio_recovery = 0.05

# The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192


global_config = GlobalConfig()
173 changes: 173 additions & 0 deletions python/sglang/srt/managers/controller/cuda_graph_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Run the model with cuda graph."""

import bisect

import torch
from vllm.distributed.parallel_state import graph_capture

from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.infer_batch import (
Batch, ForwardMode, InputMetadata, init_flashinfer_args
)


class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture):
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None

# Common inputs
self.max_bs = max_batch_size_to_capture
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")

# FlashInfer inputs
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
self.flashinfer_kv_indptr = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_indices = torch.zeros(
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
)

def can_run(self, batch_size):
return batch_size < self.max_bs

def capture(self, batch_size_list):
self.batch_size_list = batch_size_list
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in batch_size_list:
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler

def capture_one_batch_size(self, bs):
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels

graph = torch.cuda.CUDAGraph()
stream = self.stream

# Common inputs
input_ids = self.input_ids[:bs]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
position_ids_offsets = self.position_ids_offsets[:bs]
out_cache_loc = self.out_cache_loc[:bs]

# FlashInfer inputs
if not _grouped_size_compiled_for_decode_kernels(
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
init_flashinfer_args(
ForwardMode.DECODE,
self.model_runner,
req_pool_indices,
seq_lens,
None,
flashinfer_decode_wrapper,
)

# Run and capture
def run_once():
input_metadata = InputMetadata.create(
self.model_runner,
forward_mode=ForwardMode.DECODE,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
prefix_lens=None,
position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc,
out_cache_cont_start=None,
out_cache_cont_end=None,
return_logprob=False,
top_logprobs_nums=0,
skip_flashinfer_init=True,
)
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
return self.model_runner.model.forward(
input_ids, input_metadata.positions, input_metadata
)

for _ in range(2):
run_once()

torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.graph_memory_pool = graph.pool()
return graph, None, out, flashinfer_decode_wrapper

def replay(self, batch: Batch):
assert batch.out_cache_loc is not None
assert not batch.return_logprob
raw_bs = len(batch.reqs)

# Pad
index = bisect.bisect_left(self.batch_size_list, raw_bs)
bs = self.batch_size_list[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()

# Common inputs
self.input_ids[:raw_bs] = batch.input_ids
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
self.seq_lens[:raw_bs] = batch.seq_lens
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
self.out_cache_loc[:raw_bs] = batch.out_cache_loc

# FlashInfer inputs
init_flashinfer_args(
ForwardMode.DECODE,
self.model_runner,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
None,
self.flashinfer_handlers[bs],
)

# Replay
self.graphs[bs].replay()
output = self.output_buffers[bs]

# Unpad
if bs == raw_bs:
return output
else:
output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
)
return output
19 changes: 13 additions & 6 deletions python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,11 @@ def sample(self, logits: torch.Tensor):
# TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
sampled_index = torch.multinomial(probs_sort, num_samples=1)
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
-1
)
Expand Down Expand Up @@ -757,9 +761,11 @@ def create(
out_cache_cont_end=None,
top_logprobs_nums=None,
return_logprob=False,
skip_flashinfer_init=False,
):
if not model_runner.server_args.disable_flashinfer:
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens)
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
model_runner.flashinfer_decode_wrapper)

batch_size = len(req_pool_indices)

Expand Down Expand Up @@ -826,7 +832,8 @@ def create(
return ret


def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens):
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
flashinfer_decode_wrapper):
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
Expand Down Expand Up @@ -857,8 +864,8 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)

if forward_mode == ForwardMode.DECODE:
model_runner.flashinfer_decode_wrapper.end_forward()
model_runner.flashinfer_decode_wrapper.begin_forward(
flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.begin_forward(
kv_indptr,
kv_indices,
kv_last_page_len,
Expand Down
Loading

0 comments on commit 6658159

Please sign in to comment.