diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py new file mode 100644 index 0000000000000..7e4a6cc62d02b --- /dev/null +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -0,0 +1,226 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various number of speculative tokens. + +With those tests, we can say at least, Medusa would not break the +correctess for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + +# main model +# lmsys/vicuna-7b-v1.3 was to be used but it's causing +# OOM in CI pipeline, so using a smaller model. +MAIN_MODEL = "JackFram/llama-68m" + +# speculative model +SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random" + +# max. number of speculative tokens: this corresponds to +# num_heads in the config.json of the speculator model. +MAX_SPEC_TOKENS = 5 + +# precision +PRECISION = "float32" + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify greedy equality with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 128, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": k, + } + # Try a range of num. speculative tokens + for k in range(1, 1 + MAX_SPEC_TOKENS) + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode with different values of num_speculative_tokens. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that mlp speculative decoding produces exact equality + to without spec decode when speculation is disabled for large + batch sizes. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +if __name__ == "__main__": + import pytest + pytest.main([__file__]) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a4fe18d52d608..95950ad0c5a1f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -62,6 +62,7 @@ "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM") } diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py new file mode 100644 index 0000000000000..6453d0cb25c91 --- /dev/null +++ b/vllm/model_executor/models/medusa.py @@ -0,0 +1,159 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.medusa import MedusaConfig + + +class ResidualBlock(nn.Module): + + def __init__(self, hidden_size: int, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList([ + nn.Linear(hidden_size, hidden_size, bias=False) + for _ in range(num_layers) + ]) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = x + self.act(layer(x)) + return x + + +class Medusa(nn.Module): + + def __init__(self, config: MedusaConfig, **_) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + ResidualBlock(hidden_size=self.config.hidden_size, + num_layers=self.config.num_hidden_layers) + for _ in range(self.config.num_heads) + ]) + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + self.token_map = None + + def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: + return [block(hidden_states) for block in self.blocks] + + def compute_logits( + self, hidden_states: List[torch.Tensor], + sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: + logits = [] + + for hs, lm_head in zip(hidden_states, self.lm_heads): + _logits = self.logits_processor(lm_head, hs, sampling_metadata) + + if self.token_map is None: + logits.append(_logits) + else: + logits.append(-torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype)) + + logits[-1][..., self.token_map] = _logits + + return logits + + def sample( + self, + logits: List[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + logits = torch.stack(logits, dim=0).float() + logprobs = torch.log_softmax(logits, dim=-1) + token_ids = logits.argmax(-1) # support only top-1 for now + probs = torch.softmax(logits, dim=-1) + + token_id_list = [] + token_prob_list = [] + token_logprob_list = [] + + for idx, seq_group in enumerate(sampling_metadata.seq_groups): + token_id_list.append(token_ids[:, seq_group.sample_indices]) + token_prob_list.append(probs[:, seq_group.sample_indices]) + token_logprob_list.append(logprobs[:, seq_group.sample_indices]) + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(sampling_metadata.seq_groups)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx].squeeze(1), + logprobs=token_logprob_list[idx].squeeze(1), + sampled_token_ids=token_id_list[idx].squeeze(1), + )) + + return outputs + + def generate_proposals( + self, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + return self.sample( + logits=self.compute_logits( + hidden_states=self.forward(previous_hidden_states), + sampling_metadata=sampling_metadata, + ), + sampling_metadata=sampling_metadata, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + + weights_map = {} + + for name, loaded_weight in weights: + name = name.replace("medusa_heads.", "") + + if name == "token_map": + if self.truncated_vocab_size < self.orig_vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name in params_dict: + weights_map[name] = loaded_weight + + for name, loaded_weight in weights_map.items(): + if "lm_head" in name and self.token_map is not None and\ + loaded_weight.shape[0] > self.token_map.shape[0]: + + loaded_weight = loaded_weight[self.token_map] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if self.token_map is not None: + self.token_map.to(device=self.lm_heads[0].weight.device) + + assert (self.truncated_vocab_size + == self.orig_vocab_size) or (self.token_map is not None) diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py new file mode 100644 index 0000000000000..b72740fc3961c --- /dev/null +++ b/vllm/spec_decode/medusa_worker.py @@ -0,0 +1,127 @@ +import weakref +from typing import List, Optional, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker import Worker + + +class MedusaWorker(NonLLMProposerWorkerBase, Worker): + """Worker for Medusa. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Lazy initialization list. + self._proposer: Top1Proposer + + def init_device(self): + super().init_device() + + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self): + pass + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For medusa worker, this indicator shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + seq_lens, query_lens = self._prepare_input_tensors( + seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory) + + model_outputs = self.model_runner.model.generate_proposals( + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + sampling_metadata=sampling_metadata) + + return model_outputs, False + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[List[int], List[int]]: + if not seq_group_metadata_list: + return [], [] + + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + seq_lens.append(seq_len) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + query_lens.append(1) + + return seq_lens, query_lens + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_spec_proposals(execute_model_req) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MedusaWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MedusaWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MedusaWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 43ce987de1e16..60a7dab68b7fd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -18,6 +18,7 @@ from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker @@ -129,6 +130,10 @@ def create_worker( "model_config"].hf_config.model_type == "mlp_speculator": disable_bonus_tokens = False proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + elif draft_worker_kwargs[ + "model_config"].hf_config.model_type == "medusa": + disable_bonus_tokens = False + proposer_worker = MedusaWorker(**draft_worker_kwargs) else: if draft_tp == 1: draft_worker_kwargs[ diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5e2fe116db9c6..652505a892142 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,8 +6,9 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - JAISConfig, MLPSpeculatorConfig, - MPTConfig, RWConfig) + JAISConfig, MedusaConfig, + MLPSpeculatorConfig, MPTConfig, + RWConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -24,6 +25,7 @@ "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, "mlp_speculator": MLPSpeculatorConfig, + "medusa": MedusaConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index d8170858c2a9a..51de11ca3e42a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -5,6 +5,7 @@ # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig @@ -14,5 +15,6 @@ "MPTConfig", "RWConfig", "JAISConfig", + "MedusaConfig", "MLPSpeculatorConfig", ] diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py new file mode 100644 index 0000000000000..d71a08343be2a --- /dev/null +++ b/vllm/transformers_utils/configs/medusa.py @@ -0,0 +1,60 @@ +import os +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class MedusaConfig(PretrainedConfig): + model_type = "medusa" + + def __init__(self, + hidden_size: int = 4096, + vocab_size: int = 32001, + num_heads: int = 5, + num_hidden_layers: int = 1, + max_paths: int = 64, + topk: int = 10, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.max_paths = max_paths + self.topk = topk + self.max_seq_len = int(2**20) + self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ + else truncated_vocab_size + if "architectures" not in kwargs: + kwargs["architectures"] = ["MedusaModel"] + + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "MedusaConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + for k in list(config_dict.keys()): + if 'num' in k: + if 'heads' in k: + config_dict["num_heads"] = config_dict.pop(k) + elif 'layers' in k: + config_dict["num_hidden_layers"] = config_dict.pop(k) + return cls.from_dict(config_dict, **kwargs) + + @property + def num_attention_heads(self): + return 0 + + @property + def num_lookahead_tokens(self): + return self.num_heads + + @num_lookahead_tokens.setter + def num_lookahead_tokens(self, num_lookahead_tokens: int): + self.num_heads = num_lookahead_tokens diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b25f29f485d95..34bca397e5c72 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -76,8 +76,9 @@ def __init__( speculative_args = {} if speculative_config is None \ or (speculative_config.draft_model_config.model == model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type != - "mlp_speculator") else {"return_hidden_states": True} + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator"]) \ + else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_runner_cls is not None: