Skip to content

Commit

Permalink
[misc] hide best_of from engine (vllm-project#9261)
Browse files Browse the repository at this point in the history
Co-authored-by: Brendan Wong <bjwpokemon@gmail.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
  • Loading branch information
2 people authored and sumitd2 committed Nov 14, 2024
1 parent 283caf0 commit 3c52732
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 73 deletions.
4 changes: 0 additions & 4 deletions tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ async def client(server):
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
"vllm:prompt_tokens": [("_total",
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens":
Expand Down Expand Up @@ -151,9 +150,6 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):
"vllm:request_params_n_sum",
"vllm:request_params_n_bucket",
"vllm:request_params_n_count",
"vllm:request_params_best_of_sum",
"vllm:request_params_best_of_bucket",
"vllm:request_params_best_of_count",
"vllm:num_preemptions_total",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
Expand Down
1 change: 0 additions & 1 deletion tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
"vllm:e2e_request_latency_seconds",
"vllm:request_prompt_tokens",
"vllm:request_generation_tokens",
"vllm:request_params_best_of",
"vllm:request_params_n",
]
for metric_name in request_histogram_metrics:
Expand Down
4 changes: 0 additions & 4 deletions tests/tracing/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def test_traces(trace_service):
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
Expand Down Expand Up @@ -155,8 +153,6 @@ def test_traces_with_detailed_steps(trace_service):
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# async_output_proc is allowed only when we have a single sequence
# in the sequence group
no_single_seq = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1)
seq_group.sampling_params.n == 1)
return no_single_seq

def schedule(
Expand Down
11 changes: 2 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def add_request(
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create `n` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Expand Down Expand Up @@ -1242,8 +1242,7 @@ def _advance_to_next_step(
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)")
" (i.e sampling_params.n == 1)")
sample = sequence_group_outputs.samples[0]

assert len(seq_group.seqs) == 1
Expand Down Expand Up @@ -1612,7 +1611,6 @@ def _get_stats(self,
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = []
finished_reason_requests: List[str] = []

Expand Down Expand Up @@ -1683,8 +1681,6 @@ def _get_stats(self,
for seq in seq_group.get_finished_seqs()
])
if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
Expand Down Expand Up @@ -1737,7 +1733,6 @@ def _get_stats(self,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
)
Expand Down Expand Up @@ -1824,8 +1819,6 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
Expand Down
8 changes: 0 additions & 8 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,6 @@ def __init__(self, labelnames: List[str], max_model_len: int):
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
Expand Down Expand Up @@ -473,8 +467,6 @@ def _log_prometheus(self, stats: Stats) -> None:
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)

def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/metrics_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class Stats:
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1:
if sampling_params.n == 1:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
Expand Down
17 changes: 8 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _random_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
# Find the maximum n value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
results: SampleResultType = []
Expand All @@ -523,9 +523,9 @@ def _random_sample(
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.best_of
parent_ids = [0] * sampling_params.n
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
sample_idx, :sampling_params.n].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
Expand Down Expand Up @@ -570,7 +570,7 @@ def _beam_search_sample(
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
beam_width = sampling_params.n
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
Expand Down Expand Up @@ -797,12 +797,11 @@ def _sample_with_torch(
greedy_samples)

elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
max_n_in_batch = 1
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)

Expand All @@ -812,13 +811,13 @@ def _sample_with_torch(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
max_n_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_best_of_in_batch,
max_n_in_batch,
seq_groups=seq_groups_arg)

if sampled_token_ids_tensor is not None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = sampling_params.n
n = sampling_params._real_n or sampling_params.n
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
Expand Down
33 changes: 17 additions & 16 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ class SamplingParams(
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
Expand Down Expand Up @@ -173,6 +172,7 @@ class SamplingParams(

n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
Expand Down Expand Up @@ -282,7 +282,19 @@ def from_optional(
)

def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
Expand Down Expand Up @@ -329,12 +341,6 @@ def _verify_args(self) -> None:
f"type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not isinstance(self.best_of, int):
raise ValueError(f"best_of must be an int, but is of "
f"type {type(self.best_of)}")
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
Expand Down Expand Up @@ -385,18 +391,14 @@ def _verify_args(self) -> None:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")

def _verify_greedy_sampling(self) -> None:
if self.n > 1:
raise ValueError("n must be 1 when using greedy sampling, "
f"got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling, "
f"got {self.best_of}.")

def update_from_generation_config(
self,
Expand Down Expand Up @@ -453,7 +455,6 @@ def clone(self) -> "SamplingParams":
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
Expand Down
10 changes: 5 additions & 5 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,14 @@ def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
if best_of > self.num_seqs():
n = self.sampling_params.n
assert isinstance(n, int)
if n > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# generation stage, we will have `n` sequences
# running.
return best_of
return n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
Expand Down
1 change: 0 additions & 1 deletion vllm/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class SpanAttributes(BaseSpanAttributes):
# The following span attribute names are added here because they are missing
# from the Semantic Conventions for LLM.
LLM_REQUEST_ID = "gen_ai.request.id"
LLM_REQUEST_BEST_OF = "gen_ai.request.best_of"
LLM_REQUEST_N = "gen_ai.request.n"
LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
Expand Down
Loading

0 comments on commit 3c52732

Please sign in to comment.