Skip to content

Commit

Permalink
[misc] improve model support check in another process (vllm-project#9208
Browse files Browse the repository at this point in the history
)

Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
youkaichao authored and garg-amit committed Oct 28, 2024
1 parent 5fcf958 commit 1324d4b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 32 deletions.
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse==0.4.0
msgspec
cloudpickle

# packages to install to build the documentation
pydantic >= 2.8
Expand Down
67 changes: 35 additions & 32 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib
import string
import pickle
import subprocess
import sys
import uuid
import tempfile
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import cloudpickle
import torch.nn as nn

from vllm.logger import init_logger
Expand Down Expand Up @@ -282,36 +283,28 @@ def _check_stateless(

raise

valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in mod_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
if any(s not in valid_name_characters for s in func.__module__):
raise ValueError(f"Unsafe module name detected for {func}")
if any(s not in valid_name_characters for s in func.__name__):
raise ValueError(f"Unsafe class name detected for {func}")

err_id = uuid.uuid4()

stmts = ";".join([
f"from {mod_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])

result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True)

if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")

return result.returncode == 0
with tempfile.NamedTemporaryFile() as output_file:
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps(
(mod_name, cls_name, func, output_file.name))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
[sys.executable, "-m", "vllm.model_executor.models.registry"],
input=input_bytes,
capture_output=True)

# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error happened when testing "
f"model support for{mod_name}.{cls_name}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
return result

@staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
Expand Down Expand Up @@ -364,3 +357,13 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
default=False)

return any(is_pp(arch) for arch in architectures)


if __name__ == "__main__":
(mod_name, cls_name, func,
output_file) = pickle.loads(sys.stdin.buffer.read())
mod = importlib.import_module(mod_name)
klass = getattr(mod, cls_name)
result = func(klass)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))

0 comments on commit 1324d4b

Please sign in to comment.