From 067f20851abe51ccb21df67a81125b05258815c2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 11 Oct 2024 19:08:11 +0800 Subject: [PATCH] [Misc] Collect model support info in a single process per model (#9233) --- docs/source/models/adding_model.rst | 2 +- vllm/engine/arg_utils.py | 2 + vllm/engine/multiprocessing/engine.py | 3 + vllm/model_executor/models/registry.py | 380 +++++++++++++++---------- 4 files changed, 228 insertions(+), 159 deletions(-) diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index fa1003874033e..ae09259c0756c 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a 5. Register your model ---------------------- -Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py `_. +Finally, register your :code:`*ForCausalLM` class to the :code:`_VLLM_MODELS` in `vllm/model_executor/models/registry.py `_. 6. Out-of-Tree Model Integration -------------------------------------------- diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cae95d20ca23d..efdcec4ab797a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -183,6 +183,8 @@ class EngineArgs: def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model + + # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index eecca82cd2f7d..d68970e1da24c 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -130,6 +130,9 @@ def dead_error(self) -> BaseException: def from_engine_args(cls, engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): """Creates an MQLLMEngine from the engine arguments.""" + # Setup plugins for each process + from vllm.plugins import load_general_plugins + load_general_plugins() engine_config = engine_args.create_engine_config() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f1d484521acb9..b37452877cf0c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -3,8 +3,10 @@ import subprocess import sys import tempfile -from functools import lru_cache, partial -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union import cloudpickle import torch.nn as nn @@ -116,18 +118,13 @@ } # yapf: enable -_MODELS = { +_VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, } -# Architecture -> type or (module, class). -# out of tree models -_OOT_MODELS: Dict[str, Type[nn.Module]] = {} -_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {} - # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: List[str] = [] @@ -154,79 +151,125 @@ } -class ModelRegistry: +@dataclass(frozen=True) +class _ModelInfo: + is_text_generation_model: bool + is_embedding_model: bool + supports_multimodal: bool + supports_pp: bool @staticmethod - def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: - if model_arch in _MODELS: - module_relname, cls_name = _MODELS[model_arch] - return f"vllm.model_executor.models.{module_relname}", cls_name + def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + return _ModelInfo( + is_text_generation_model=is_text_generation_model(model), + is_embedding_model=is_embedding_model(model), + supports_multimodal=supports_multimodal(model), + supports_pp=supports_pp(model), + ) - if model_arch in _OOT_MODELS_LAZY: - return _OOT_MODELS_LAZY[model_arch] - raise KeyError(model_arch) +class _BaseRegisteredModel(ABC): - @staticmethod - @lru_cache(maxsize=128) - def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: - try: - mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) - except KeyError: - return None + @abstractmethod + def inspect_model_cls(self) -> _ModelInfo: + raise NotImplementedError - module = importlib.import_module(mod_name) - return getattr(module, cls_name, None) + @abstractmethod + def load_model_cls(self) -> Type[nn.Module]: + raise NotImplementedError - @staticmethod - def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch in _OOT_MODELS: - return _OOT_MODELS[model_arch] - - if is_hip(): - if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError( - f"Model architecture {model_arch} is not supported by " - "ROCm for now.") - if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - logger.warning( - "Model architecture %s is partially supported by ROCm: %s", - model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - return None +@dataclass(frozen=True) +class _RegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has already been imported in the main process. + """ + + interfaces: _ModelInfo + model_cls: Type[nn.Module] @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: - model = ModelRegistry._try_get_model_stateless(model_arch) - if model is not None: - return model + def from_model_cls(model_cls: Type[nn.Module]): + return _RegisteredModel( + interfaces=_ModelInfo.from_model_cls(model_cls), + model_cls=model_cls, + ) + + def inspect_model_cls(self) -> _ModelInfo: + return self.interfaces + + def load_model_cls(self) -> Type[nn.Module]: + return self.model_cls + + +@dataclass(frozen=True) +class _LazyRegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has not been imported in the main process. + """ + module_name: str + class_name: str + + # Performed in another process to avoid initializing CUDA + def inspect_model_cls(self) -> _ModelInfo: + return _run_in_subprocess( + lambda: _ModelInfo.from_model_cls(self.load_model_cls())) + + def load_model_cls(self) -> Type[nn.Module]: + mod = importlib.import_module(self.module_name) + return getattr(mod, self.class_name) + + +@lru_cache(maxsize=128) +def _try_load_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> Optional[Type[nn.Module]]: + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError(f"Model architecture '{model_arch}' is not " + "supported by ROCm for now.") + + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] + logger.warning( + "Model architecture '%s' is partially " + "supported by ROCm: %s", model_arch, msg) + + try: + return model.load_model_cls() + except Exception: + logger.exception("Error in loading model architecture '%s'", + model_arch) + return None - return ModelRegistry._try_get_model_stateful(model_arch) - @staticmethod - def resolve_model_cls( - architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") +@lru_cache(maxsize=128) +def _try_inspect_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> Optional[_ModelInfo]: + try: + return model.inspect_model_cls() + except Exception: + logger.exception("Error in inspecting model architecture '%s'", + model_arch) + return None - for arch in architectures: - model_cls = ModelRegistry._try_load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") +@dataclass +class _ModelRegistry: + # Keyed by model_arch + models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) - @staticmethod - def get_supported_archs() -> List[str]: - return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) + def get_supported_archs(self) -> List[str]: + return list(self.models.keys()) - @staticmethod - def register_model(model_arch: str, model_cls: Union[Type[nn.Module], - str]): + def register_model( + self, + model_arch: str, + model_cls: Union[Type[nn.Module], str], + ) -> None: """ Register an external model to be used in vLLM. @@ -238,7 +281,7 @@ def register_model(model_arch: str, model_cls: Union[Type[nn.Module], when importing the model and thus the related error :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. """ - if model_arch in _MODELS: + if model_arch in self.models: logger.warning( "Model architecture %s is already registered, and will be " "overwritten by the new model class %s.", model_arch, @@ -250,120 +293,141 @@ def register_model(model_arch: str, model_cls: Union[Type[nn.Module], msg = "Expected a string in the format `:`" raise ValueError(msg) - module_name, cls_name = split_str - _OOT_MODELS_LAZY[model_arch] = module_name, cls_name + model = _LazyRegisteredModel(*split_str) else: - _OOT_MODELS[model_arch] = model_cls + model = _RegisteredModel.from_model_cls(model_cls) - @staticmethod - @lru_cache(maxsize=128) - def _check_stateless( - func: Callable[[Type[nn.Module]], bool], - model_arch: str, - *, - default: Optional[bool] = None, - ) -> bool: - """ - Run a boolean function against a model and return the result. + self.models[model_arch] = model - If the model is not found, returns the provided default value. + def _raise_for_unsupported(self, architectures: List[str]): + all_supported_archs = self.get_supported_archs() - If the model is not already imported, the function is run inside a - subprocess to avoid initializing CUDA for the main program. - """ - model = ModelRegistry._try_get_model_stateless(model_arch) - if model is not None: - return func(model) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}") - try: - mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) - except KeyError: - if default is not None: - return default - - raise - - with tempfile.NamedTemporaryFile() as output_file: - # `cloudpickle` allows pickling lambda functions directly - input_bytes = cloudpickle.dumps( - (mod_name, cls_name, func, output_file.name)) - # cannot use `sys.executable __file__` here because the script - # contains relative imports - returned = subprocess.run( - [sys.executable, "-m", "vllm.model_executor.models.registry"], - input=input_bytes, - capture_output=True) - - # check if the subprocess is successful - try: - returned.check_returncode() - except Exception as e: - # wrap raised exception to provide more information - raise RuntimeError(f"Error happened when testing " - f"model support for{mod_name}.{cls_name}:\n" - f"{returned.stderr.decode()}") from e - with open(output_file.name, "rb") as f: - result = pickle.load(f) - return result + def _try_load_model_cls(self, + model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in self.models: + return None - @staticmethod - def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") + return _try_load_model_cls(model_arch, self.models[model_arch]) - is_txt_gen = partial(ModelRegistry._check_stateless, - is_text_generation_model, - default=False) + def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: + if model_arch not in self.models: + return None - return any(is_txt_gen(arch) for arch in architectures) + return _try_inspect_model_cls(model_arch, self.models[model_arch]) - @staticmethod - def is_embedding_model(architectures: Union[str, List[str]]) -> bool: + def _normalize_archs( + self, + architectures: Union[str, List[str]], + ) -> List[str]: if isinstance(architectures, str): architectures = [architectures] if not architectures: logger.warning("No model architectures are specified") - is_emb = partial(ModelRegistry._check_stateless, - is_embedding_model, - default=False) + return architectures - return any(is_emb(arch) for arch in architectures) + def inspect_model_cls( + self, + architectures: Union[str, List[str]], + ) -> _ModelInfo: + architectures = self._normalize_archs(architectures) - @staticmethod - def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return model_info - is_mm = partial(ModelRegistry._check_stateless, - supports_multimodal, - default=False) + return self._raise_for_unsupported(architectures) - return any(is_mm(arch) for arch in architectures) + def resolve_model_cls( + self, + architectures: Union[str, List[str]], + ) -> Tuple[Type[nn.Module], str]: + architectures = self._normalize_archs(architectures) - @staticmethod - def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) - is_pp = partial(ModelRegistry._check_stateless, - supports_pp, - default=False) + return self._raise_for_unsupported(architectures) - return any(is_pp(arch) for arch in architectures) + def is_text_generation_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).is_text_generation_model + def is_embedding_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).is_embedding_model + + def is_multimodal_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_multimodal + + def is_pp_supported_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_pp + + +ModelRegistry = _ModelRegistry({ + model_arch: _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() +}) + +_T = TypeVar("_T") + + +def _run_in_subprocess(fn: Callable[[], _T]) -> _T: + with tempfile.NamedTemporaryFile() as output_file: + # `cloudpickle` allows pickling lambda functions directly + input_bytes = cloudpickle.dumps((fn, output_file.name)) + + # cannot use `sys.executable __file__` here because the script + # contains relative imports + returned = subprocess.run( + [sys.executable, "-m", "vllm.model_executor.models.registry"], + input=input_bytes, + capture_output=True) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n" + f"{returned.stderr.decode()}") from e + + with open(output_file.name, "rb") as f: + return pickle.load(f) + + +def _run() -> None: + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + fn, output_file = pickle.loads(sys.stdin.buffer.read()) + + result = fn() -if __name__ == "__main__": - (mod_name, cls_name, func, - output_file) = pickle.loads(sys.stdin.buffer.read()) - mod = importlib.import_module(mod_name) - klass = getattr(mod, cls_name) - result = func(klass) with open(output_file, "wb") as f: f.write(pickle.dumps(result)) + + +if __name__ == "__main__": + _run()