forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for guided decoding for offline LLM (vllm-project#6878)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
- Loading branch information
1 parent
ef9484f
commit 03d5399
Showing
9 changed files
with
352 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import json | ||
import re | ||
import weakref | ||
|
||
import jsonschema | ||
import pytest | ||
|
||
from vllm.entrypoints.llm import LLM | ||
from vllm.outputs import RequestOutput | ||
from vllm.sampling_params import SamplingParams | ||
|
||
from ...conftest import cleanup | ||
|
||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def llm(): | ||
# pytest caches the fixture so we use weakref.proxy to | ||
# enable garbage collection | ||
llm = LLM(model=MODEL_NAME, max_model_len=1024) | ||
|
||
with llm.deprecate_legacy_api(): | ||
yield weakref.proxy(llm) | ||
del llm | ||
cleanup() | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_regex(sample_regex, llm): | ||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
) | ||
outputs = llm.generate( | ||
prompts=[ | ||
f"Give an example IPv4 address with this regex: {sample_regex}" | ||
] * 2, | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options_request=dict(guided_regex=sample_regex)) | ||
|
||
assert outputs is not None | ||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(generated_text) | ||
assert generated_text is not None | ||
assert re.fullmatch(sample_regex, generated_text) is not None | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_json_completion(sample_json_schema, llm): | ||
sampling_params = SamplingParams( | ||
temperature=1.0, | ||
max_tokens=1000, | ||
) | ||
outputs = llm.generate( | ||
prompts=[ | ||
f"Give an example JSON for an employee profile " | ||
f"that fits this schema: {sample_json_schema}" | ||
] * 2, | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options_request=dict(guided_json=sample_json_schema)) | ||
|
||
assert outputs is not None | ||
|
||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
|
||
generated_text = output.outputs[0].text | ||
assert generated_text is not None | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
output_json = json.loads(generated_text) | ||
jsonschema.validate(instance=output_json, schema=sample_json_schema) | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_choice_completion(sample_guided_choice, llm): | ||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
) | ||
outputs = llm.generate( | ||
prompts="The best language for type-safe systems programming is ", | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options_request=dict(guided_choice=sample_guided_choice)) | ||
|
||
assert outputs is not None | ||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(generated_text) | ||
assert generated_text is not None | ||
assert generated_text in sample_guided_choice | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
|
||
|
||
@pytest.mark.skip_global_cleanup | ||
def test_guided_grammar(sample_sql_statements, llm): | ||
|
||
sampling_params = SamplingParams( | ||
temperature=0.8, | ||
top_p=0.95, | ||
max_tokens=1000, | ||
) | ||
outputs = llm.generate( | ||
prompts=("Generate a sql state that select col_1 from " | ||
"table_1 where it is equals to 1"), | ||
sampling_params=sampling_params, | ||
use_tqdm=True, | ||
guided_options_request=dict(guided_grammar=sample_sql_statements)) | ||
|
||
assert outputs is not None | ||
for output in outputs: | ||
assert output is not None | ||
assert isinstance(output, RequestOutput) | ||
prompt = output.prompt | ||
|
||
generated_text = output.outputs[0].text | ||
assert generated_text is not None | ||
# use Lark to parse the output, and make sure it's a valid parse tree | ||
from lark import Lark | ||
parser = Lark(sample_sql_statements) | ||
parser.parse(generated_text) | ||
|
||
# remove spaces for comparison b/c we removed them in the grammar | ||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( | ||
" ", "") | ||
|
||
assert generated_text.strip() == ground_truth | ||
|
||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.