Skip to content

Commit

Permalink
Improve docs & Add JSON decode example (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jan 30, 2024
1 parent 0617528 commit 97aa9b3
Show file tree
Hide file tree
Showing 19 changed files with 212 additions and 61 deletions.
43 changes: 38 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,21 @@ You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`.
The system will manage the state, chat template, parallelism and batching for you.

The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py)

### Control Flow
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.

```python
@sgl.function
def control_flow(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
def tool_use(s, question):
s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "

if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression")
elif s["tool"] == "web browser":
s += "The website url is" + sgl.gen("url")
elif s["tool"] == "search engine":
s += "The key word to search is" + sgl.gen("word")
```

### Parallelism
Expand Down Expand Up @@ -170,6 +172,8 @@ def image_qa(s, image_file, question):
s += sgl.assistant(sgl.gen("answer", max_tokens=256)
```

See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py).

### Constrained Decoding
Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models.
Expand All @@ -185,6 +189,35 @@ def regular_expression_gen(s):
)
```

### JSON Decoding

```python
character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)

@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
```

See also [json_decode.py](examples/usage/json_decode.py).


### Batching
Use `run_batch` to run a batch of requests with continuous batching.

Expand Down
2 changes: 1 addition & 1 deletion benchmark/json_fast_forward/bench_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# fmt: off
def character_gen(name, generate):
s = name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s = name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += generate(s, max_tokens=256, regex=character_regex)
return s
# fmt: on
Expand Down
2 changes: 1 addition & 1 deletion benchmark/json_fast_forward/bench_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# fmt: off
@sgl.function
def character_gen(s, name):
s += name+ " is a character in Harry Potter. Please fill in the following information about him/her.\n"
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on

Expand Down
12 changes: 12 additions & 0 deletions docs/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@ This doc describes the sampling parameters of the SGLang Runtime.
The `/generate` endpoint accepts the following arguments in the JSON format.

```python
@dataclass
class GenerateReqInput:
# The input prompt
text: Union[List[str], str]
# The image input
image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None
# Whether return logprobs of the prompts
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# Whether to stream output
stream: bool = False
```

Expand Down Expand Up @@ -84,3 +92,7 @@ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
prev = len(output)
print("")
```

### Multi modal

See [test_httpserver_llava.py](../test/srt/test_httpserver_llava.py).
3 changes: 3 additions & 0 deletions examples/quick_start/srt_example_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def batch():
runtime = sgl.Runtime(model_path="liuhaotian/llava-v1.5-7b",
tokenizer_path="llava-hf/llava-1.5-7b-hf")
sgl.set_default_backend(runtime)
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))

# Run a single request
print("\n========== single ==========\n")
Expand Down
8 changes: 6 additions & 2 deletions examples/usage/async_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Usage:
python3 async_io.py
"""
import asyncio
from sglang import Runtime

Expand Down Expand Up @@ -27,8 +31,8 @@ async def generate(

if __name__ == "__main__":
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
print("runtime ready")
print("--- runtime ready ---\n")

prompt = "Who is Alan Turing?"
sampling_params = {"max_new_tokens": 128}
asyncio.run(generate(runtime, prompt, sampling_params))
Expand Down
3 changes: 2 additions & 1 deletion examples/usage/choices_logprob.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python choices_logprob.py
"""

import sglang as sgl


Expand Down
81 changes: 81 additions & 0 deletions examples/usage/json_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python json_decode.py
"""
from enum import Enum

from pydantic import BaseModel, constr
import sglang as sgl
from sglang.srt.constrained.json_schema import build_regex_from_object


character_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
+ r""" "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
+ r""" "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
+ r""" "wand": \{\n"""
+ r""" "wood": "[\w\d\s]{1,16}",\n"""
+ r""" "core": "[\w\d\s]{1,16}",\n"""
+ r""" "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
+ r""" \},\n"""
+ r""" "alive": "(Alive|Deceased)",\n"""
+ r""" "patronus": "[\w\d\s]{1,16}",\n"""
+ r""" "bogart": "[\w\d\s]{1,16}"\n"""
+ r"""\}"""
)


@sgl.function
def character_gen(s, name):
s += name + " is a character in Harry Potter. Please fill in the following information about this character.\n"
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)


def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())


class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"


class Wizard(BaseModel):
name: str
age: int
weapon: Weapon


@sgl.function
def pydantic_wizard_gen(s):
s += "Give me a description about a wizard in the JSON format.\n"
s += sgl.gen(
"character",
max_tokens=128,
temperature=0,
regex=build_regex_from_object(Wizard), # Requires pydantic >= 2.0
)


def driver_character_gen():
state = character_gen.run(name="Hermione Granger")
print(state.text())


def driver_pydantic_wizard_gen():
state = pydantic_wizard_gen.run()
print(state.text())


if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
driver_character_gen()
# driver_pydantic_wizard_gen()
4 changes: 4 additions & 0 deletions examples/usage/openai_speculative.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Usage:
python3 openai_speculative.py
"""
from sglang import function, gen, set_default_backend, OpenAI


Expand Down
4 changes: 4 additions & 0 deletions examples/usage/parallel_sample.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Usage:
python3 parallel_sample.py
"""
import sglang as sgl


Expand Down
34 changes: 29 additions & 5 deletions examples/usage/readme_examples.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python readme_examples.py
"""
import sglang as sgl


@sgl.function
def tool_use(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
s += "To answer this question: " + question + ". "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "

if s["tool"] == "calculator":
s += "The math expression is" + sgl.gen("expression")
elif s["tool"] == "web browser":
s += "The website url is" + sgl.gen("url")
elif s["tool"] == "search engine":
s += "The key word to search is" + sgl.gen("word")


@sgl.function
Expand All @@ -28,6 +34,16 @@ def tip_suggestion(s):
s += "In summary" + sgl.gen("summary")


@sgl.function
def regular_expression_gen(s):
s += "Q: What is the IP address of the Google DNS servers?\n"
s += "A: " + sgl.gen(
"answer",
temperature=0,
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
)


@sgl.function
def text_qa(s, question):
s += "Q: " + question + "\n"
Expand All @@ -46,6 +62,12 @@ def driver_tip_suggestion():
print("\n")


def driver_regex():
state = regular_expression_gen.run()
print(state.text())
print("\n")


def driver_batching():
states = text_qa.run_batch(
[
Expand Down Expand Up @@ -74,9 +96,11 @@ def driver_stream():


if __name__ == "__main__":
sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
#sgl.set_default_backend(sgl.OpenAI("gpt-3.5-turbo-instruct"))
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))

driver_tool_use()
driver_tip_suggestion()
driver_regex()
driver_batching()
driver_stream()
24 changes: 0 additions & 24 deletions examples/usage/srt_example_regex.py

This file was deleted.

4 changes: 4 additions & 0 deletions examples/usage/streaming.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Usage:
python3 streaming.py
"""
import asyncio
import sglang as sgl

Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle", "pillow"]
"pydantic", "referencing", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
Expand Down
Loading

0 comments on commit 97aa9b3

Please sign in to comment.