Skip to content

Commit

Permalink
[Frontend] Add Early Validation For Chat Template / Tool Call Parser (v…
Browse files Browse the repository at this point in the history
…llm-project#9151)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
  • Loading branch information
alex-jw-brooks authored and garg-amit committed Oct 28, 2024
1 parent b0119f5 commit 6a72495
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 72 deletions.
178 changes: 109 additions & 69 deletions tests/entrypoints/openai/test_cli_args.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,131 @@
import json
import unittest

from vllm.entrypoints.openai.cli_args import make_arg_parser
import pytest

from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.utils import FlexibleArgumentParser

from ...utils import VLLM_PATH

LORA_MODULE = {
"name": "module2",
"path": "/path/to/module2",
"base_model_name": "llama"
}
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
assert CHATML_JINJA_PATH.exists()


class TestLoraParserAction(unittest.TestCase):
@pytest.fixture
def serve_parser():
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
return make_arg_parser(parser)

def setUp(self):
# Setting up argparse parser for tests
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
self.parser = make_arg_parser(parser)

def test_valid_key_value_format(self):
# Test old format: name=path
args = self.parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
### Tests for Lora module parsing
def test_valid_key_value_format(serve_parser):
# Test old format: name=path
args = serve_parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
assert args.lora_modules == expected


def test_valid_json_format(serve_parser):
# Test valid JSON format input
args = serve_parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
assert args.lora_modules == expected


def test_invalid_json_format(serve_parser):
# Test invalid JSON format input, missing closing brace
with pytest.raises(SystemExit):
serve_parser.parse_args([
'--lora-modules', '{"name": "module3", "path": "/path/to/module3"'
])
expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
self.assertEqual(args.lora_modules, expected)

def test_valid_json_format(self):
# Test valid JSON format input
args = self.parser.parse_args([

def test_invalid_type_error(serve_parser):
# Test type error when values are not JSON or key=value
with pytest.raises(SystemExit):
serve_parser.parse_args([
'--lora-modules',
json.dumps(LORA_MODULE),
'invalid_format' # This is not JSON or key=value format
])
expected = [
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)

def test_invalid_json_format(self):
# Test invalid JSON format input, missing closing brace
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module3", "path": "/path/to/module3"'
])

def test_invalid_type_error(self):
# Test type error when values are not JSON or key=value
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'invalid_format' # This is not JSON or key=value format
])

def test_invalid_json_field(self):
# Test valid JSON format but missing required fields
with self.assertRaises(SystemExit):
self.parser.parse_args([
'--lora-modules',
'{"name": "module4"}' # Missing required 'path' field
])

def test_empty_values(self):
# Test when no LoRA modules are provided
args = self.parser.parse_args(['--lora-modules', ''])
self.assertEqual(args.lora_modules, [])

def test_multiple_valid_inputs(self):
# Test multiple valid inputs (both old and JSON format)
args = self.parser.parse_args([


def test_invalid_json_field(serve_parser):
# Test valid JSON format but missing required fields
with pytest.raises(SystemExit):
serve_parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
'{"name": "module4"}' # Missing required 'path' field
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
self.assertEqual(args.lora_modules, expected)


if __name__ == '__main__':
unittest.main()
def test_empty_values(serve_parser):
# Test when no LoRA modules are provided
args = serve_parser.parse_args(['--lora-modules', ''])
assert args.lora_modules == []


def test_multiple_valid_inputs(serve_parser):
# Test multiple valid inputs (both old and JSON format)
args = serve_parser.parse_args([
'--lora-modules',
'module1=/path/to/module1',
json.dumps(LORA_MODULE),
])
expected = [
LoRAModulePath(name='module1', path='/path/to/module1'),
LoRAModulePath(name='module2',
path='/path/to/module2',
base_model_name='llama')
]
assert args.lora_modules == expected


### Tests for serve argument validation that run prior to loading
def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
"""Ensure validation fails if tool choice is enabled with no call parser"""
# If we enable-auto-tool-choice, explode with no tool-call-parser
args = serve_parser.parse_args(args=["--enable-auto-tool-choice"])
with pytest.raises(TypeError):
validate_parsed_serve_args(args)


def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
"""Ensure validation passes with tool choice enabled with a call parser"""
args = serve_parser.parse_args(args=[
"--enable-auto-tool-choice",
"--tool-call-parser",
"mistral",
])
validate_parsed_serve_args(args)


def test_chat_template_validation_for_happy_paths(serve_parser):
"""Ensure validation passes if the chat template exists"""
args = serve_parser.parse_args(
args=["--chat-template",
CHATML_JINJA_PATH.absolute().as_posix()])
validate_parsed_serve_args(args)


def test_chat_template_validation_for_sad_paths(serve_parser):
"""Ensure validation fails if the chat template doesn't exist"""
args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
with pytest.raises(ValueError):
validate_parsed_serve_args(args)
22 changes: 22 additions & 0 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,28 @@ def parse_audio(self, audio_url: str) -> None:
self._add_placeholder(placeholder)


def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return

elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")

elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")

else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")


def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
Expand Down
4 changes: 3 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down Expand Up @@ -577,5 +578,6 @@ def signal_handler(*_) -> None:
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)

uvloop.run(run_server(args))
15 changes: 15 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List, Optional, Sequence, Union

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
Expand Down Expand Up @@ -231,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
return parser


def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "serve":
return

# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)

# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")


def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
Expand Down
8 changes: 6 additions & 2 deletions vllm/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -142,7 +143,7 @@ def main():
env_setup()

parser = FlexibleArgumentParser(description="vLLM CLI")
subparsers = parser.add_subparsers(required=True)
subparsers = parser.add_subparsers(required=True, dest="subparser")

serve_parser = subparsers.add_parser(
"serve",
Expand Down Expand Up @@ -186,6 +187,9 @@ def main():
chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat")

args = parser.parse_args()
if args.subparser == "serve":
validate_parsed_serve_args(args)

# One of the sub commands should be executed.
if hasattr(args, "dispatch_function"):
args.dispatch_function(args)
Expand Down

0 comments on commit 6a72495

Please sign in to comment.