Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parallel grammar preprocessing #1996

Merged
merged 29 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8bc804c
feat(xgrammar): support xgrammar as one of the grammar backends
DarkSharpness Oct 19, 2024
cae33a9
fix: fix wrongly clearing the vocab_mask of outlines
DarkSharpness Oct 19, 2024
1b17c72
minor: fix the format by running pre-commit
DarkSharpness Oct 19, 2024
b23f632
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 20, 2024
d93f76e
fix: set the object to error when import failed
DarkSharpness Oct 21, 2024
ee43065
minor: set the default grammar backend as outlines
DarkSharpness Oct 21, 2024
652ef54
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 21, 2024
83d1502
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 22, 2024
5ce813c
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 23, 2024
b8648dd
refactor(constrained): add a new abstraction for constrained decoding
DarkSharpness Oct 23, 2024
e615ce3
minor(constrained): set import failure object as None to pass type check
DarkSharpness Oct 24, 2024
cd59ed0
fix(constrained): use DummyType to avoid type failure in 'isinstance'
DarkSharpness Oct 24, 2024
d01e7af
fix(constrained): fix wrong parameter order in initing bnf_cache
DarkSharpness Oct 24, 2024
e1de402
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 24, 2024
c07cd0d
minor: format the code using pre-commit
DarkSharpness Oct 24, 2024
8608c2b
fix(constrained): fix wrong jump-forward assertion
DarkSharpness Oct 25, 2024
cbdca83
minor: format the code using pre-commit
DarkSharpness Oct 25, 2024
bb0b28d
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 25, 2024
e0340eb
Merge branch 'main' into xgrammar-outlines
DarkSharpness Nov 10, 2024
0fa48e9
feat(constrained): support concurrent access of cache
DarkSharpness Nov 10, 2024
c7963d4
refactor(constrained): force to initialize grammar before a prefill b…
DarkSharpness Nov 11, 2024
3095e55
fix(constrained): update the api to sync with xgrammar
DarkSharpness Nov 11, 2024
565a197
Merge branch 'main' into xgrammar-outlines
DarkSharpness Nov 11, 2024
35aafa6
fix(constrained): need a dummy instead of object when xgrammar is not…
DarkSharpness Nov 11, 2024
6f0f00e
Merge branch 'main' into xgrammar-outlines
DarkSharpness Nov 12, 2024
e18782a
minor: rename the old 'GrammarMatcherInitContextCache' and 'GrammarMa…
DarkSharpness Nov 12, 2024
38036ce
refactor(constrained): use constrained.future to replace self-written…
DarkSharpness Nov 12, 2024
21cc82a
Improve the style
merrymercy Nov 12, 2024
8971e6b
Update unit tests
merrymercy Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions python/sglang/srt/constrained/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,14 @@ def build_regex_from_object(


try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
from xgrammar import CachedGrammarCompiler as GrammarMatcherInitContextCache
from xgrammar import CompiledGrammar as GrammarMatcherInitContext
from xgrammar import GrammarMatcher
DarkSharpness marked this conversation as resolved.
Show resolved Hide resolved
except ImportError as e:

# we rely on type information, so we have a dummy class here
class Dummy:
pass
def __init__(self):
pass

GrammarMatcher = Dummy
GrammarMatcherInitContext = Dummy
Expand Down
85 changes: 64 additions & 21 deletions python/sglang/srt/constrained/base_tool_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,93 @@
"""Base tool cache for constrained decoding tools."""

import time
from dataclasses import dataclass
from threading import Event, Lock
from typing import Any, Dict, Tuple


@dataclass
class MapEntry:
event: Event
value: Any

def __iter__(self):
return iter((self.event, self.value))


class BaseToolCache:
enable: bool
cache: Dict[str, MapEntry]
metrics: Dict[str, Any]
lock_cache: Lock
lock_metrics: Lock
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because BaseToolCache is not a dataclass, so these fields should go to __init__


def __init__(self, enable=True):
self.enable = enable
self.lock_cache = Lock()
self.lock_metrics = Lock()
self.reset()

def reset(self):
self.cache = {}
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
with self.lock_cache:
self.cache = {}
with self.lock_metrics:
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}

def query(self, key):
def _init_with_timer(key):
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
def _init_with_timer(self, key) -> Tuple[Any, float]:
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
return val, init_time

def update_time(self, init_time):
with self.lock_metrics:
curr_total = self.metrics["total"]
new_total = curr_total + 1

# Update average init time without old_avg * old_total to avoid overflow.
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
return val

if key in self.cache:
self.metrics["hit"] += 1
val = self.cache[key]
else:
# Cache miss or disabled.
val = _init_with_timer(key)
def query(self, key):
if not self.enable:
value, init_time = self._init_with_timer(key)
self.update_time(init_time)
return value

cache_hit = False

if self.enable:
with self.lock_cache:
if key in self.cache:
entry = self.cache[key]
cache_hit = True
else:
entry = MapEntry(Event(), None)
self.cache[key] = entry

with self.lock_metrics:
self.metrics["total"] += 1
self.cache[key] = val
return val
if cache_hit:
self.metrics["hit"] += 1

if cache_hit:
entry.event.wait()
else:
entry.value, init_time = self._init_with_timer(key)
self.update_time(init_time)
entry.event.set()
return entry.value

def init_value(self, key):
raise NotImplementedError()

