Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][misc] simply output processing with shortcut for non-parallel sampling and non-beam search usecase #7117

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,

def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.n == 1 and not sampling_params.use_beam_search:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
sampling_params,
lora_req=seq_group.lora_request,
)
if seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return

# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
Expand Down Expand Up @@ -127,20 +150,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize and self.detokenizer:
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
seq, sampling_params)
else:
new_char_count = 0
self.stop_checker.maybe_stop_sequence(
seq,
new_char_count,
seq_group.sampling_params,
sampling_params,
lora_req=seq_group.lora_request,
)

# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
Expand All @@ -164,8 +187,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
beam_width = sampling_params.best_of
length_penalty = sampling_params.length_penalty

# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
Expand Down Expand Up @@ -219,8 +242,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params, best_running_seq, current_worst_seq)
sampling_params.early_stopping, sampling_params,
best_running_seq, current_worst_seq)

if stop_beam_search:
# Stop the beam search and remove all the running sequences from
Expand Down
Loading