Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add Early Validation For Chat Template / Tool Call Parser #9151

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 109 additions & 69 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,131 @@
import json
import unittest

from vllm.entrypoints.openai.cli_args import make_arg_parser
import pytest

from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser

from ...utils import VLLM_PATH

LORA_MODULE = {
"name": "module2",
"path": "/path/to/module2",
"base_model_name": "llama"
}
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
assert CHATML_JINJA_PATH.exists()


class TestLoraParserAction(unittest.TestCase):
@pytest.fixture
def serve_parser():
Copy link
Contributor Author

@alex-jw-brooks alex-jw-brooks Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only new tests are the bottom 4 - I pulled the setup into a fixture and removed the test class since that's how most of the tests I've seen in vLLM have been written, and IIRC subclassing from unittest.TestCase is an issue for things like pytest.mark.parametrize, which I could see being nice for validation tests in the future.

Happy to put it back and move the tests into the class if there is a reason they were written this way though!

parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
return make_arg_parser(parser)

def setUp(self):
# Setting up argparse parser for tests
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
self.parser = make_arg_parser(parser)

def test_valid_key_value_format(self):
# Test old format: name=path
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
### Tests for Lora module parsing
def test_valid_key_value_format(serve_parser):
# Test old format: name=path
args = serve_parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
assert args.lora_modules == expected


def test_valid_json_format(serve_parser):
# Test valid JSON format input
args = serve_parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
assert args.lora_modules == expected


def test_invalid_json_format(serve_parser):
# Test invalid JSON format input, missing closing brace
with pytest.raises(SystemExit):
serve_parser.parse_args([
'--lora-modules', '{"name": "module3", "path": "/path/to/module3"'
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
self.assertEqual(args.lora_modules, expected)

def test_valid_json_format(self):
# Test valid JSON format input
args = self.parser.parse_args([

def test_invalid_type_error(serve_parser):
# Test type error when values are not JSON or key=value
with pytest.raises(SystemExit):
serve_parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
'invalid_format' # This is not JSON or key=value format
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)

def test_invalid_json_format(self):
# Test invalid JSON format input, missing closing brace
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module3", "path": "/path/to/module3"'
])

def test_invalid_type_error(self):
# Test type error when values are not JSON or key=value
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'invalid_format' # This is not JSON or key=value format
])

def test_invalid_json_field(self):
# Test valid JSON format but missing required fields
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module4"}' # Missing required 'path' field
])

def test_empty_values(self):
# Test when no LoRA modules are provided
args = self.parser.parse_args(['--lora-modules', ''])
self.assertEqual(args.lora_modules, [])

def test_multiple_valid_inputs(self):
# Test multiple valid inputs (both old and JSON format)
args = self.parser.parse_args([


def test_invalid_json_field(serve_parser):
# Test valid JSON format but missing required fields
with pytest.raises(SystemExit):
serve_parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
'{"name": "module4"}' # Missing required 'path' field
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)


if __name__ == '__main__':
unittest.main()
def test_empty_values(serve_parser):
# Test when no LoRA modules are provided
args = serve_parser.parse_args(['--lora-modules', ''])
assert args.lora_modules == []


def test_multiple_valid_inputs(serve_parser):
# Test multiple valid inputs (both old and JSON format)
args = serve_parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
assert args.lora_modules == expected


### Tests for serve argument validation that run prior to loading
def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
"""Ensure validation fails if tool choice is enabled with no call parser"""
# If we enable-auto-tool-choice, explode with no tool-call-parser
args = serve_parser.parse_args(args=["--enable-auto-tool-choice"])
with pytest.raises(TypeError):
validate_parsed_serve_args(args)


def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
"""Ensure validation passes with tool choice enabled with a call parser"""
args = serve_parser.parse_args(args=[
"--enable-auto-tool-choice",
"--tool-call-parser",
"mistral",
])
validate_parsed_serve_args(args)


def test_chat_template_validation_for_happy_paths(serve_parser):
"""Ensure validation passes if the chat template exists"""
args = serve_parser.parse_args(
args=["--chat-template",
CHATML_JINJA_PATH.absolute().as_posix()])
validate_parsed_serve_args(args)


def test_chat_template_validation_for_sad_paths(serve_parser):
"""Ensure validation fails if the chat template doesn't exist"""
args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
with pytest.raises(ValueError):
validate_parsed_serve_args(args)
22 changes: 22 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,28 @@ def parse_audio(self, audio_url: str) -> None:
self._add_placeholder(placeholder)


def validate_chat_template(chat_template: Optional[Union[Path, str]]):
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return

elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")

elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")

else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")


def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down Expand Up @@ -577,5 +578,6 @@ def signal_handler(*_) -> None:
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)

uvloop.run(run_server(args))
15 changes: 15 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Optional, Sequence, Union

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -231,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
return parser


def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return

# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)

# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
Expand Down
8 changes: 6 additions & 2 deletions vllm/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -142,7 +143,7 @@ def main():
env_setup()

parser = FlexibleArgumentParser(description="vLLM CLI")
subparsers = parser.add_subparsers(required=True)
subparsers = parser.add_subparsers(required=True, dest="subparser")

serve_parser = subparsers.add_parser(
"serve",
Expand Down Expand Up @@ -186,6 +187,9 @@ def main():
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")

args = parser.parse_args()
if args.subparser == "serve":
validate_parsed_serve_args(args)

# One of the sub commands should be executed.
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
Expand Down
Loading