def get_cache_hit_rate(self):
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
with self.lock_metrics:
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]

def get_avg_init_time(self):
return self.metrics["avg_init_time"]
with self.lock_metrics:
return self.metrics["avg_init_time"]
5 changes: 4 additions & 1 deletion python/sglang/srt/constrained/bnf_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
key_type, key_string = key
if key_type == "json":
return self.grammar_cache.get_init_context_for_json_schema(key_string)
return self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
elif key_type == "regex":
raise ValueError(f"regex hasn't been supported by xgrammar yet")
else:
Expand All @@ -59,3 +59,6 @@ def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
return GrammarMatcher(
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
)

def reset(self):
self.grammar_cache.clear()
48 changes: 48 additions & 0 deletions python/sglang/srt/constrained/future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Base tool cache for constrained decoding tools."""

import logging
import threading
from typing import Any, Callable

logger = logging.getLogger(__name__)


class FutureObject:
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, f: Callable[[], Any]):
self._result = None
self._exception = None
self._done = threading.Event()
self._thread = threading.Thread(target=self._run, args=(f,))
self._thread.daemon = True
self._thread.start()

def _run(self, f: Callable[[], Any]):
try:
self._result = f()
except Exception as e:
logger.exception(f"Error in getting a FutureObject: {e}")
finally:
self._done.set()

def get(self):
self._done.wait()
assert self._result is not None
return self._result

def is_complete(self) -> bool:
return self._done.is_set()
61 changes: 54 additions & 7 deletions python/sglang/srt/constrained/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sglang.srt.constrained import GrammarMatcher, RegexGuide
from sglang.srt.constrained.bnf_cache import BNFCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.future import FutureObject
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap

# from sglang.srt.managers.schedule_batch import Req
Expand Down Expand Up @@ -49,7 +50,7 @@ def can_jump(self):
return len(self.data) > 0


class Grammar:
class GrammarInner:
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
jump_map: Union[XGrammarJump, JumpForwardMap, None]

Expand Down Expand Up @@ -129,7 +130,7 @@ def jump_and_retokenize(
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool
bitmask = self.grammar.find_next_token_bitmask()
bitmask = self.grammar.get_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
Expand All @@ -140,6 +141,48 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
vocab_mask[guide.get_next_instruction(state).tokens] = 0


class Grammar:
data: Union[FutureObject, GrammarInner]

def __init__(self, data: FutureObject) -> None:
self.data = data

def _get(self) -> GrammarInner:
if isinstance(self.data, FutureObject):
self.data = self.data.get()
assert isinstance(self.data, GrammarInner)
return self.data

def is_complete(self) -> bool:
if isinstance(self.data, FutureObject):
return self.data.is_complete()
return True

def accept_token(self, token: int):
self._get().accept_token(token)

def try_jump(self, tokenizer) -> JumpHelper:
return self._get().try_jump(tokenizer)

def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
return self._get().jump_forward_str_state(helper)

def jump_and_retokenize(
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
):
self._get().jump_and_retokenize(old_output_ids, new_output_ids, next_state)

def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
self._get().fill_vocab_mask(vocab_mask, vocab_size)

# forward all the function calls to the inner object
# def __getattr__(self, name):
# if isinstance(self.data, FutureObject):
# self.data = self.data.get()
# assert isinstance(self.data, GrammarInner)
# return getattr(self.data, name)


class GrammarCache:
grammar_cache: Union[BNFCache, FSMCache]
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
Expand Down Expand Up @@ -172,19 +215,23 @@ def __init__(
)
self.jump_cache = JumpForwardCache() if allow_jump else None

def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
def _query(self, key: Tuple[str, str], vocab_size: int) -> GrammarInner:
if isinstance(self.grammar_cache, BNFCache):
assert not isinstance(self.jump_cache, JumpForwardCache)
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
return GrammarInner(
self.grammar_cache.query(key, vocab_size), self.jump_cache
)
else:
jump_map = None
guide, regex = self.grammar_cache.query(key)
if isinstance(self.jump_cache, JumpForwardCache):
jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map)
return GrammarInner((guide, 0), jump_map)

def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
return Grammar(FutureObject(lambda: self._query(key, vocab_size)))

def reset(self):
if isinstance(self.grammar_cache, FSMCache):
self.grammar_cache.reset()
self.grammar_cache.reset()
if isinstance(self.jump_cache, JumpForwardCache):
self.jump_cache.reset()
16 changes: 15 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def __init__(

# Init the FSM cache for constrained generation
self.grammar_cache = None
self.grammar_queue: List[Req] = []

if not server_args.skip_tokenizer_init:
self.grammar_cache = GrammarCache(
Expand Down Expand Up @@ -488,7 +489,10 @@ def handle_generate_request(
self.max_req_len - len(req.origin_input_ids) - 1,
)

self.waiting_queue.append(req)
if req.grammar is not None:
self.grammar_queue.append(req)
else:
self.waiting_queue.append(req)

def handle_embedding_request(
self,
Expand Down Expand Up @@ -634,6 +638,16 @@ def get_next_batch_to_run(self):
return self.running_batch

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar queue is ready
new_grammar_queue = []
for req in self.grammar_queue:
assert req.grammar is not None
if req.grammar.is_complete():
self.waiting_queue.append(req)
else:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue

# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
Expand Down
Loading