diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 80037dda20015..d58f226136918 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2 myst-parser==2.0.0 sphinx-argparse==0.4.0 msgspec +cloudpickle # packages to install to build the documentation pydantic >= 2.8 diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index f7b95fdc79362..f1d484521acb9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -1,11 +1,12 @@ import importlib -import string +import pickle import subprocess import sys -import uuid +import tempfile from functools import lru_cache, partial from typing import Callable, Dict, List, Optional, Tuple, Type, Union +import cloudpickle import torch.nn as nn from vllm.logger import init_logger @@ -282,36 +283,28 @@ def _check_stateless( raise - valid_name_characters = string.ascii_letters + string.digits + "._" - if any(s not in valid_name_characters for s in mod_name): - raise ValueError(f"Unsafe module name detected for {model_arch}") - if any(s not in valid_name_characters for s in cls_name): - raise ValueError(f"Unsafe class name detected for {model_arch}") - if any(s not in valid_name_characters for s in func.__module__): - raise ValueError(f"Unsafe module name detected for {func}") - if any(s not in valid_name_characters for s in func.__name__): - raise ValueError(f"Unsafe class name detected for {func}") - - err_id = uuid.uuid4() - - stmts = ";".join([ - f"from {mod_name} import {cls_name}", - f"from {func.__module__} import {func.__name__}", - f"assert {func.__name__}({cls_name}), '{err_id}'", - ]) - - result = subprocess.run([sys.executable, "-c", stmts], - capture_output=True) - - if result.returncode != 0: - err_lines = [line.decode() for line in result.stderr.splitlines()] - if err_lines and err_lines[-1] != f"AssertionError: {err_id}": - err_str = "\n".join(err_lines) - raise RuntimeError( - "An unexpected error occurred while importing the model in " - f"another process. Error log:\n{err_str}") - - return result.returncode == 0 + with tempfile.NamedTemporaryFile() as output_file: + # `cloudpickle` allows pickling lambda functions directly + input_bytes = cloudpickle.dumps( + (mod_name, cls_name, func, output_file.name)) + # cannot use `sys.executable __file__` here because the script + # contains relative imports + returned = subprocess.run( + [sys.executable, "-m", "vllm.model_executor.models.registry"], + input=input_bytes, + capture_output=True) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error happened when testing " + f"model support for{mod_name}.{cls_name}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + return result @staticmethod def is_text_generation_model(architectures: Union[str, List[str]]) -> bool: @@ -364,3 +357,13 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: default=False) return any(is_pp(arch) for arch in architectures) + + +if __name__ == "__main__": + (mod_name, cls_name, func, + output_file) = pickle.loads(sys.stdin.buffer.read()) + mod = importlib.import_module(mod_name) + klass = getattr(mod, cls_name) + result = func(klass) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result))