Skip to content

Commit

Permalink
[Misc] Log spec decode metrics (vllm-project#6454)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jul 16, 2024
1 parent 94162be commit 160e1d8
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 14 deletions.
49 changes: 49 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,55 @@ def test_engine_log_metrics_regression(
assert_metrics(engine, disable_log_stats, len(example_prompts))


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [10])
def test_metric_spec_decode(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
k = 5

with vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.4,
speculative_model=model,
num_speculative_tokens=k,
use_v2_block_manager=True) as vllm_model:

# Force log interval to be 0 to catch all metrics.
stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus']
stat_logger.local_interval = 0

# Note that the purpose of this test is to verify spec decode
# metrics instead of functional correctness, so the expected values
# are intended to be loose.
metric_name_to_expected_fn = {
"gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1,
"gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1,
"counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k,
"counter_spec_decode_num_draft_tokens": lambda v: v == k,
"counter_spec_decode_num_emitted_tokens":
lambda v: 0 <= v <= k + 1,
}

# Use one request to better inspect the metrics.
prompts = example_prompts[:1]

_ = vllm_model.generate_greedy(prompts, max_tokens)
for metric_name, is_expected in metric_name_to_expected_fn.items():
metric_val = getattr(
stat_logger.metrics,
metric_name).labels(**stat_logger.labels)._value.get()
assert is_expected(metric_val), (
f"the value of metric {metric_name} ({metric_val}) "
"does not meet expectation")


def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
num_requests: int) -> None:
if disable_log_stats:
Expand Down
44 changes: 36 additions & 8 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
}
test_name = request.node.name

model = kwargs["model"]
draft_model = kwargs.get("speculative_model", None)
same_draft_target_model = (draft_model is not None
and draft_model == model)

def generator_inner():

wait_for_gpu_memory_to_clear(
Expand All @@ -177,6 +182,13 @@ def generator_inner():

print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)

# Override logging interval to 0 for spec decode test run to
# log all metrics in time.
if (baseline_or_test == "test" and not use_async
and llm.llm_engine.log_stats):
for sate_logger in llm.llm_engine.stat_loggers.values():
sate_logger.local_interval = 0
set_random_seed(seed)

yield llm
Expand All @@ -188,6 +200,9 @@ def generator_outer():
yield llm
del llm

# Set an attribute to the generator_outer function to allow us to
# determine whether to further check the acceptance rate in tests.
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
return generator_outer


Expand All @@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):

def get_output_from_llm_generator(
llm_generator, prompts,
sampling_params) -> Tuple[List[str], List[List[int]]]:
sampling_params) -> Tuple[List[str], List[List[int]], float]:
tokens: List[str] = []
token_ids: List[List[int]] = []
acceptance_rate: float = -1.0
for llm in llm_generator():
maybe_assert_ngram_worker(llm)

outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]

# Fetch acceptance rate if logging is enabled.
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
stat_logger = stat_loggers["prometheus"]
acceptance_rate = (stat_logger.metrics.
gauge_spec_decode_draft_acceptance_rate.labels(
**stat_logger.labels)._value.get())
del llm

return tokens, token_ids
return tokens, token_ids, acceptance_rate


def get_logprobs_from_llm_generator(
Expand All @@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
print_tokens: bool = False,
ensure_all_accepted: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
Expand Down Expand Up @@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
temperature=temperature,
)

spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)
(spec_batch_tokens, spec_batch_token_ids,
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
prompts, sampling_params)

(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
(baseline_batch_tokens, baseline_batch_token_ids,
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
sampling_params)

assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)
Expand All @@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids

if ensure_all_accepted:
assert acceptance_rate == 1.0
18 changes: 12 additions & 6 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
temperature=temperature,
)

batch_tokens, batch_token_ids = get_output_from_llm_generator(
batch_tokens, batch_token_ids, _ = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)

# Expect a generation for each prompt in the batch.
Expand Down Expand Up @@ -200,12 +200,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
Since this test is cheaper than other e2e correctness tests, we generate
with a higher output_len.
When the draft model is the same as the target model, we further check
whether all speculative tokens are accepted.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
ensure_all_accepted = test_llm_generator.same_draft_target_model
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
ensure_all_accepted=ensure_all_accepted)


@pytest.mark.parametrize(
Expand Down
40 changes: 40 additions & 0 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,30 @@ def __init__(self, labelnames: List[str], max_model_len: int):
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])

# Speculatie decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))

# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
name="vllm:avg_prompt_throughput_toks_per_s",
Expand Down Expand Up @@ -454,6 +478,22 @@ def log(self, stats: Stats):
self.num_generation_tokens = []
self.last_local_log = stats.now

if stats.spec_decode_metrics is not None:
self._log_gauge(
self.metrics.gauge_spec_decode_draft_acceptance_rate,
stats.spec_decode_metrics.draft_acceptance_rate)
self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
stats.spec_decode_metrics.system_efficiency)
self._log_counter(
self.metrics.counter_spec_decode_num_accepted_tokens,
stats.spec_decode_metrics.accepted_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_draft_tokens,
stats.spec_decode_metrics.draft_tokens)
self._log_counter(
self.metrics.counter_spec_decode_num_emitted_tokens,
stats.spec_decode_metrics.emitted_tokens)


class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
Expand Down

0 comments on commit 160e1d8

Please sign in to comment.