Skip to content

Commit

Permalink
Multiple minor fixes (#1530)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Sep 28, 2024
1 parent 065bb94 commit 4e4459b
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 23 deletions.
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ curl http://localhost:30000/generate \
}
}'
```
Learn more about the argument format [here](docs/en/sampling_params.md).

Learn more about the argument specification, streaming, and multi-modal support [here](docs/en/sampling_params.md).

### OpenAI Compatible API
In addition, the server supports OpenAI-compatible APIs.
Expand Down Expand Up @@ -202,7 +203,7 @@ response = client.embeddings.create(
print(response)
```

It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
It supports streaming, vision, and almost all features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).

### Additional Server Arguments
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
Expand All @@ -223,6 +224,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/en/custom_chat_template.md).
Expand All @@ -241,9 +243,9 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- Llama / Llama 2 / Llama 3 / Llama 3.1
- Mistral / Mixtral / Mistral NeMo
- Gemma / Gemma 2
- OLMoE
- Qwen / Qwen 2 / Qwen 2 MoE
- DeepSeek / DeepSeek 2
- OLMoE
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava`
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
Expand All @@ -265,7 +267,6 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- XVERSE / XVERSE MoE
- SmolLM


**Embedding Models**

- e5-mistral
Expand Down
11 changes: 9 additions & 2 deletions docs/en/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ curl http://localhost:30000/generate \
}
}'
```
Learn more about the argument format [here](https://sglang.readthedocs.io/en/latest/sampling_params.html).

Learn more about the argument specification, streaming, and multi-modal support [here](https://sglang.readthedocs.io/en/latest/sampling_params.html).

### OpenAI Compatible API
In addition, the server supports OpenAI-compatible APIs.
Expand Down Expand Up @@ -58,7 +59,7 @@ response = client.embeddings.create(
print(response)
```

It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).
It supports streaming, vision, and almost all features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/).

### Additional Server Arguments
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
Expand All @@ -79,6 +80,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096
```
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes.
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports various quantization strategies.
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](https://sglang.readthedocs.io/en/latest/custom_chat_template.html).
Expand All @@ -99,6 +101,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- Gemma / Gemma 2
- Qwen / Qwen 2 / Qwen 2 MoE
- DeepSeek / DeepSeek 2
- OLMoE
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava`
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
Expand All @@ -115,6 +118,10 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- ChatGLM
- InternLM 2
- Exaone 3
- BaiChuan2
- MiniCPM / MiniCPM 3
- XVERSE / XVERSE MoE
- SmolLM

**Embedding Models**

Expand Down
1 change: 1 addition & 0 deletions examples/runtime/async_io_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Usage:
python3 async_io.py
"""

Expand Down
9 changes: 4 additions & 5 deletions examples/runtime/openai_batch_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Usage:
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python openai_batch_chat.py
Note: Before running this script,
you should create the input.jsonl file with the following content:
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world! List 3 NBA players and tell a story"}],"max_tokens": 300}}
Expand All @@ -13,12 +15,10 @@
import time

import openai
from openai import OpenAI


class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
def __init__(self):
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")

self.client = client
Expand Down Expand Up @@ -81,8 +81,7 @@ def process_batch(self, input_file_path, endpoint, completion_window):


# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
processor = OpenAIBatchProcessor()

# Process the batch job
input_file_path = "input.jsonl"
Expand Down
10 changes: 3 additions & 7 deletions examples/runtime/openai_batch_complete.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@
"""

import json
import os
import time

import openai
from openai import OpenAI


class OpenAIBatchProcessor:
def __init__(self, api_key):
# client = OpenAI(api_key=api_key)
def __init__(self):
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")

self.client = client
Expand Down Expand Up @@ -82,11 +79,10 @@ def process_batch(self, input_file_path, endpoint, completion_window):


# Initialize the OpenAIBatchProcessor
api_key = os.environ.get("OPENAI_API_KEY")
processor = OpenAIBatchProcessor(api_key)
processor = OpenAIBatchProcessor()

# Process the batch job
input_file_path = "input_complete.jsonl"
input_file_path = "input.jsonl"
endpoint = "/v1/completions"
completion_window = "24h"

Expand Down
2 changes: 0 additions & 2 deletions examples/runtime/reward_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# launch server
# python -m sglang.launch_server --model LxzGordon/URM-LLaMa-3.1-8B --is-embedding

import json

import requests

url = "http://127.0.0.1:30000"
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/lang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def select(
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]
logprob_start_len = max(prompt_len - 2, 0) # For token healing

# Compute logprob
data = {
Expand All @@ -245,7 +246,7 @@ def select(
},
"return_logprob": True,
"return_text_in_logprobs": True,
"logprob_start_len": prompt_len - 2, # For token healing
"logprob_start_len": logprob_start_len,
}
obj = self._generate_http_request(s, data)

Expand All @@ -258,8 +259,8 @@ def select(
# Remove extra token if no token healing occurred
for i in range(len(input_token_logprobs)):
healed_token_str = input_token_logprobs[i][0][-1]
healed_token_logprob = input_token_logprobs[i][0][0]
if s.text_.endswith(healed_token_str):
healed_token_logprob = input_token_logprobs[i][0][0]
normalized_prompt_logprobs[i] = (
normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
- healed_token_logprob
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ async def async_generate(
if chunk == "data: [DONE]\n\n":
break
data = json.loads(chunk[5:].strip("\n"))
if hasattr(data, "text"):
if "text" in data:
cur = data["text"][pos:]
if cur:
yield cur
Expand Down

0 comments on commit 4e4459b

Please sign in to comment.