Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support parallel grammar preprocessing #1996

Merged
merged 29 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8bc804c
feat(xgrammar): support xgrammar as one of the grammar backends
DarkSharpness Oct 19, 2024
cae33a9
fix: fix wrongly clearing the vocab_mask of outlines
DarkSharpness Oct 19, 2024
1b17c72
minor: fix the format by running pre-commit
DarkSharpness Oct 19, 2024
b23f632
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 20, 2024
d93f76e
fix: set the object to error when import failed
DarkSharpness Oct 21, 2024
ee43065
minor: set the default grammar backend as outlines
DarkSharpness Oct 21, 2024
652ef54
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 21, 2024
83d1502
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 22, 2024
5ce813c
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 23, 2024
b8648dd
refactor(constrained): add a new abstraction for constrained decoding
DarkSharpness Oct 23, 2024
e615ce3
minor(constrained): set import failure object as None to pass type check
DarkSharpness Oct 24, 2024
cd59ed0
fix(constrained): use DummyType to avoid type failure in 'isinstance'
DarkSharpness Oct 24, 2024
d01e7af
fix(constrained): fix wrong parameter order in initing bnf_cache
DarkSharpness Oct 24, 2024
e1de402
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 24, 2024
c07cd0d
minor: format the code using pre-commit
DarkSharpness Oct 24, 2024
8608c2b
fix(constrained): fix wrong jump-forward assertion
DarkSharpness Oct 25, 2024
cbdca83
minor: format the code using pre-commit
DarkSharpness Oct 25, 2024
bb0b28d
Merge branch 'main' into xgrammar-outlines
DarkSharpness Oct 25, 2024
e0340eb
Merge branch 'main' into xgrammar-outlines
DarkSharpness Nov 10, 2024
0fa48e9
feat(constrained): support concurrent access of cache
DarkSharpness Nov 10, 2024
c7963d4
refactor(constrained): force to initialize grammar before a prefill b…
DarkSharpness Nov 11, 2024
3095e55
fix(constrained): update the api to sync with xgrammar
DarkSharpness Nov 11, 2024
565a197
Merge branch 'main' into xgrammar-outlines
DarkSharpness Nov 11, 2024
35aafa6
fix(constrained): need a dummy instead of object when xgrammar is not…
DarkSharpness Nov 11, 2024
6f0f00e
Merge branch 'main' into xgrammar-outlines
DarkSharpness Nov 12, 2024
e18782a
minor: rename the old 'GrammarMatcherInitContextCache' and 'GrammarMa…
DarkSharpness Nov 12, 2024
38036ce
refactor(constrained): use constrained.future to replace self-written…
DarkSharpness Nov 12, 2024
21cc82a
Improve the style
merrymercy Nov 12, 2024
8971e6b
Update unit tests
merrymercy Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 0 additions & 39 deletions python/sglang/srt/constrained/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,11 @@
limitations under the License.
"""

"""For constrained decoding."""

import json
from typing import Dict, Optional, Union

from pydantic import BaseModel

try:
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
except ImportError as e:
print(
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
)
raise

try:
from outlines.fsm.json_schema import build_regex_from_object
except ImportError:
Expand All @@ -51,31 +37,6 @@ def build_regex_from_object(
return build_regex_from_schema(schema, whitespace_pattern)


try:
from xgrammar import (
GrammarMatcher,
GrammarMatcherInitContext,
GrammarMatcherInitContextCache,
)
except ImportError as e:

class Dummy:
pass

GrammarMatcher = Dummy
GrammarMatcherInitContext = Dummy
GrammarMatcherInitContextCache = Dummy

__all__ = [
"RegexGuide",
"FSMInfo",
"make_deterministic_fsm",
"build_regex_from_object",
"TransformerTokenizer",
"disk_cache",
"disable_cache",
"make_byte_level_fsm",
"GrammarMatcher",
"GrammarMatcherInitContext",
"GrammarMatcherInitContextCache",
]
85 changes: 62 additions & 23 deletions python/sglang/srt/constrained/base_tool_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,53 +13,92 @@
limitations under the License.
"""

"""Base tool cache for constrained decoding tools."""
"""Base cache class for constrained decoding tools."""

