Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-redhat committed Jul 11, 2024
1 parent 046479f commit bccd4b6
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 41 deletions.
98 changes: 63 additions & 35 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,23 @@ def __init__(self):
# speculative decoding.
self.include_gpu_probs_tensor = False

def _init_sampling_tensors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
assert logits is not None
_, vocab_size = logits.shape

(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)

self._sampling_tensors = sampling_tensors
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p

def forward(
self,
logits: torch.Tensor,
Expand All @@ -61,12 +78,21 @@ def forward(
_, vocab_size = logits.shape

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
else:
if self._do_penalties:
# In this case, we depend on the output tokens
# TODO: Check with Cade if this is needed for spec tokens
self._init_sampling_tensors(logits, sampling_metadata)

sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p

logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Apply presence and frequency penalties.
if do_penalties:
logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
Expand All @@ -77,8 +103,8 @@ def forward(

# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
Expand Down Expand Up @@ -113,18 +139,19 @@ def forward(
on_device_tensors = None

# Get the logprobs query results.
if not sampling_metadata.skip_logprobs:
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_cpu_samples:
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
else:
prompt_logprobs = None
sample_logprobs = None
return _build_sampler_output(sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_logprobs=sampling_metadata.skip_logprobs)

return _build_sampler_output(
sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_cpu_samples=sampling_metadata.skip_cpu_samples)

@property
def _should_modify_greedy_probs_inplace(self) -> bool:
Expand Down Expand Up @@ -544,19 +571,20 @@ def _sample_with_torch(

# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
if not sampling_metadata.skip_logprobs:
if not sampling_metadata.skip_cpu_samples:
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type in (SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))

sample_results = [
Expand Down Expand Up @@ -785,7 +813,7 @@ def _get_logprobs(

# print("query_indices_gpu: shape = {}, vals = {}".format(query_indices_gpu.shape, query_indices_gpu))
# print("next_token_ids_gpu: shape = {}, vals = {}".format(next_token_ids_gpu.shape, next_token_ids_gpu))

# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
Expand Down Expand Up @@ -1013,11 +1041,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output(
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
skip_logprobs: bool = False,
skip_cpu_samples: bool = False,
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
Expand All @@ -1027,23 +1055,23 @@ def _build_sampler_output(
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""

sampler_output: List[CompletionSequenceGroupOutput] = []
if not skip_logprobs:
if not skip_cpu_samples:
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
for parent_id, next_token_id, logprobs in zip(
parent_ids, next_token_ids, group_sample_logprobs):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
SequenceOutput(seq_ids[parent_id], next_token_id,
logprobs))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
CompletionSequenceGroupOutput(seq_outputs,
group_prompt_logprobs))

# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,17 @@ def __init__(
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
num_prompts: int,
skip_logprobs: bool = False,
skip_cpu_samples: bool = False,
reuse_sampling_tensors: bool = False,
) -> None:
self.seq_groups = seq_groups
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = num_prompts
self.skip_logprobs = skip_logprobs

# TODO: Add docs
self.skip_cpu_samples = skip_cpu_samples
self.reuse_sampling_tensors = reuse_sampling_tensors

@staticmethod
def prepare(
Expand Down
14 changes: 10 additions & 4 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
logger = init_logger(__name__)

log_advance_input = False

enable_advance_step = True

class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
Expand Down Expand Up @@ -260,6 +260,9 @@ def _advance_step(
return new_model_input

def _can_use_advance_step(self):
if not enable_advance_step:
return False

# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
return False
Expand Down Expand Up @@ -311,8 +314,8 @@ def _execute_model_with_advance_step(
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)

model_input.sampling_metadata.skip_logprobs = True

model_input.sampling_metadata.skip_cpu_samples = True
# Sample the next token.
outputs.append(
self.model.sample(
Expand All @@ -328,6 +331,8 @@ def _execute_model_with_advance_step(
else:
assert not model_input.is_prompt
model_input = self._advance_step(model_input, outputs[-1])

model_input.sampling_metadata.reuse_sampling_tensors = True

return outputs

Expand Down Expand Up @@ -381,7 +386,7 @@ def execute_model(
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)

model_input.sampling_metadata.skip_logprobs = True
model_input.sampling_metadata.skip_cpu_samples = True

# Sample the next token.
outputs.append(
Expand All @@ -393,5 +398,6 @@ def execute_model(
# Prepare the inputs for the next step.
if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1])
model_input.sampling_metadata.reuse_sampling_tensors = True

return outputs

0 comments on commit bccd4b6

Please sign in to comment.