diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 42b15cd6c458e..23a7a85580a0a 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,3 +1,4 @@ +import time from typing import List import pytest @@ -10,6 +11,8 @@ from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams +from ..conftest import cleanup + MODELS = [ "facebook/opt-125m", ] @@ -219,6 +222,94 @@ def test_metric_spec_decode( "does not meet expectation") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize("log_interval", [1, 3, 5, 7]) +def test_metric_spec_decode_interval( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + log_interval: int, +) -> None: + k = 5 + + engine_args = EngineArgs(model=model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_model=model, + num_speculative_tokens=k, + use_v2_block_manager=True, + enforce_eager=True) + + engine = LLMEngine.from_engine_args(engine_args) + + try: + + engine.add_request( + "request-id-0", + example_prompts[0], + SamplingParams(max_tokens=max_tokens), + ) + + # set log internal + stat_logger = engine.stat_loggers['prometheus'] + stat_logger.local_interval = log_interval + + # prefill + engine.step() + + # wait for 5 seconds to ensure that spec decode metrics + # get triggered in first decode step + time.sleep(5) + + # first decode step should trigger async collection of metrics + engine.step() + + # wait one second to allow H2D transfer to finish + time.sleep(1) + + # second decode step should now be able to collect the spec + # decode stats and the request should also be finished + engine.step() + + # must have finisehd now + assert not engine.has_unfinished_requests() + + # wait to ensure logging occurs + time.sleep(log_interval) + + # force logging + engine.step() + + # Note that the purpose of this test is to verify spec decode + # metrics instead of functional correctness, so the expected values + # are intended to be loose. + metric_name_to_expected_fn = { + "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1, + "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, + "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, + "counter_spec_decode_num_draft_tokens": lambda v: v == k, + "counter_spec_decode_num_emitted_tokens": + lambda v: 0 <= v <= k + 1, + } + + for metric_name, is_expected in metric_name_to_expected_fn.items(): + metric_val = getattr( + stat_logger.metrics, + metric_name).labels(**stat_logger.labels)._value.get() + assert is_expected(metric_val), ( + f"the value of metric {metric_name} ({metric_val}) " + "does not meet expectation") + + finally: + del engine + cleanup() + + def assert_metrics(engine: LLMEngine, disable_log_stats: bool, num_requests: int) -> None: if disable_log_stats: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index b1531249d0453..0124ccb62c683 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -356,6 +356,7 @@ def __init__(self, local_interval: float) -> None: self.num_generation_tokens: List[int] = [] self.last_local_log = time.time() self.local_interval = local_interval + self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None @abstractmethod def info(self, type: str, obj: SupportsMetricsInfo) -> None: @@ -365,6 +366,12 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None: def log(self, stats: Stats) -> None: raise NotImplementedError + def maybe_update_spec_decode_metrics(self, stats: Stats): + """Save spec decode metrics (since they are unlikely + to be emitted at same time as log interval).""" + if stats.spec_decode_metrics is not None: + self.spec_decode_metrics = stats.spec_decode_metrics + class LoggingStatLogger(StatLoggerBase): """LoggingStatLogger is used in LLMEngine to log to Stdout.""" @@ -380,6 +387,9 @@ def log(self, stats: Stats) -> None: self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter) + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + # Log locally every local_interval seconds. if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): @@ -413,16 +423,17 @@ def log(self, stats: Stats) -> None: stats.cpu_cache_usage_sys * 100, ) - # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] - self.last_local_log = stats.now - if stats.spec_decode_metrics is not None: logger.info( self._format_spec_decode_metrics_str( stats.spec_decode_metrics)) + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + def _format_spec_decode_metrics_str( self, metrics: "SpecDecodeWorkerMetrics") -> str: @@ -537,6 +548,9 @@ def log(self, stats: Stats): self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter) + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + # Log locally every local_interval seconds. if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): @@ -554,26 +568,27 @@ def log(self, stats: Stats): prompt_throughput=prompt_throughput, generation_throughput=generation_throughput) - # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] - self.last_local_log = stats.now - - if stats.spec_decode_metrics is not None: + if self.spec_decode_metrics is not None: self._log_gauge( self.metrics.gauge_spec_decode_draft_acceptance_rate, - stats.spec_decode_metrics.draft_acceptance_rate) + self.spec_decode_metrics.draft_acceptance_rate) self._log_gauge(self.metrics.gauge_spec_decode_efficiency, - stats.spec_decode_metrics.system_efficiency) + self.spec_decode_metrics.system_efficiency) self._log_counter( self.metrics.counter_spec_decode_num_accepted_tokens, - stats.spec_decode_metrics.accepted_tokens) + self.spec_decode_metrics.accepted_tokens) self._log_counter( self.metrics.counter_spec_decode_num_draft_tokens, - stats.spec_decode_metrics.draft_tokens) + self.spec_decode_metrics.draft_tokens) self._log_counter( self.metrics.counter_spec_decode_num_emitted_tokens, - stats.spec_decode_metrics.emitted_tokens) + self.spec_decode_metrics.emitted_tokens) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead."""