import time
from dataclasses import dataclass
from threading import Event, Lock
from typing import Any, Dict, Tuple


@dataclass
class MapEntry:
event: Event
value: Any

def __iter__(self):
return iter((self.event, self.value))


class BaseToolCache:

def __init__(self, enable=True):
self.enable = enable
self.enable: bool = enable
self.cache: Dict[str, MapEntry] = {}
self.metrics: Dict[str, Any] = {}
self.lock_cache: Lock = Lock()
self.lock_metrics: Lock = Lock()
self.reset()

def reset(self):
self.cache = {}
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}
with self.lock_cache:
self.cache = {}
with self.lock_metrics:
self.metrics = {"total": 0, "hit": 0, "avg_init_time": 0}

def query(self, key):
def _init_with_timer(key):
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
def _init_with_timer(self, key) -> Tuple[Any, float]:
start = time.monotonic()
val = self.init_value(key)
init_time = time.monotonic() - start
return val, init_time

def update_time(self, init_time):
with self.lock_metrics:
curr_total = self.metrics["total"]
new_total = curr_total + 1

# Update average init time without old_avg * old_total to avoid overflow.
self.metrics["avg_init_time"] = (init_time / new_total) + (
curr_total / new_total
) * self.metrics["avg_init_time"]
return val

if key in self.cache:
self.metrics["hit"] += 1
val = self.cache[key]
else:
# Cache miss or disabled.
val = _init_with_timer(key)
def query(self, key):
if not self.enable:
value, init_time = self._init_with_timer(key)
self.update_time(init_time)
return value

with self.lock_cache:
if key in self.cache:
entry = self.cache[key]
cache_hit = True
else:
entry = MapEntry(Event(), None)
self.cache[key] = entry
cache_hit = False

if self.enable:
with self.lock_metrics:
self.metrics["total"] += 1
self.cache[key] = val
return val
if cache_hit:
self.metrics["hit"] += 1

if cache_hit:
entry.event.wait()
else:
entry.value, init_time = self._init_with_timer(key)
self.update_time(init_time)
entry.event.set()
return entry.value

def init_value(self, key):
raise NotImplementedError()

def get_cache_hit_rate(self):
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]
with self.lock_metrics:
if self.metrics["total"] == 0:
return 0
return self.metrics["hit"] / self.metrics["total"]

def get_avg_init_time(self):
return self.metrics["avg_init_time"]
with self.lock_metrics:
return self.metrics["avg_init_time"]
80 changes: 36 additions & 44 deletions python/sglang/srt/constrained/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,44 @@

"""Cache for the compressed finite state machine."""
import logging
from typing import List, Optional, Tuple, Union
from concurrent.futures import Future, ThreadPoolExecutor
from typing import List, Tuple, Union

import torch

from sglang.srt.constrained import GrammarMatcher, RegexGuide
from sglang.srt.constrained.bnf_cache import BNFCache
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap

# from sglang.srt.managers.schedule_batch import Req
from sglang.srt.constrained.outlines_cache import OutlinesCache, RegexGuide
from sglang.srt.constrained.outlines_jump_forward import (
OutlinesJumpCache,
OutlinesJumpForwardMap,
)
from sglang.srt.constrained.xgrammar_cache import (
GrammarMatcher,
XGrammarBackend,
XGrammarJumpCache,
)

logger = logging.getLogger(__name__)

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5


class XGrammarJump:
pass


class JumpHelper:
data: Union[List, str]
state: int
suffix_ids: List[int]

def __init__(
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
) -> None:
self.data = data
self.state = state
self.suffix_ids = suffix_ids
self.data: Union[List, str] = data
self.state: int = state
self.suffix_ids: List[int] = suffix_ids

def can_jump(self):
return len(self.data) > 0


class Grammar:
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
jump_map: Union[XGrammarJump, JumpForwardMap, None]

def __init__(
self,
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
jump_map: Union[XGrammarJump, JumpForwardMap, None],
jump_map: Union[XGrammarJumpCache, OutlinesJumpForwardMap, None],
) -> None:
self.grammar = grammar
self.jump_map = jump_map
Expand All @@ -69,10 +63,10 @@ def accept_token(self, token: int):
self.grammar = guide, guide.get_next_state(state, token)

