Skip to content

Commit

Permalink
[Bugfix] StatLoggers: cache spec decode metrics when they get collect…
Browse files Browse the repository at this point in the history
…ed. (#6645)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
  • Loading branch information
tdoublep authored Jul 23, 2024
1 parent 01c16ed commit 2f808e6
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 16 deletions.
91 changes: 91 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import List

import pytest
Expand All @@ -10,6 +11,8 @@
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams

from ..conftest import cleanup

MODELS = [
"facebook/opt-125m",
]
Expand Down Expand Up @@ -219,6 +222,94 @@ def test_metric_spec_decode(
"does not meet expectation")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("log_interval", [1, 3, 5, 7])
def test_metric_spec_decode_interval(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
log_interval: int,
) -> None:
k = 5

engine_args = EngineArgs(model=model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True,
enforce_eager=True)

engine = LLMEngine.from_engine_args(engine_args)

try:

engine.add_request(
"request-id-0",
example_prompts[0],
SamplingParams(max_tokens=max_tokens),
)

# set log internal
stat_logger = engine.stat_loggers['prometheus']
stat_logger.local_interval = log_interval

# prefill
engine.step()

# wait for 5 seconds to ensure that spec decode metrics
# get triggered in first decode step
time.sleep(5)

# first decode step should trigger async collection of metrics
engine.step()

# wait one second to allow H2D transfer to finish
time.sleep(1)

# second decode step should now be able to collect the spec
# decode stats and the request should also be finished
engine.step()

# must have finisehd now
assert not engine.has_unfinished_requests()

# wait to ensure logging occurs
time.sleep(log_interval)

# force logging
engine.step()

# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
metric_name_to_expected_fn = {
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
"counter_spec_decode_num_emitted_tokens":
lambda v: 0 <= v <= k + 1,
}

for metric_name, is_expected in metric_name_to_expected_fn.items():
metric_val = getattr(
stat_logger.metrics,
metric_name).labels(**stat_logger.labels)._value.get()
assert is_expected(metric_val), (
f"the value of metric {metric_name} ({metric_val}) "
"does not meet expectation")

finally:
del engine
cleanup()


def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:
Expand Down
47 changes: 31 additions & 16 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(self, local_interval: float) -> None:
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None

@abstractmethod
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
Expand All @@ -364,6 +365,12 @@ def info(self, type: str, obj: SupportsMetricsInfo) -> None:
def log(self, stats: Stats) -> None:
raise NotImplementedError

def maybe_update_spec_decode_metrics(self, stats: Stats):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if stats.spec_decode_metrics is not None:
self.spec_decode_metrics = stats.spec_decode_metrics


class LoggingStatLogger(StatLoggerBase):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
Expand All @@ -379,6 +386,9 @@ def log(self, stats: Stats) -> None:
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)

# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)

# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
Expand Down Expand Up @@ -408,15 +418,16 @@ def log(self, stats: Stats) -> None:
stats.cpu_cache_usage_sys * 100,
)

if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))

# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now

if stats.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
stats.spec_decode_metrics))
self.spec_decode_metrics = None

def _format_spec_decode_metrics_str(
self, metrics: "SpecDecodeWorkerMetrics") -> str:
Expand Down Expand Up @@ -533,6 +544,9 @@ def log(self, stats: Stats):
self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
self.num_generation_tokens.append(stats.num_generation_tokens_iter)

# Update spec decode metrics
self.maybe_update_spec_decode_metrics(stats)

# Log locally every local_interval seconds.
if local_interval_elapsed(stats.now, self.last_local_log,
self.local_interval):
Expand All @@ -550,26 +564,27 @@ def log(self, stats: Stats):
prompt_throughput=prompt_throughput,
generation_throughput=generation_throughput)

# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now

if stats.spec_decode_metrics is not None:
if self.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate)
self.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency)
self.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens)
self.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens)
self.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens)
self.spec_decode_metrics.emitted_tokens)

# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
self.last_local_log = stats.now
self.spec_decode_metrics = None


class RayPrometheusStatLogger(PrometheusStatLogger):
Expand Down

0 comments on commit 2f808e6

Please sign in to comment.