Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] improve model support check in another process #9208

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse==0.4.0
msgspec
cloudpickle

# packages to install to build the documentation
pydantic >= 2.8
Expand Down
67 changes: 35 additions & 32 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib
import string
import pickle
import subprocess
import sys
import uuid
import tempfile
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import cloudpickle
import torch.nn as nn

from vllm.logger import init_logger
Expand Down Expand Up @@ -282,36 +283,28 @@ def _check_stateless(

raise

valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in mod_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
if any(s not in valid_name_characters for s in func.__module__):
raise ValueError(f"Unsafe module name detected for {func}")
if any(s not in valid_name_characters for s in func.__name__):
raise ValueError(f"Unsafe class name detected for {func}")

err_id = uuid.uuid4()

stmts = ";".join([
f"from {mod_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])

result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True)

if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")

return result.returncode == 0
with tempfile.NamedTemporaryFile() as output_file:
# `cloudpickle` allows pickling lambda functions directly
input_bytes = cloudpickle.dumps(
(mod_name, cls_name, func, output_file.name))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned = subprocess.run(
[sys.executable, "-m", "vllm.model_executor.models.registry"],
input=input_bytes,
capture_output=True)

# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(f"Error happened when testing "
f"model support for{mod_name}.{cls_name}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
return result

@staticmethod
def is_text_generation_model(architectures: Union[str, List[str]]) -> bool:
Expand Down Expand Up @@ -364,3 +357,13 @@ def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
default=False)

return any(is_pp(arch) for arch in architectures)


if __name__ == "__main__":
(mod_name, cls_name, func,
output_file) = pickle.loads(sys.stdin.buffer.read())
mod = importlib.import_module(mod_name)
klass = getattr(mod, cls_name)
result = func(klass)
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))
Loading