def try_jump(self, tokenizer) -> JumpHelper:
if isinstance(self.jump_map, XGrammarJump):
if isinstance(self.jump_map, XGrammarJumpCache):
assert isinstance(self.grammar, GrammarMatcher)
return JumpHelper(self.grammar.find_jump_forward_string())
elif isinstance(self.jump_map, JumpForwardMap):
elif isinstance(self.jump_map, OutlinesJumpForwardMap):
assert isinstance(self.grammar, Tuple)

_, state = self.grammar
Expand Down Expand Up @@ -103,7 +97,7 @@ def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
if isinstance(helper.data, str):
return helper.data, -1
else:
assert isinstance(self.jump_map, JumpForwardMap)
assert isinstance(self.jump_map, OutlinesJumpForwardMap)
return self.jump_map.jump_forward_symbol(helper.state)

def jump_and_retokenize(
Expand All @@ -129,7 +123,7 @@ def jump_and_retokenize(
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
if isinstance(self.grammar, GrammarMatcher):
# Note that this bitmask is a bitset, not bool
bitmask = self.grammar.find_next_token_bitmask()
bitmask = self.grammar.get_next_token_bitmask()
# Mask the tokens that are not allowed
vocab_mask[
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
Expand All @@ -140,9 +134,7 @@ def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
vocab_mask[guide.get_next_instruction(state).tokens] = 0


class GrammarCache:
grammar_cache: Union[BNFCache, FSMCache]
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
class GrammarBackend:

def __init__(
self,
Expand All @@ -153,38 +145,38 @@ def __init__(
backend=None,
allow_jump=False,
):
self.executor = ThreadPoolExecutor()
self.backend = backend

if backend == "xgrammar":
self.grammar_cache = BNFCache(
self.grammar_cache = XGrammarBackend(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
whitespace_patterns=whitespace_patterns,
)
self.jump_cache = XGrammarJump() if allow_jump else None
self.jump_cache = XGrammarJumpCache() if allow_jump else None
else:
assert backend == "outlines"
self.grammar_cache = FSMCache(
self.grammar_cache = OutlinesCache(
tokenizer_path=tokenizer_path,
tokenizer_args_dict=tokenizer_args_dict,
skip_tokenizer_init=skip_tokenizer_init,
constrained_json_whitespace_pattern=whitespace_patterns,
enable=True,
)
self.jump_cache = JumpForwardCache() if allow_jump else None
self.jump_cache = OutlinesJumpCache() if allow_jump else None

def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, BNFCache):
assert not isinstance(self.jump_cache, JumpForwardCache)
def _query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
if isinstance(self.grammar_cache, XGrammarBackend):
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
else:
jump_map = None
guide, regex = self.grammar_cache.query(key)
if isinstance(self.jump_cache, JumpForwardCache):
jump_map = self.jump_cache.query(regex)
jump_map = self.jump_cache.query(regex)
return Grammar((guide, 0), jump_map)

def query(self, key: Tuple[str, str], vocab_size: int) -> Future:
return self.executor.submit(self._query, key, vocab_size)

def reset(self):
if isinstance(self.grammar_cache, FSMCache):
self.grammar_cache.reset()
if isinstance(self.jump_cache, JumpForwardCache):
self.jump_cache.reset()
self.grammar_cache.reset()
self.jump_cache.reset()
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
import logging

from interegular import InvalidSyntax, parse_pattern
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.fsm.guide import RegexGuide
from outlines.models.transformers import TransformerTokenizer
from transformers import AutoTokenizer

from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained import build_regex_from_object
from sglang.srt.constrained.base_tool_cache import BaseToolCache

logger = logging.getLogger(__name__)


class FSMCache(BaseToolCache):
class OutlinesCache(BaseToolCache):
def __init__(
self,
tokenizer_path,
Expand Down Expand Up @@ -74,7 +75,7 @@ def init_value(self, key):
key_type, key_string = key
if key_type == "json":
try:
regex = build_regex_from_schema(
regex = build_regex_from_object(
key_string,
whitespace_pattern=self.constrained_json_whitespace_pattern,
)
Expand Down
Loading
Loading