Skip to content

Commit

Permalink
Simplify prometheus metrics (#1981)
Browse files Browse the repository at this point in the history
Co-authored-by: Mohit Reddy <mohitreddy1996@users.noreply.github.com>
  • Loading branch information
merrymercy and mohitreddy1996 authored Nov 10, 2024
1 parent ed53ac8 commit 1929c06
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 632 deletions.
15 changes: 0 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import dataclasses
import logging
import time
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -255,16 +254,6 @@ def __init__(
# For Qwen2-VL
self.mrope_position_delta = [] # use mutable object

# Lifetime traces
# time when request is created and added to waitlist
self.created_time = None
# time when request is added to prefill batch
self.queued_time = None
# time when request is being processed
self.started_time = None
# time when request is finished
self.finished_time = None

# whether request reached finished condition
def finished(self) -> bool:
return self.finished_reason is not None
Expand Down Expand Up @@ -1038,10 +1027,6 @@ def __str__(self):
f"#req={(len(self.reqs))})"
)

def mark_reqs_started(self):
for req in self.reqs:
req.started_time = time.time()


@dataclasses.dataclass
class ModelWorkerBatch:
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import os
import random
import time
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
Expand Down Expand Up @@ -307,7 +306,6 @@ def add_one_req(self, req: Req):
):
# Non-chunked prefill
self.can_run_list.append(req)
req.queued_time = time.time()
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len,
Expand All @@ -326,7 +324,6 @@ def add_one_req(self, req: Req):
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
req.queued_time = time.time()
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
Expand Down
221 changes: 64 additions & 157 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.metrics_collector import PrometheusMetricsCollector
from sglang.srt.metrics.metrics_types import Stats
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
broadcast_pyobj,
Expand Down Expand Up @@ -106,6 +105,7 @@ def __init__(
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics

# Init inter-process communication
context = zmq.Context(2)
Expand Down Expand Up @@ -224,8 +224,7 @@ def __init__(
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time() # time of last stats for every iter
self.last_log_tic = time.time() # time of last log for print decode log
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval

# Init chunked prefill
Expand Down Expand Up @@ -294,15 +293,16 @@ def __init__(
],
with_stack=True,
)

# Init metrics stats
self.stats = Stats()
self.metrics_collector = PrometheusMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
max_model_len=self.max_total_num_tokens,
)
self.stats = SchedulerStats()
if self.enable_metrics:
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)

def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
Expand Down Expand Up @@ -350,11 +350,6 @@ def event_loop_normal(self):
else:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
# log stats
if self.is_generation and self.server_args.enable_metrics:
stats = self.get_stats(batch)
self.log_stats(stats)
self.last_stats_tic = time.time()

self.last_batch = batch

Expand Down Expand Up @@ -493,7 +488,6 @@ def handle_generate_request(
self.max_req_len - len(req.origin_input_ids) - 1,
)

req.created_time = time.time()
self.waiting_queue.append(req)

def handle_embedding_request(
Expand All @@ -518,25 +512,68 @@ def handle_embedding_request(

self.waiting_queue.append(req)

def print_decode_stats(self):
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0

num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)

logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)

if self.enable_metrics:
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
self.stats.cache_hit_rate = tree_cache_hit_rate
self.metrics_collector.log_stats(self.stats)

def log_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
throughput = self.num_generated_tokens / (time.time() - self.last_log_tic)
gen_throughput = self.num_generated_tokens / (
time.time() - self.last_decode_stats_tic
)
self.num_generated_tokens = 0
self.last_log_tic = time.time()
# set system stats
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
logger.info(
f"Decode batch. "
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)

if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
self.metrics_collector.log_stats(self.stats)

def check_memory(self):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
Expand Down Expand Up @@ -612,15 +649,14 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
prefix_computed = self.policy.calc_priority(self.waiting_queue)

# Prefill policy
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder(
self.tree_cache,
self.running_batch,
self.new_token_ratio,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
num_mixed_running,
running_bs if self.is_mixed_chunk else 0,
)

has_inflight = self.being_chunked_req is not None
Expand Down Expand Up @@ -677,47 +713,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

# Print stats
if self.tp_rank == 0:
if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
)
else:
tree_cache_hit_rate = 0.0

num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
# set system stats
self.stats.cache_hit_rate = round(100.0 * tree_cache_hit_rate, 2)
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)

if num_mixed_running > 0:
logger.info(
f"Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
else:
logger.info(
f"Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}"
)
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)

# Create a new batch
new_batch = ScheduleBatch.init_new(
Expand Down Expand Up @@ -789,7 +785,6 @@ def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
model_worker_batch = batch.get_model_worker_batch()
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
batch.mark_reqs_started()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
Expand All @@ -810,94 +805,6 @@ def run_batch(self, batch: ScheduleBatch):
ret = embeddings, model_worker_batch.bid
return ret

def get_stats(self, batch: ScheduleBatch):
# TODO: get stats for chunked prefill

now = time.time()
# system stats
# Scheduler State
new_seq: int = 0
num_running_req = len(self.running_batch.reqs) if self.running_batch else 0
num_waiting_req = len(self.waiting_queue)
# Cache State
cache_hit_rate: float = 0.0
token_usage: float = 0.0

# set stats from prefill
if self.stats is not None:
# new_seq=self.stats.new_seq
cache_hit_rate = self.stats.cache_hit_rate
token_usage = self.stats.token_usage
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = []

# Request stats
# Decode
gen_throughput: float = 0.0
# Latency
time_e2e_requests: List[float] = []
time_waiting_requests: List[float] = []
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
finished_reason_requests: List[str] = []

# _, next_token_ids, _ = result
if batch is not None:
num_generation_tokens_iter = len(batch.output_ids)
gen_throughput = round(
num_generation_tokens_iter / (now - self.last_stats_tic), 2
)

for i, req in enumerate(batch.reqs):
# NOTE: Batch forward mode is extend befor start decode,
if batch.forward_mode.is_extend():
num_prompt_tokens_iter = len(batch.input_ids) + sum(
batch.prefix_lens
)
time_to_first_tokens_iter.append(now - req.started_time)
else:
time_per_output_tokens_iter.append(now - self.last_stats_tic)

if req.finished():
time_e2e_requests.append(now - req.created_time)
time_waiting_requests.append(req.queued_time - req.created_time)
num_prompt_tokens_requests.append(len(req.origin_input_ids))
num_generation_tokens_requests.append(len(req.output_ids))
finished_reason_requests.append(
req.finished_reason.to_json()
if req.finished_reason is not None
else None
)

return Stats(
new_seq=new_seq,
num_running_req=num_running_req,
num_waiting_req=num_waiting_req,
cache_hit_rate=cache_hit_rate,
token_usage=token_usage,
num_prompt_tokens_iter=num_prompt_tokens_iter,
num_generation_tokens_iter=num_generation_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter,
gen_throughput=gen_throughput,
time_e2e_requests=time_e2e_requests,
time_waiting_requests=time_waiting_requests,
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
finished_reason_requests=finished_reason_requests,
context_len=self.model_config.context_len,
max_total_num_tokens=self.max_total_num_tokens,
max_prefill_tokens=self.max_prefill_tokens,
max_running_requests=self.max_running_requests,
)

def log_stats(self, stats: Stats):
self.metrics_collector.log_stats(stats)

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
Expand Down Expand Up @@ -1035,7 +942,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):
self.tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.print_decode_stats()
self.log_decode_stats()

def add_logprob_return_values(
self,
Expand Down
Loading

0 comments on commit 1929c06

Please sign in to comment.