diff --git a/tests/conftest.py b/tests/conftest.py index 5de3f1f2a2b90..c042160cbc44d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -885,6 +885,7 @@ def num_gpus_available(): temp_dir = tempfile.gettempdir() _dummy_opt_path = os.path.join(temp_dir, "dummy_opt") _dummy_llava_path = os.path.join(temp_dir, "dummy_llava") +_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding") @pytest.fixture @@ -923,3 +924,22 @@ def dummy_llava_path(): with open(json_path, "w") as f: json.dump(config, f) return _dummy_llava_path + + +@pytest.fixture +def dummy_gemma2_embedding_path(): + json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json") + if not os.path.exists(_dummy_gemma2_embedding_path): + snapshot_download(repo_id="BAAI/bge-multilingual-gemma2", + local_dir=_dummy_gemma2_embedding_path, + ignore_patterns=[ + "*.bin", "*.bin.index.json", "*.pt", "*.h5", + "*.msgpack" + ]) + assert os.path.exists(json_path) + with open(json_path, "r") as f: + config = json.load(f) + config["architectures"] = ["MyGemma2Embedding"] + with open(json_path, "w") as f: + json.dump(config, f) + return _dummy_gemma2_embedding_path diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index ee3f8911f318c..94be215258f89 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -2,7 +2,7 @@ import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, PoolingParams, SamplingParams from vllm.assets.image import ImageAsset from ..utils import fork_new_process_for_each_test @@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path): @fork_new_process_for_each_test -def test_oot_registration(dummy_opt_path): +def test_oot_registration_text_generation(dummy_opt_path): os.environ["VLLM_PLUGINS"] = "register_dummy_model" prompts = ["Hello, my name is", "The text does not matter"] sampling_params = SamplingParams(temperature=0) @@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path): assert rest == "" +@fork_new_process_for_each_test +def test_oot_registration_embedding(dummy_gemma2_embedding_path): + os.environ["VLLM_PLUGINS"] = "register_dummy_model" + prompts = ["Hello, my name is", "The text does not matter"] + sampling_params = PoolingParams() + llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy") + outputs = llm.encode(prompts, sampling_params) + + for output in outputs: + assert all(v == 0 for v in output.outputs.embedding) + + image = ImageAsset("cherry_blossom").pil_image.convert("RGB") @fork_new_process_for_each_test -def test_oot_multimodal_registration(dummy_llava_path): +def test_oot_registration_multimodal(dummy_llava_path): os.environ["VLLM_PLUGINS"] = "register_dummy_model" prompts = [{ "prompt": "What's in the image?", diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 299aeacb9f337..a2194fa15f90e 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,7 +3,14 @@ import pytest import torch.cuda -from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models import (is_embedding_model, + is_text_generation_model, + supports_multimodal) +from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, + _MULTIMODAL_MODELS, + _SPECULATIVE_DECODING_MODELS, + _TEXT_GENERATION_MODELS, + ModelRegistry) from vllm.platforms import current_platform from ..utils import fork_new_process_for_each_test @@ -12,7 +19,20 @@ @pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully - ModelRegistry.resolve_model_cls(model_arch) + model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) + + if model_arch in _SPECULATIVE_DECODING_MODELS: + pass # Ignore these models which do not have a unified format + else: + assert is_text_generation_model(model_cls) is ( + model_arch in _TEXT_GENERATION_MODELS + or model_arch in _MULTIMODAL_MODELS) + + assert is_embedding_model(model_cls) is (model_arch + in _EMBEDDING_MODELS) + + assert supports_multimodal(model_cls) is (model_arch + in _MULTIMODAL_MODELS) @fork_new_process_for_each_test diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py index 022ba66e38cc3..62a8f871fa51b 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py @@ -9,6 +9,12 @@ def register(): ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM) # Test passing lazy model + if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs(): + ModelRegistry.register_model( + "MyGemma2Embedding", + "vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding", + ) + if "MyLlava" not in ModelRegistry.get_supported_archs(): ModelRegistry.register_model("MyLlava", "vllm_add_dummy_model.my_llava:MyLlava") diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py new file mode 100644 index 0000000000000..1d61f6b74f520 --- /dev/null +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -0,0 +1,34 @@ +from typing import List, Optional, Union + +import torch + +from vllm.attention import AttentionMetadata +from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel +from vllm.sequence import IntermediateTensors + + +class MyGemma2Embedding(Gemma2EmbeddingModel): + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = super().forward( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + if isinstance(hidden_states, IntermediateTensors): + return hidden_states + + # Return all-zero embeddings + return torch.zeros_like(hidden_states) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 51054a147a06f..eaa2b93eb3331 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,10 +1,17 @@ from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, SupportsPP, has_inner_state, supports_lora, supports_multimodal, supports_pp) +from .interfaces_base import (VllmModelForEmbedding, + VllmModelForTextGeneration, is_embedding_model, + is_text_generation_model) from .registry import ModelRegistry __all__ = [ "ModelRegistry", + "VllmModelForEmbedding", + "is_embedding_model", + "VllmModelForTextGeneration", + "is_text_generation_model", "HasInnerState", "has_inner_state", "SupportsLoRA", diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 298174fa05965..278dfc52078ef 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,4 +1,3 @@ -import inspect from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, Protocol, Type, Union, overload, runtime_checkable) @@ -6,9 +5,9 @@ from typing_extensions import TypeIs from vllm.logger import init_logger +from vllm.utils import supports_kw if TYPE_CHECKING: - from vllm.attention import AttentionMetadata from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.sequence import IntermediateTensors @@ -142,9 +141,7 @@ def supports_lora( return result -def _supports_lora( - model: Union[Type[object], object], -) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: +def _supports_lora(model: Union[Type[object], object]) -> bool: if isinstance(model, type): return isinstance(model, _SupportsLoRAType) @@ -175,10 +172,7 @@ def make_empty_intermediate_tensors( def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", + *, intermediate_tensors: Optional["IntermediateTensors"], ) -> Union[torch.Tensor, "IntermediateTensors"]: """ @@ -205,10 +199,7 @@ def make_empty_intermediate_tensors( def forward( self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: "AttentionMetadata", + *, intermediate_tensors: Optional["IntermediateTensors"], ) -> Union[torch.Tensor, "IntermediateTensors"]: ... @@ -257,24 +248,19 @@ def supports_pp( return supports_attributes and supports_inspect -def _supports_pp_attributes( - model: Union[Type[object], object], -) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: +def _supports_pp_attributes(model: Union[Type[object], object]) -> bool: if isinstance(model, type): return isinstance(model, _SupportsPPType) return isinstance(model, SupportsPP) -def _supports_pp_inspect( - model: Union[Type[object], object], -) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: +def _supports_pp_inspect(model: Union[Type[object], object]) -> bool: model_forward = getattr(model, "forward", None) if not callable(model_forward): return False - forward_params = inspect.signature(model_forward).parameters - return "intermediate_tensors" in forward_params + return supports_kw(model_forward, "intermediate_tensors") @runtime_checkable diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py new file mode 100644 index 0000000000000..8d2d422f9891c --- /dev/null +++ b/vllm/model_executor/models/interfaces_base.py @@ -0,0 +1,191 @@ +from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, + overload, runtime_checkable) + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from typing_extensions import TypeIs, TypeVar + +from vllm.logger import init_logger +from vllm.utils import supports_kw + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import CacheConfig + from vllm.model_executor.layers.pooler import PoolerOutput + from vllm.model_executor.layers.quantization import QuantizationConfig + from vllm.model_executor.layers.sampler import SamplerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.sampling_metadata import SamplingMetadata + +logger = init_logger(__name__) + +# The type of HF config +C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True) + +# The type of hidden states +# Currently, T = torch.Tensor for all models except for Medusa +# which has T = List[torch.Tensor] +T = TypeVar("T", default=torch.Tensor) +T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) + +# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags +# for the base interfaces to avoid breaking OOT registration for existing models +# that don't inherit from the base interface classes + + +@runtime_checkable +class VllmModel(Protocol[C_co, T_co]): + + def __init__( + self, + config: C_co, + *, + cache_config: Optional["CacheConfig"], + quant_config: Optional["QuantizationConfig"], + ) -> None: + ... + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + ) -> T_co: + ... + + +def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: + model_init = model.__init__ + vllm_kws = ("cache_config", "quant_config") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_init, kw)) + + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its initializer: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_forward, kw)) + + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its initializer: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +@overload +def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: + ... + + +@overload +def is_vllm_model(model: object) -> TypeIs[VllmModel]: + ... + + +def is_vllm_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: + return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + + +@runtime_checkable +class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]): + + def compute_logits( + self, + hidden_states: T, + sampling_metadata: "SamplingMetadata", + ) -> Optional[T]: + """Return `None` if TP rank > 0.""" + ... + + def sample( + self, + logits: T, + sampling_metadata: "SamplingMetadata", + ) -> "SamplerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_text_generation_model( + model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: + ... + + +@overload +def is_text_generation_model( + model: object) -> TypeIs[VllmModelForTextGeneration]: + ... + + +def is_text_generation_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForTextGeneration]], + TypeIs[VllmModelForTextGeneration]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForTextGeneration) + + return isinstance(model, VllmModelForTextGeneration) + + +@runtime_checkable +class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): + + def pooler( + self, + hidden_states: T, + pooling_metadata: "PoolingMetadata", + ) -> "PoolerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_embedding_model( + model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: + ... + + +@overload +def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: + ... + + +def is_embedding_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForEmbedding) + + return isinstance(model, VllmModelForEmbedding) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ccb0e155ff4aa..46c69f17f4471 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,10 +12,12 @@ from vllm.utils import is_hip from .interfaces import supports_multimodal, supports_pp +from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) -_GENERATION_MODELS = { +_TEXT_GENERATION_MODELS = { + # [Decoder-only] "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), @@ -74,10 +76,9 @@ "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - # NOTE: The below models are for speculative decoding only - "MedusaModel": ("medusa", "Medusa"), - "EAGLEModel": ("eagle", "EAGLE"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + # [Encoder-decoder] + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), } _EMBEDDING_MODELS = { @@ -114,16 +115,18 @@ "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), } -_CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), + +_SPECULATIVE_DECODING_MODELS = { + "EAGLEModel": ("eagle", "EAGLE"), + "MedusaModel": ("medusa", "Medusa"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } _MODELS = { - **_GENERATION_MODELS, + **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, **_MULTIMODAL_MODELS, - **_CONDITIONAL_GENERATION_MODELS, + **_SPECULATIVE_DECODING_MODELS, } # Architecture -> type or (module, class). @@ -317,6 +320,19 @@ def _check_stateless( return result.returncode == 0 + @staticmethod + def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_txt_gen = partial(ModelRegistry._check_stateless, + is_text_generation_model, + default=False) + + return any(is_txt_gen(arch) for arch in architectures) + @staticmethod def is_embedding_model(architectures: Union[str, List[str]]) -> bool: if isinstance(architectures, str): @@ -324,7 +340,11 @@ def is_embedding_model(architectures: Union[str, List[str]]) -> bool: if not architectures: logger.warning("No model architectures are specified") - return any(arch in _EMBEDDING_MODELS for arch in architectures) + is_emb = partial(ModelRegistry._check_stateless, + is_embedding_model, + default=False) + + return any(is_emb(arch) for arch in architectures) @staticmethod def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: diff --git a/vllm/utils.py b/vllm/utils.py index e44365fa24990..bec2f951d69db 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1277,6 +1277,15 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +def supports_kw(callable: Callable[..., object], kw_name: str) -> bool: + params = inspect.signature(callable).parameters + if kw_name in params: + return True + + return any(param.kind == inspect.Parameter.VAR_KEYWORD + for param in params.values()) + + def get_allowed_kwarg_only_overrides( callable: Callable[..., object], overrides: Optional[Dict[str, Any]],