diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index ca0cbef5cbf48..3ab235754a404 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SequenceData @@ -21,7 +20,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (group_weights_with_prefix, init_vllm_registered_model, +from .utils import (AutoWeightsLoader, init_vllm_registered_model, merge_multimodal_embeddings) # We use this internally as placeholders since there is no image token @@ -687,35 +686,5 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.vision_model.load_weights(weights_group["vision_model"]) - - # load query tokens - for name, loaded_weight in weights_group["query_tokens"]: - assert name == "" - param = self.query_tokens - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load qformer - qformer_params_dict = dict(self.qformer.named_parameters()) - for name, loaded_weight in weights_group["qformer"]: - param = qformer_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load mlp projector - mlp_params_dict = dict(self.language_projection.named_parameters()) - for name, loaded_weight in weights_group["language_projection"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 835931746fd4b..62a1b1f8cd4cb 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -42,8 +41,7 @@ SequenceData) from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, group_weights_with_prefix, - merge_multimodal_embeddings) +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -349,16 +347,5 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision embeddings - vision_params_dict = dict(self.vision_embed_tokens.named_parameters()) - for name, loaded_weight in weights_group["vision_embed_tokens"]: - param = vision_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f1899d92b02b6..c442b6d2e7c96 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -40,7 +40,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (group_weights_with_prefix, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -447,19 +447,9 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - weights_group = group_weights_with_prefix(weights) - - self.model.load_weights(weights_group["model"]) - - if not self.config.tie_word_embeddings: - # NOTE: For now self.lm_head is not defined because - # tie_word_embeddings is assumed to the False - lm_head_dict = dict(self.lm_head.named_parameters()) - for name, loaded_weight in weights_group["lm_head"]: - if is_pp_missing_parameter(name, self.lm_head): - continue - - param = lm_head_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 5048e9aa240c1..9024831df543c 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -20,7 +20,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -32,8 +31,8 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, group_weights_with_prefix, - init_vllm_registered_model, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) IMG_START = '' IMG_END = '' @@ -609,19 +608,5 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.vision_model.load_weights(weights_group["vision_model"]) - - # load mlp projector - mlp_params_dict = dict(self.mlp1.named_parameters()) - for name, loaded_weight in weights_group["mlp1"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1e97cf06394bf..ecc7917746bda 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -52,8 +52,7 @@ from vllm.utils import is_hip from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, group_weights_with_prefix, - is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -573,25 +572,14 @@ def sample(self, logits: torch.Tensor, return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - weights = [ + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights( self.maybe_remap_mistral(name, loaded_weight) - for name, loaded_weight in weights - ] - - weights_group = group_weights_with_prefix(weights) - - self.model.load_weights(weights_group["model"]) - - if not self.config.tie_word_embeddings: - lm_head_dict = dict(self.lm_head.named_parameters()) - for name, loaded_weight in weights_group["lm_head"]: - if is_pp_missing_parameter(name, self.lm_head): - continue - - param = lm_head_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for name, loaded_weight in weights) def load_kv_cache_scales(self, quantization_param_path: str) -> None: self.model.load_kv_cache_scales(quantization_param_path) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index a62231b628cb9..a3acb93dc3c11 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -13,7 +13,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -26,8 +25,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (flatten_bn, group_weights_with_prefix, - init_vllm_registered_model, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): @@ -406,19 +405,5 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.vision_tower.load_weights(weights_group["vision_tower"]) - - # load mlp projector - mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in weights_group["multi_modal_projector"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index efad800d7d760..766f6a4cc83fa 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -15,7 +15,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -29,8 +28,8 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (flatten_bn, group_weights_with_prefix, - init_vllm_registered_model, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) # Result in the max possible feature size (2x2 grid of 336x336px tiles) MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 @@ -642,27 +641,5 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.vision_tower.load_weights(weights_group["vision_tower"]) - - # load mlp projector - mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in weights_group["multi_modal_projector"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load newline - for name, loaded_weight in weights_group["image_newline"]: - assert name == "" - param = self.image_newline - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 44b3073b46358..e10c1f9e6e04b 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -28,7 +27,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) -from .utils import (group_weights_with_prefix, init_vllm_registered_model, +from .utils import (AutoWeightsLoader, init_vllm_registered_model, merge_multimodal_embeddings) # For profile run @@ -458,19 +457,9 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.vision_tower.load_weights(weights_group["vision_tower"]) - - # load mlp projector - mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in weights_group["multi_modal_projector"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader( + self, + # This model doesn't support images for now + ignore_unexpected_prefixes=["image_newline"], + ) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index af957e35d8089..46e97e78d482b 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -20,7 +20,6 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, @@ -35,8 +34,8 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, dummy_video_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (flatten_bn, group_weights_with_prefix, - init_vllm_registered_model, merge_multimodal_embeddings) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -872,19 +871,5 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.vision_tower.load_weights(weights_group["vision_tower"]) - - # load mlp projector - mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in weights_group["multi_modal_projector"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 93032b4095917..99d000ea13a2c 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -11,7 +11,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -21,7 +20,7 @@ from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import group_weights_with_prefix, merge_multimodal_embeddings +from .utils import AutoWeightsLoader, merge_multimodal_embeddings logger = init_logger(__name__) @@ -292,19 +291,5 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision tower - self.vision_tower.load_weights(weights_group["vision_tower"]) - - # load mlp projector - mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) - for name, loaded_weight in weights_group["multi_modal_projector"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index b875a83f876be..00a04dac88789 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -31,7 +31,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -42,15 +41,11 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (flatten_bn, group_weights_with_prefix, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, merge_multimodal_embeddings) logger = init_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "model.vision_embed_tokens": "vision_embed_tokens", -} - # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 32044 @@ -295,35 +290,8 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load vision encoder - self.img_processor.load_weights(weights_group["img_processor"]) - - # load glb_GN - for name, loaded_weight in weights_group["glb_GN"]: - assert name == "" - param = self.glb_GN - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load sub_GN - for name, loaded_weight in weights_group["sub_GN"]: - assert name == "" - param = self.sub_GN - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load mlp projector - mlp_params_dict = dict(self.img_projection.named_parameters()) - for name, loaded_weight in weights_group["img_projection"]: - param = mlp_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 @@ -715,27 +683,12 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapping = { - "model.vision_embed_tokens.": "vision_embed_tokens.", - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - } - - def hf_to_vllm_name(key: str) -> str: - for hf_name, vllm_name in hf_to_vllm_mapping.items(): - if key.startswith(hf_name): - return key.replace(hf_name, vllm_name, 1) - - return key - - vllm_weights = {hf_to_vllm_name(k): v for k, v in weights} - - # prepare weight iterators for components - weights_group = group_weights_with_prefix(vllm_weights.items()) - - # load vision embeddings and encoder - self.vision_embed_tokens.load_weights( - weights_group["vision_embed_tokens"]) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_embed_tokens.": "vision_embed_tokens.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index f9db87b7a9fbc..eb9a9aa9364cc 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -48,8 +48,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, group_weights_with_prefix, - is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -435,17 +434,9 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - weights_group = group_weights_with_prefix(weights) - - self.model.load_weights(weights_group["model"]) - - if not self.config.tie_word_embeddings: - lm_head_dict = dict(self.lm_head.named_parameters()) - for name, loaded_weight in weights_group["lm_head"]: - if is_pp_missing_parameter(name, self.lm_head): - continue - - param = lm_head_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 1aeab72b46522..7dcf52a56e985 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -16,13 +16,12 @@ RowParallelLinear) from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsPP from .qwen2 import Qwen2Model -from .utils import group_weights_with_prefix +from .utils import AutoWeightsLoader class ReLU(nn.Module): @@ -120,13 +119,5 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - weights_group = group_weights_with_prefix(weights) - - self.model.load_weights(weights_group["model"]) - - score_dict = dict(self.score.named_parameters()) - for name, loaded_weight in weights_group["score"]: - param = score_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 101cf38c96b01..e162e3af008e4 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -25,11 +25,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.utils import (flatten_bn, - group_weights_with_prefix, - init_vllm_registered_model, - merge_multimodal_embeddings) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs, NestedTensors @@ -41,6 +36,8 @@ from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, merge_multimodal_embeddings) _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -498,30 +495,9 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # prepare weight iterators for components - weights_group = group_weights_with_prefix(weights) - - # load audio tower weights - audio_tower_weights = weights_group["audio_tower"] - audio_tower_params_dict = dict( - self.audio_tower.named_parameters( - prefix=self.audio_tower.base_model_prefix)) - for name, loaded_weight in audio_tower_weights: - if name in audio_tower_params_dict: - param = audio_tower_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load projector weights - projector_weights = weights_group["multi_modal_projector"] - projector_params_dict = dict( - self.multi_modal_projector.named_parameters()) - for name, loaded_weight in projector_weights: - param = projector_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # load llm backbone - self.language_model.load_weights(weights_group["language_model"]) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) + + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["audio_tower."]) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 916f373d4481e..89b64ba2fd43c 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools -from collections import UserDict -from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol, - Tuple, Union, overload) +from dataclasses import dataclass, field +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Protocol, Tuple, Union, overload) import torch import torch.nn as nn @@ -12,55 +12,184 @@ SchedulerConfig) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available +WeightsMapping = Mapping[str, Optional[str]] +"""If a key maps to a value of `None`, the corresponding weight is ignored.""" -class WeightsGroup(UserDict): - """ - Wraps grouped weights dictionary for a more informative error message - when attempting to access a weight component that does not exist. - """ - def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: - try: - return super().__getitem__(key) - except KeyError as exc: - msg = (f"There is no weights named with the prefix: {key}. " - f"Available prefix: {set(self.keys())}") - raise KeyError(msg) from exc +@dataclass +class WeightsMapper: + """Maps the name of each weight if they match the following patterns.""" + orig_to_new_substr: WeightsMapping = field(default_factory=dict) + orig_to_new_prefix: WeightsMapping = field(default_factory=dict) + orig_to_new_suffix: WeightsMapping = field(default_factory=dict) -def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], - prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: - """ - Helper function to load weights for inner vLLM models. + def _map_name(self, key: str) -> Optional[str]: + for substr, new_key in self.orig_to_new_substr.items(): + if substr in key: + if new_key is None: + return None - See also: - :ref:`init_vllm_registered_model` - """ - for name, loaded_weight in weights: - name = name.split(".") - if prefix == name.pop(0): - name = ".".join(name) - yield name, loaded_weight + key = key.replace(substr, new_key, 1) + + for prefix, new_key in self.orig_to_new_prefix.items(): + if key.startswith(prefix): + if new_key is None: + return None + + key = key.replace(prefix, new_key, 1) + + for suffix, new_key in self.orig_to_new_suffix.items(): + if key.endswith(suffix): + if new_key is None: + return None + + key = new_key.join(key.rsplit(suffix, 1)) + return key -def group_weights_with_prefix( - weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup: + def apply( + self, weights: Iterable[Tuple[str, torch.Tensor]] + ) -> Iterable[Tuple[str, torch.Tensor]]: + return ((out_name, data) for name, data in weights + if (out_name := self._map_name(name)) is not None) + + +class AutoWeightsLoader: """ - Helper function to group weights with prefix + Helper class to load weights into a :class:`torch.nn.Module`. It is able + to automatically detect child modules and parameters while iterating over + the weights only once. + + The weight loading logic for individual modules can be overridden + by defining a ``load_weights`` method. + + Similarly, the weight loading logic for individual parameters can be + overridden by defining a ``weight_loader`` method. """ - init_weights, repeated_weights = itertools.tee(weights, 2) - weights_prefix = {name.split(".")[0] for name, _ in init_weights} - repeated_weights = itertools.tee(repeated_weights, len(weights_prefix)) - - return WeightsGroup({ - prefix: filter_weights(component, prefix) - for component, prefix in zip(repeated_weights, weights_prefix) - }) + + def __init__( + self, + module: nn.Module, + *, + skip_prefixes: Optional[List[str]] = None, + ignore_unexpected_prefixes: Optional[List[str]] = None, + ) -> None: + super().__init__() + + self.module = module + self.skip_prefixes = skip_prefixes or [] + self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + + def _groupby_prefix( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]: + weights_by_parts = ((weight_name.split(".", 1), weight_data) + for weight_name, weight_data in weights) + + for prefix, group in itertools.groupby(weights_by_parts, + key=lambda x: x[0][0]): + yield ( + prefix, + # Because maxsplit=1 in weight_name.split(...), + # the length of `parts` must either be 1 or 2 + (("" if len(parts) == 1 else parts[1], weights_data) + for parts, weights_data in group), + ) + + def _get_qualname(self, prefix: str, rest: str) -> str: + if prefix == "": + return rest + if rest == "": + return prefix + + return ".".join((prefix, rest)) + + def _can_skip(self, qualname: str) -> bool: + return any(qualname.startswith(p) for p in self.skip_prefixes) + + def _can_ignore_unexpected(self, qualname: str) -> bool: + return any( + qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + + def _load_param( + self, + base_prefix: str, + param: nn.Parameter, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> None: + for weight_name, weight_data in weights: + weight_qualname = self._get_qualname(base_prefix, weight_name) + + if self._can_skip(weight_qualname): + continue + + if weight_name != "": + if not self._can_ignore_unexpected(weight_qualname): + raise ValueError( + f"Attempted to load nested weight '{weight_qualname}' " + f"into a single parameter '{base_prefix}'") + + continue + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight_data) + + def _load_module( + self, + base_prefix: str, + module: nn.Module, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> None: + if isinstance(module, PPMissingLayer): + return + + # Avoid infinite recursion since this function is typically + # called inside load_weights of the module itself + if module != self.module: + module_load_weights = getattr(module, "load_weights", None) + if callable(module_load_weights): + module_load_weights(weights) + return + + child_modules = dict(module.named_children()) + child_params = dict(module.named_parameters(recurse=False)) + + for child_prefix, child_weights in self._groupby_prefix(weights): + prefix = self._get_qualname(base_prefix, child_prefix) + + if self._can_skip(prefix): + continue + + if child_prefix in child_modules: + self._load_module(prefix, child_modules[child_prefix], + child_weights) + elif child_prefix in child_params: + self._load_param(prefix, child_params[child_prefix], + child_weights) + else: + if not self._can_ignore_unexpected(prefix): + msg = f"There is no module or parameter named '{prefix}'" + raise ValueError(msg) + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + *, + mapper: Optional[WeightsMapper] = None, + ) -> None: + if mapper is not None: + weights = mapper.apply(weights) + + self._load_module("", self.module, weights) def init_vllm_registered_model(