Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time cost utils #355

Merged
merged 3 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from sglang.lang.ir import SglSamplingParams

try:
import tiktoken

import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/constrained/fsm_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
super().__init__(enable=enable)

from importlib.metadata import version

if version("outlines") >= "0.0.35":
from transformers import AutoTokenizer

Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import handle_port_init
from sglang.srt.utils import enable_show_time_cost, handle_port_init
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

Expand Down Expand Up @@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
global tokenizer_manager
global chat_template_name

# start show time thread
if server_args.show_time_cost:
enable_show_time_cost()

# disable disk cache if needed
if server_args.disable_disk_cache:
disable_cache()
Expand Down
21 changes: 14 additions & 7 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ class ServerArgs:
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
api_key: str = ""
show_time_cost: bool = False

# optional modes
disable_radix_cache: bool = False
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
api_key: str = ""

def __post_init__(self):
if self.tokenizer_path is None:
Expand Down Expand Up @@ -181,6 +182,18 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.log_stats_interval,
help="Log stats interval in second.",
)
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
parser.add_argument(
"--show-time-cost",
action="store_true",
help="Show time cost of custom marks",
)

# optional modes
parser.add_argument(
"--disable-radix-cache",
Expand All @@ -202,12 +215,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
66 changes: 37 additions & 29 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,48 +11,56 @@
import numpy as np
import requests
import torch
import torch.distributed as dist

is_show_cost_time = False
show_time_cost = False
time_infos = {}


def mark_cost_time(func_name):
def inner_func(func):
def time_func(*args, **kwargs):
if dist.get_rank() in [0, 1] and is_show_cost_time:
torch.cuda.synchronize()
start_time = time.time()
ans = func(*args, **kwargs)
torch.cuda.synchronize()
print(func_name, "cost time:", (time.time() - start_time) * 1000)
return ans
else:
torch.cuda.synchronize()
ans = func(*args, **kwargs)
torch.cuda.synchronize()
return ans
def enable_show_time_cost():
global show_time_cost
show_time_cost = True

return time_func

return inner_func
class TimeInfo:
def __init__(self, name, interval=0.1, color=0, indent=0):
self.name = name
self.interval = interval
self.color = color
self.indent = indent

self.acc_time = 0
self.last_acc_time = 0

def check(self):
if self.acc_time - self.last_acc_time > self.interval:
self.last_acc_time = self.acc_time
return True
return False

time_mark = {}
def pretty_print(self):
print(f"\x1b[{self.color}m", end="")
print("-" * self.indent * 2, end="")
print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")


def mark_start(key):
def mark_start(name, interval=0.1, color=0, indent=0):
global time_infos, show_time_cost
if not show_time_cost:
return
torch.cuda.synchronize()
global time_mark
time_mark[key] = time.time()
return
if time_infos.get(name, None) is None:
time_infos[name] = TimeInfo(name, interval, color, indent)
time_infos[name].acc_time -= time.time()


def mark_end(key, print_min_cost=0.0):
def mark_end(name):
global time_infos, show_time_cost
if not show_time_cost:
return
torch.cuda.synchronize()
global time_mark
cost_time = (time.time() - time_mark[key]) * 1000
if cost_time > print_min_cost:
print(f"cost {key}:", cost_time)
time_infos[name].acc_time += time.time()
if time_infos[name].check():
time_infos[name].pretty_print()


def calculate_time(show=False, min_cost_ms=0.0):
Expand Down
12 changes: 6 additions & 6 deletions test/srt/model/bench_llama_low_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def update_extend(
p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[
n_idx, prefix_len : prefix_len + extend_len
] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = (
self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
)

def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size
Expand All @@ -81,9 +81,9 @@ def update_decode(self, predict_ids, batch_size):
self.out_cache_cont_start,
self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens
] = self.out_cache_loc
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
self.out_cache_loc
)
self.seq_lens.add_(1)


Expand Down