diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index b3c7850556f90..22db9e885d795 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -246,6 +246,7 @@ def test_metric_spec_decode( "counter_spec_decode_num_draft_tokens": lambda v: v == k, "counter_spec_decode_num_emitted_tokens": lambda v: 0 <= v <= k + 1, + "gauge_spec_decode_mean_accepted_tokens": lambda v: 0 <= v <= k, } # Use one request to better inspect the metrics. @@ -333,6 +334,7 @@ def test_metric_spec_decode_interval( "counter_spec_decode_num_draft_tokens": lambda v: v == k, "counter_spec_decode_num_emitted_tokens": lambda v: 0 <= v <= k + 1, + "gauge_spec_decode_mean_accepted_tokens": lambda v: 0 <= v <= k, } for metric_name, is_expected in metric_name_to_expected_fn.items(): diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index c8aec8dd3afa3..cdf3e9d4ca8f1 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -257,6 +257,11 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): name="vllm:spec_decode_num_emitted_tokens_total", documentation="Number of emitted tokens.", labelnames=labelnames)) + self.gauge_spec_decode_mean_accepted_tokens = self._gauge_cls( + name="vllm:spec_decode_mean_accepted_tokens", + documentation="Mean length of speculative tokens.", + labelnames=labelnames, + multiprocess_mode="all") # Deprecated in favor of vllm:prompt_tokens_total self.gauge_avg_prompt_throughput = self._gauge_cls( @@ -504,13 +509,16 @@ def _reset(self, stats, prompt_throughput, generation_throughput) -> None: def _format_spec_decode_metrics_str( self, metrics: "SpecDecodeWorkerMetrics") -> str: - return ("Speculative metrics: " - f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " - f"System efficiency: {metrics.system_efficiency:.3f}, " - f"Number of speculative tokens: {metrics.num_spec_tokens}, " - f"Number of accepted tokens: {metrics.accepted_tokens}, " - f"Number of draft tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens: {metrics.emitted_tokens}.") + return ( + "Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens: {metrics.emitted_tokens}." + f"Mean accepted tokens length: {metrics.mean_accepted_tokens:.3f}." + ) def info(self, type: str, obj: SupportsMetricsInfo) -> None: raise NotImplementedError @@ -692,6 +700,9 @@ def log(self, stats: Stats): self._log_counter( self.metrics.counter_spec_decode_num_emitted_tokens, self.spec_decode_metrics.emitted_tokens) + self._log_gauge( + self.metrics.gauge_spec_decode_mean_accepted_tokens, + self.spec_decode_metrics.mean_accepted_tokens) # Reset tracked stats for next interval. self.num_prompt_tokens = [] diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 3ab0ba9e9f5c2..57c8c6998d388 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -135,6 +135,10 @@ def forward( self.num_accepted_tokens += accepted_token_num.sum() self.num_emitted_tokens += emitted_token_num.sum() + batch_size self.num_draft_tokens += batch_size * k + self.num_invocation += 1 + self.mean_accepted_tokens = ( + (self.mean_accepted_tokens * (self.num_invocation - 1)) + + accepted_token_num.sum()) / self.num_invocation else: accepted, recovered_token_ids = ( self._batch_modified_rejection_sampling( diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 6aa4b8bd34cde..841d9b2619b4b 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -28,6 +28,8 @@ def __init__(self, strict_mode: bool = False): self.num_accepted_tokens: Optional[torch.Tensor] = None self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_invocation: Optional[torch.Tensor] = None + self.mean_accepted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 def init_gpu_tensors(self, device: Union[int, str]) -> None: @@ -57,6 +59,10 @@ def init_tensors(self, self.num_emitted_tokens = torch.tensor(0, dtype=torch.long, device=device) + self.num_invocation = torch.tensor(0, dtype=torch.long, device=device) + self.mean_accepted_tokens = torch.tensor(0, + dtype=torch.float, + device=device) @property def probs_dtype(self): @@ -126,6 +132,10 @@ def _create_output( self.num_accepted_tokens += accepted.sum() self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_invocation += 1 + self.mean_accepted_tokens = ( + (self.mean_accepted_tokens * + (self.num_invocation - 1)) + accepted.sum()) / self.num_invocation self.num_draft_tokens += batch_size * k return output_with_bonus_tokens diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 03dc46600d8a9..b6c8e3d730291 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -45,6 +45,9 @@ class SpecDecodeWorkerMetrics( # The number of speculative tokens per sequence. num_spec_tokens: int + # Mean length of the speculative sequence that is accepted by the sampler. + mean_accepted_tokens: float + Timer = Callable[[], float] @@ -73,6 +76,8 @@ def __init__(self, 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) self._aggregate_num_emitted_tokens = torch.tensor( 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_mean_accepted_tokens = torch.tensor( + 0, dtype=torch.float, device="cpu", pin_memory=pin_memory) self._aggregate_num_draft_tokens = 0 self._rejsample_metrics_collect_interval_s = collect_interval_s @@ -134,6 +139,9 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: non_blocking=True) self._aggregate_num_emitted_tokens.copy_( self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) + self._aggregate_mean_accepted_tokens.copy_( + self.spec_decode_sampler.mean_accepted_tokens, + non_blocking=True) # Number of draft tokens is calculated on CPU, so no copy is # required. self._aggregate_num_draft_tokens = ( @@ -163,6 +171,7 @@ def _collect_rejsample_metrics( accepted_tokens = self._aggregate_num_accepted_tokens.item() emitted_tokens = self._aggregate_num_emitted_tokens.item() + mean_accepted_tokens = self._aggregate_mean_accepted_tokens.item() draft_tokens = self._aggregate_num_draft_tokens max_num_emitted_tokens = self.get_max_num_emitted_tokens( @@ -181,6 +190,7 @@ def _collect_rejsample_metrics( return SpecDecodeWorkerMetrics( num_spec_tokens=k, draft_acceptance_rate=draft_acceptance_rate, + mean_accepted_tokens=mean_accepted_tokens, system_efficiency=system_efficiency, accepted_tokens=accepted_tokens, draft_tokens=draft_tokens,