Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify prometheus metrics #1981

Merged
merged 16 commits into from
Nov 10, 2024
15 changes: 0 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import dataclasses
import logging
import time
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -255,16 +254,6 @@ def __init__(
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object

# Lifetime traces
# time when request is created and added to waitlist
self.created_time = None
# time when request is added to prefill batch
self.queued_time = None
# time when request is being processed
self.started_time = None
# time when request is finished
self.finished_time = None

# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
Expand Down Expand Up @@ -1038,10 +1027,6 @@ def __str__(self):
f"#req={(len(self.reqs))})"
)

def mark_reqs_started(self):
for req in self.reqs:
req.started_time = time.time()


@dataclasses.dataclass
class ModelWorkerBatch:
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import os
import random
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
Expand Down Expand Up @@ -307,7 +306,6 @@ def add_one_req(self, req: Req):
):
# Non-chunked prefill
self.can_run_list.append(req)
req.queued_time = time.time()
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len,
Expand All @@ -326,7 +324,6 @@ def add_one_req(self, req: Req):
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
req.queued_time = time.time()
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
Expand Down
221 changes: 64 additions & 157 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
Expand Down Expand Up @@ -106,6 +105,7 @@ def __init__(
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics

# Init inter-process communication
context = zmq.Context(2)
Expand Down Expand Up @@ -224,8 +224,7 @@ def __init__(
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval

# Init chunked prefill
Expand Down Expand Up @@ -294,15 +293,16 @@ def __init__(
],
with_stack=True,
)

# Init metrics stats
self.stats = Stats()
self.metrics_collector = PrometheusMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
max_model_len=self.max_total_num_tokens,
)
self.stats = SchedulerStats()
if self.enable_metrics:
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)

def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
Expand Down Expand Up @@ -350,11 +350,6 @@ def event_loop_normal(self):
else:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()

self.last_batch = batch

Expand Down Expand Up @@ -493,7 +488,6 @@ def handle_generate_request(
self.max_req_len - len(req.origin_input_ids) - 1,
)

req.created_time = time.time()
self.waiting_queue.append(req)

def handle_embedding_request(
Expand All @@ -518,25 +512,68 @@ def handle_embedding_request(

self.waiting_queue.append(req)

def print_decode_stats(self):
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0

num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)

logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)

if self.enable_metrics:
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
self.stats.cache_hit_rate = tree_cache_hit_rate
self.metrics_collector.log_stats(self.stats)

def log_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
gen_throughput = self.num_generated_tokens / (
time.time() - self.last_decode_stats_tic
)
self.num_generated_tokens = 0
self.last_log_tic = time.time()
# set system stats
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)

if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.metrics_collector.log_stats(self.stats)

def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
Expand Down Expand Up @@ -612,15 +649,14 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
prefix_computed = self.policy.calc_priority(self.waiting_queue)

# Prefill policy
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder(
self.tree_cache,
self.running_batch,
self.new_token_ratio,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
num_mixed_running,
running_bs if self.is_mixed_chunk else 0,
)

has_inflight = self.being_chunked_req is not None
Expand Down Expand Up @@ -677,47 +713,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

# Print stats
if self.tp_rank == 0:
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0

num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)

if num_mixed_running > 0:
logger.info(
f"Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
else:
logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)

# Create a new batch
new_batch = ScheduleBatch.init_new(
Expand Down Expand Up @@ -789,7 +785,6 @@ def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
Expand All @@ -810,94 +805,6 @@ def run_batch(self, batch: ScheduleBatch):
ret = embeddings, model_worker_batch.bid
return ret

def get_stats(self, batch: ScheduleBatch):
# TODO: get stats for chunked prefill

now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0

# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate = self.stats.cache_hit_rate
token_usage = self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []

# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []

# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(
num_generation_tokens_iter / (now - self.last_stats_tic), 2
)

for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter = len(batch.input_ids) + sum(
batch.prefix_lens
)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now - self.last_stats_tic)

if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None
)

return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)

def log_stats(self, stats: Stats):
self.metrics_collector.log_stats(stats)

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
Expand Down Expand Up @@ -1035,7 +942,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
self.tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.print_decode_stats()
self.log_decode_stats()

def add_logprob_return_values(
self,
Expand Down
Loading
Loading