Skip to content

Commit

Permalink
[Misc] Extend vLLM Metrics logging API (vllm-project#5925)
Browse files Browse the repository at this point in the history
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
  • Loading branch information
SolitaryThinker and Yard1 authored Jun 29, 2024
1 parent c4bca74 commit 906a19c
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 118 deletions.
12 changes: 6 additions & 6 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_metric_counter_prompt_tokens(
vllm_prompt_token_count = sum(prompt_token_counts)

_ = vllm_model.generate_greedy(example_prompts, max_tokens)
stat_logger = vllm_model.model.llm_engine.stat_logger
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
**stat_logger.labels)._value.get()

Expand All @@ -64,7 +64,7 @@ def test_metric_counter_generation_tokens(
gpu_memory_utilization=0.4) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
tokenizer = vllm_model.model.get_tokenizer()
stat_logger = vllm_model.model.llm_engine.stat_logger
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
**stat_logger.labels)._value.get()
vllm_generation_count = 0
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
disable_log_stats=False,
gpu_memory_utilization=0.3,
served_model_name=served_model_name) as vllm_model:
stat_logger = vllm_model.model.llm_engine.stat_logger
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
metrics_tag_content = stat_logger.labels["model_name"]

if served_model_name is None or served_model_name == []:
Expand Down Expand Up @@ -172,10 +172,10 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:
with pytest.raises(AttributeError):
_ = engine.stat_logger
_ = engine.stat_loggers
else:
assert (engine.stat_logger
is not None), "engine.stat_logger should be set"
assert (engine.stat_loggers
is not None), "engine.stat_loggers should be set"
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
Expand Down
38 changes: 30 additions & 8 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler,
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase, Stats)
from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
Expand Down Expand Up @@ -160,6 +161,7 @@ def __init__(
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
Expand Down Expand Up @@ -292,11 +294,21 @@ def __init__(

# Metric Logging.
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)
if stat_loggers is not None:
self.stat_loggers = stat_loggers
else:
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)

self.tracer = None
if self.observability_config.otlp_traces_endpoint:
Expand Down Expand Up @@ -833,14 +845,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

return request_outputs

def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if logger_name in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} already exists.")
self.stat_loggers[logger_name] = logger

def remove_logger(self, logger_name: str) -> None:
if logger_name not in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} does not exist.")
del self.stat_loggers[logger_name]

def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None) -> None:
"""Forced log when no requests active."""
if self.log_stats:
self.stat_logger.log(
self._get_stats(scheduler_outputs, model_output))
for logger in self.stat_loggers.values():
logger.log(self._get_stats(scheduler_outputs, model_output))

def _get_stats(
self,
Expand Down
Loading

0 comments on commit 906a19c

Please sign in to comment.