From 906a19cdb06b390f3dde287b06a3fe26c03a45e5 Mon Sep 17 00:00:00 2001 From: William Lin Date: Fri, 28 Jun 2024 19:36:06 -0700 Subject: [PATCH] [Misc] Extend vLLM Metrics logging API (#5925) Co-authored-by: Antoni Baum --- tests/metrics/test_metrics.py | 12 +- vllm/engine/llm_engine.py | 38 ++++- vllm/engine/metrics.py | 293 ++++++++++++++++++++++------------ 3 files changed, 225 insertions(+), 118 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index c1164739eee31..0191d85194e33 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -39,7 +39,7 @@ def test_metric_counter_prompt_tokens( vllm_prompt_token_count = sum(prompt_token_counts) _ = vllm_model.generate_greedy(example_prompts, max_tokens) - stat_logger = vllm_model.model.llm_engine.stat_logger + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] metric_count = stat_logger.metrics.counter_prompt_tokens.labels( **stat_logger.labels)._value.get() @@ -64,7 +64,7 @@ def test_metric_counter_generation_tokens( gpu_memory_utilization=0.4) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) tokenizer = vllm_model.model.get_tokenizer() - stat_logger = vllm_model.model.llm_engine.stat_logger + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] metric_count = stat_logger.metrics.counter_generation_tokens.labels( **stat_logger.labels)._value.get() vllm_generation_count = 0 @@ -92,7 +92,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, disable_log_stats=False, gpu_memory_utilization=0.3, served_model_name=served_model_name) as vllm_model: - stat_logger = vllm_model.model.llm_engine.stat_logger + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] metrics_tag_content = stat_logger.labels["model_name"] if served_model_name is None or served_model_name == []: @@ -172,10 +172,10 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool, num_requests: int) -> None: if disable_log_stats: with pytest.raises(AttributeError): - _ = engine.stat_logger + _ = engine.stat_loggers else: - assert (engine.stat_logger - is not None), "engine.stat_logger should be set" + assert (engine.stat_loggers + is not None), "engine.stat_loggers should be set" # Ensure the count bucket of request-level histogram metrics matches # the number of requests as a simple sanity check to ensure metrics are # generated diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fde18f60e4ddd..808a639f5dc9e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -13,7 +13,8 @@ from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger, + StatLoggerBase, Stats) from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker @@ -160,6 +161,7 @@ def __init__( executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -292,11 +294,21 @@ def __init__( # Metric Logging. if self.log_stats: - self.stat_logger = StatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len) - self.stat_logger.info("cache_config", self.cache_config) + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) self.tracer = None if self.observability_config.otlp_traces_endpoint: @@ -833,14 +845,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: return request_outputs + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - self.stat_logger.log( - self._get_stats(scheduler_outputs, model_output)) + for logger in self.stat_loggers.values(): + logger.log(self._get_stats(scheduler_outputs, model_output)) def _get_stats( self, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 027f5c7e73c2b..2c1210c90c632 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,21 +1,27 @@ import time +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING from typing import Counter as CollectionsCounter from typing import Dict, List, Optional, Protocol, Union import numpy as np -from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, - disable_created_metrics) +import prometheus_client +from vllm.executor.ray_utils import ray from vllm.logger import init_logger +if ray is not None: + from ray.util import metrics as ray_metrics +else: + ray_metrics = None + if TYPE_CHECKING: from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics logger = init_logger(__name__) -disable_created_metrics() +prometheus_client.disable_created_metrics() # The begin-* and end* here are used by the documentation generator # to extract the metrics definitions. @@ -24,56 +30,55 @@ # begin-metrics-definitions class Metrics: labelname_finish_reason = "finished_reason" + _base_library = prometheus_client def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors - for collector in list(REGISTRY._collector_to_names): - if hasattr(collector, "_name") and "vllm" in collector._name: - REGISTRY.unregister(collector) + self._unregister_vllm_metrics() # Config Information - self.info_cache_config = Info( + self.info_cache_config = prometheus_client.Info( name='vllm:cache_config', documentation='information of cache_config') # System stats # Scheduler State - self.gauge_scheduler_running = Gauge( + self.gauge_scheduler_running = self._base_library.Gauge( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_waiting = Gauge( + self.gauge_scheduler_waiting = self._base_library.Gauge( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) - self.gauge_scheduler_swapped = Gauge( + self.gauge_scheduler_swapped = self._base_library.Gauge( name="vllm:num_requests_swapped", documentation="Number of requests swapped to CPU.", labelnames=labelnames) # KV Cache Usage in % - self.gauge_gpu_cache_usage = Gauge( + self.gauge_gpu_cache_usage = self._base_library.Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - self.gauge_cpu_cache_usage = Gauge( + self.gauge_cpu_cache_usage = self._base_library.Gauge( name="vllm:cpu_cache_usage_perc", documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) # Iteration stats - self.counter_num_preemption = Counter( + self.counter_num_preemption = self._base_library.Counter( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames) - self.counter_prompt_tokens = Counter( + self.counter_prompt_tokens = self._base_library.Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", labelnames=labelnames) - self.counter_generation_tokens = Counter( + self.counter_generation_tokens = self._base_library.Counter( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - self.histogram_time_to_first_token = Histogram( + self.histogram_time_to_first_token = self._base_library.Histogram( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labelnames, @@ -81,7 +86,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 ]) - self.histogram_time_per_output_token = Histogram( + self.histogram_time_per_output_token = self._base_library.Histogram( name="vllm:time_per_output_token_seconds", documentation="Histogram of time per output token in seconds.", labelnames=labelnames, @@ -92,54 +97,77 @@ def __init__(self, labelnames: List[str], max_model_len: int): # Request stats # Latency - self.histogram_e2e_time_request = Histogram( + self.histogram_e2e_time_request = self._base_library.Histogram( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) # Metadata - self.histogram_num_prompt_tokens_request = Histogram( + self.histogram_num_prompt_tokens_request = self._base_library.Histogram( name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) - self.histogram_num_generation_tokens_request = Histogram( - name="vllm:request_generation_tokens", - documentation="Number of generation tokens processed.", - labelnames=labelnames, - buckets=build_1_2_5_buckets(max_model_len), - ) - self.histogram_best_of_request = Histogram( + self.histogram_num_generation_tokens_request = \ + self._base_library.Histogram( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_best_of_request = self._base_library.Histogram( name="vllm:request_params_best_of", documentation="Histogram of the best_of request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.histogram_n_request = Histogram( + self.histogram_n_request = self._base_library.Histogram( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.counter_request_success = Counter( + self.counter_request_success = self._base_library.Counter( name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) # Deprecated in favor of vllm:prompt_tokens_total - self.gauge_avg_prompt_throughput = Gauge( + self.gauge_avg_prompt_throughput = self._base_library.Gauge( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) # Deprecated in favor of vllm:generation_tokens_total - self.gauge_avg_generation_throughput = Gauge( + self.gauge_avg_generation_throughput = self._base_library.Gauge( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", labelnames=labelnames, ) + def _unregister_vllm_metrics(self) -> None: + for collector in list(self._base_library.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + self._base_library.REGISTRY.unregister(collector) + + +class RayMetrics(Metrics): + """ + RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. + Provides the same metrics as Metrics but uses Ray's util.metrics library. + """ + _base_library = ray_metrics + + def __init__(self, labelnames: List[str], max_model_len: int): + if ray_metrics is None: + raise ImportError("RayMetrics requires Ray to be installed.") + super().__init__(labelnames, max_model_len) + + def _unregister_vllm_metrics(self) -> None: + # No-op on purpose + pass + # end-metrics-definitions @@ -206,34 +234,136 @@ def metrics_info(self) -> Dict[str, str]: ... -class StatLogger: - """StatLogger is used LLMEngine to log to Promethus and Stdout.""" +def local_interval_elapsed(now: float, last_log: float, + local_interval: float) -> bool: + elapsed_time = now - last_log + return elapsed_time > local_interval + + +def get_throughput(tracked_stats: List[int], now: float, + last_log: float) -> float: + return float(np.sum(tracked_stats) / (now - last_log)) - def __init__(self, local_interval: float, labels: Dict[str, str], - max_model_len: int) -> None: - # Metadata for logging locally. - self.last_local_log = time.time() - self.local_interval = local_interval +class StatLoggerBase(ABC): + """Base class for StatLogger.""" + + def __init__(self, local_interval: float) -> None: # Tracked stats over current local logging interval. self.num_prompt_tokens: List[int] = [] self.num_generation_tokens: List[int] = [] + self.last_local_log = time.time() + self.local_interval = local_interval + + @abstractmethod + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + @abstractmethod + def log(self, stats: Stats) -> None: + raise NotImplementedError + +class LoggingStatLogger(StatLoggerBase): + """LoggingStatLogger is used in LLMEngine to log to Stdout.""" + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + Logs to Stdout every self.local_interval seconds.""" + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + # Compute summary metrics for tracked stats (and log them + # to promethus if applicable). + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + + # Log to stdout. + logger.info( + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%.", + prompt_throughput, + generation_throughput, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, + ) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + + if stats.spec_decode_metrics is not None: + logger.info( + self._format_spec_decode_metrics_str( + stats.spec_decode_metrics)) + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") + + +class PrometheusStatLogger(StatLoggerBase): + """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + _metrics_cls = Metrics + + def __init__(self, local_interval: float, labels: Dict[str, str], + max_model_len: int) -> None: + super().__init__(local_interval) # Prometheus metrics self.labels = labels - self.metrics = Metrics(labelnames=list(labels.keys()), - max_model_len=max_model_len) + self.metrics = self._metrics_cls(labelnames=list(labels.keys()), + max_model_len=max_model_len) def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": self.metrics.info_cache_config.info(obj.metrics_info()) - def _get_throughput(self, tracked_stats: List[int], now: float) -> float: - return float(np.sum(tracked_stats) / (now - self.last_local_log)) + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) - def _local_interval_elapsed(self, now: float) -> bool: - elapsed_time = now - self.last_local_log - return elapsed_time > self.local_interval + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram, data: Union[List[int], + List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) def _log_prometheus(self, stats: Stats) -> None: # System state data @@ -279,26 +409,6 @@ def _log_prometheus(self, stats: Stats) -> None: self._log_histogram(self.metrics.histogram_best_of_request, stats.best_of_requests) - def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None: - # Convenience function for logging to gauge. - gauge.labels(**self.labels).set(data) - - def _log_counter(self, counter: Counter, data: Union[int, float]) -> None: - # Convenience function for logging to counter. - counter.labels(**self.labels).inc(data) - - def _log_counter_labels(self, counter: Counter, data: CollectionsCounter, - label_key: str) -> None: - # Convenience function for collection counter of labels. - for label, count in data.items(): - counter.labels(**{**self.labels, label_key: label}).inc(count) - - def _log_histogram(self, histogram: Histogram, - data: Union[List[int], List[float]]) -> None: - # Convenience function for logging list to histogram. - for datum in data: - histogram.labels(**self.labels).observe(datum) - def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: # Logs metrics to prometheus that are computed every logging_interval. @@ -313,11 +423,8 @@ def _log_prometheus_interval(self, prompt_throughput: float, self.metrics.gauge_avg_generation_throughput.labels( **self.labels).set(generation_throughput) - def log(self, stats: Stats) -> None: - """Called by LLMEngine. - Logs to prometheus and tracked stats every iteration. - Logs to Stdout every self.local_interval seconds.""" - + def log(self, stats: Stats): + """Logs to prometheus and tracked stats every iteration.""" # Log to prometheus. self._log_prometheus(stats) @@ -326,50 +433,28 @@ def log(self, stats: Stats) -> None: self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. - if self._local_interval_elapsed(stats.now): + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): # Compute summary metrics for tracked stats (and log them # to promethus if applicable). - prompt_throughput = self._get_throughput(self.num_prompt_tokens, - now=stats.now) - generation_throughput = self._get_throughput( - self.num_generation_tokens, now=stats.now) + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + self._log_prometheus_interval( prompt_throughput=prompt_throughput, generation_throughput=generation_throughput) - # Log to stdout. - logger.info( - "Avg prompt throughput: %.1f tokens/s, " - "Avg generation throughput: %.1f tokens/s, " - "Running: %d reqs, Swapped: %d reqs, " - "Pending: %d reqs, GPU KV cache usage: %.1f%%, " - "CPU KV cache usage: %.1f%%.", - prompt_throughput, - generation_throughput, - stats.num_running_sys, - stats.num_swapped_sys, - stats.num_waiting_sys, - stats.gpu_cache_usage_sys * 100, - stats.cpu_cache_usage_sys * 100, - ) - # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] self.last_local_log = stats.now - if stats.spec_decode_metrics is not None: - logger.info( - self._format_spec_decode_metrics_str( - stats.spec_decode_metrics)) - - def _format_spec_decode_metrics_str( - self, metrics: "SpecDecodeWorkerMetrics") -> str: - return ("Speculative metrics: " - f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " - f"System efficiency: {metrics.system_efficiency:.3f}, " - f"Number of speculative tokens: {metrics.num_spec_tokens}, " - f"Number of accepted tokens: {metrics.accepted_tokens}, " - f"Number of draft tokens tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics