Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher priority for user input of max_prefill_tokens & format #540

Merged
merged 2 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/gsm8k/bench_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(args):
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
#prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
# prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
temperature=0,
max_tokens=256,
stop="Question",
Expand Down
34 changes: 24 additions & 10 deletions benchmark/latency_throughput/bench_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
async with session.post(
api_url, headers=headers, json=pload
) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
Expand Down Expand Up @@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed)

api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)

if args.dataset:
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
else:
input_lens = np.random.randint(
int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts)
int(args.input_len * args.range_ratio),
args.input_len + 1,
size=args.num_prompts,
)
output_lens = np.random.randint(
int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts)
int(args.output_len * args.range_ratio),
args.output_len + 1,
size=args.num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts)
input_requests = []
for i in range(args.num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])])
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))

benchmark_start_time = time.perf_counter()
Expand Down Expand Up @@ -287,16 +302,15 @@ def main(args: argparse.Namespace):
)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--dataset", type=str, help="Path to the dataset."
)
parser.add_argument("--dataset", type=str, help="Path to the dataset.")
parser.add_argument("--input-len", type=int, default=2048)
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--range-ratio", type=float, default=1.0)
parser.add_argument(
"--tokenizer", type=str,
"--tokenizer",
type=str,
default="NousResearch/Meta-Llama-3-8B",
help="Name or path of the tokenizer."
help="Name or path of the tokenizer.",
)
parser.add_argument(
"--best-of",
Expand Down
2 changes: 1 addition & 1 deletion benchmark/mmlu/bench_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ def main(args):
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--nsub", type=int, default=60)
args = add_common_other_args_and_parse(parser)
main(args)
main(args)
2 changes: 1 addition & 1 deletion python/sglang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

# SGL Backends
from sglang.backend.anthropic import Anthropic
from sglang.backend.litellm import LiteLLM
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI
from sglang.backend.litellm import LiteLLM

# Global Configurations
from sglang.global_config import global_config
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/backend/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
self.model_name = model_name

self.chat_template = chat_template or get_chat_template_by_model_path(
model_name)
model_name
)

self.client_params = {
"api_key": api_key,
Expand Down
41 changes: 26 additions & 15 deletions python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import logging
import time
import warnings
import dataclasses
from typing import Callable, List, Optional, Union

import numpy as np
Expand Down Expand Up @@ -105,14 +105,16 @@ def __init__(
def get_chat_template(self):
return self.chat_template

def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
num_api_spec_tokens: int, spec_var_name: str):
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
)
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens

params = sampling_params.to_openai_kwargs()
for key, value in params.items():
Expand Down Expand Up @@ -151,8 +153,9 @@ def generate(
)
prompt = s.messages_
else:
return self._prepare_spec_execution(sampling_params,
s.num_api_spec_tokens, spec_var_name)
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_

Expand Down Expand Up @@ -325,7 +328,7 @@ def select(
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens

# TODO:
# 1. return logits as the scores
Expand Down Expand Up @@ -355,7 +358,9 @@ def select(
return decision, scores, None, None


def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
Expand Down Expand Up @@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return comp


def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
Expand All @@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def _execute_gen(self, expr: SglGen):
)
return

else: # Speculative execution on models with completion interface
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)

self.text_ += comp
Expand Down
6 changes: 2 additions & 4 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,10 @@ def to_anthropic_kwargs(self):
"top_p": self.top_p,
"top_k": self.top_k,
}

def to_litellm_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the LiteLLM backend."
)
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)

launch_server(server_args, None)
launch_server(server_args, None)
1 change: 1 addition & 0 deletions python/sglang/launch_server_llavavid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Launch the inference server for Llava-video model."""

import argparse
import multiprocessing as mp

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/constrained/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/constrained/fsm_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Cache for the compressed finite state machine."""

from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache

Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/constrained/jump_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

import interegular
import outlines.caching

from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_deterministic_fsm,
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Conversation templates."""

# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
Expand Down
31 changes: 23 additions & 8 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Utilities for Huggingface Transformers."""

import functools
import json
import os
import warnings
import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from typing import AbstractSet, Collection, Literal, Optional, Union

from huggingface_hub import snapshot_download
from transformers import (
Expand Down Expand Up @@ -179,6 +179,7 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken

PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""

# Read JSON
Expand All @@ -190,7 +191,8 @@ def __init__(self, tokenizer_path):
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"

Expand All @@ -202,7 +204,10 @@ def __init__(self, tokenizer_path):
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
[
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
Expand All @@ -216,14 +221,20 @@ def encode_patched(
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)

tokenizer.encode = functools.partial(encode_patched, tokenizer)

# Convert to HF interface
Expand All @@ -237,10 +248,14 @@ def encode(self, x, add_special_tokens=False):
def decode(self, x):
return self.tokenizer.decode(x)

def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)

def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
return self.tokenizer.decode_single_token_bytes(index).decode(
"utf-8", errors="ignore"
)
Loading