Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for guided decoding in offline interface #4130

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a5aeec2
first commit, added extra parameters for offline LLM, issue with work…
kevinbu233 Apr 16, 2024
dcbfc69
clean up cod code
kevinbu233 Apr 16, 2024
bd4d84c
add skip cleanup
simon-mo Apr 17, 2024
76c0924
Changed model and sampling parameters, cleaned up naming
kevinbu233 Apr 17, 2024
4709a98
cleand up code and created helper functions
kevinbu233 Apr 18, 2024
aab046d
first merge resolved
kevinbu233 Apr 18, 2024
84b1442
fix format
kevinbu233 Apr 19, 2024
1d6abe4
fix merge conflict
kevinbu233 Apr 19, 2024
daa2c8f
added docstrings for sampling params guided options
kevinbu233 Apr 23, 2024
35c73ee
fix merge conflict with main
kevinbu233 Apr 23, 2024
9e454c5
:erge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Apr 23, 2024
ef4cf6f
fixed support for multiple sampling params for LLM
kevinbu233 Apr 23, 2024
945125d
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Apr 23, 2024
1c0769d
added noqa for extra long line
kevinbu233 Apr 24, 2024
062f0bc
Update tests/entrypoints/test_local_LLM.py
simon-mo May 1, 2024
05a2512
Merge branch 'main' of github.com:vllm-project/vllm into yihuan_issue…
simon-mo May 2, 2024
438ab37
fix typing
simon-mo May 2, 2024
06c2205
fix test and more refactoring
simon-mo May 3, 2024
4158d78
use x2
simon-mo May 3, 2024
0d9e5a5
lint
simon-mo May 3, 2024
ff9ba7f
fix isort
simon-mo May 3, 2024
bbb59bf
merge with main
kevinbu233 May 27, 2024
d779f86
fixing merge issues
kevinbu233 May 30, 2024
f923677
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 May 30, 2024
a15b511
fixed merge
kevinbu233 May 30, 2024
292264a
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 May 30, 2024
77d42a8
format
kevinbu233 May 30, 2024
60ab6f6
finished merge first draft
kevinbu233 Jun 14, 2024
ae23772
merged main
kevinbu233 Jun 14, 2024
3fb6258
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 19, 2024
75cf9a7
fixed merge conflict and fixed suggestions
kevinbu233 Jun 19, 2024
c429ef8
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 21, 2024
da89c1b
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 22, 2024
6c8b82a
fix test_openai error
kevinbu233 Jun 23, 2024
4e759d9
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 24, 2024
4ac8abb
fixing response_format test
kevinbu233 Jun 24, 2024
2eedeba
temporay push
kevinbu233 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from transformers import (AutoModelForCausalLM, AutoProcessor,
LlavaForConditionalGeneration)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down Expand Up @@ -414,3 +415,74 @@ def get_tokenizer_pool_config(tokenizer_group_type):
pool_type="ray",
extra_config={})
raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")


@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}


@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]


@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement

select_statement: "SELECT" column "from" table "where" condition

column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number

number: "1" | "2"
""")
162 changes: 162 additions & 0 deletions tests/entrypoints/test_local_LLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# imports for guided decoding tests
import json
import re

import jsonschema
import pytest

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


@pytest.fixture(scope="session")
def llm():
return LLM(model=MODEL_NAME, max_model_len=15000)
simon-mo marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skip_global_cleanup
def test_simple_prompts(llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_regex_(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_regex=sample_regex))
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
],
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_json=sample_json_schema),
max_tokens=1000)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
],
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
print(outputs)
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_choice=sample_json_schema))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert generated_text in sample_json_schema
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_grammar=sample_sql_statements))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
pytest.main([__file__])
Loading
Loading