diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py index c47c5c8dd58..a8708dfea71 100644 --- a/python/sglang/srt/constrained/__init__.py +++ b/python/sglang/srt/constrained/__init__.py @@ -51,6 +51,21 @@ def build_regex_from_object( return build_regex_from_schema(schema, whitespace_pattern) +try: + from xgrammar import ( + GrammarMatcher, + GrammarMatcherInitContext, + GrammarMatcherInitContextCache, + ) +except ImportError as e: + + class Dummy: + pass + + GrammarMatcher = Dummy + GrammarMatcherInitContext = Dummy + GrammarMatcherInitContextCache = Dummy + __all__ = [ "RegexGuide", "FSMInfo", @@ -60,4 +75,7 @@ def build_regex_from_object( "disk_cache", "disable_cache", "make_byte_level_fsm", + "GrammarMatcher", + "GrammarMatcherInitContext", + "GrammarMatcherInitContextCache", ] diff --git a/python/sglang/srt/constrained/bnf_cache.py b/python/sglang/srt/constrained/bnf_cache.py new file mode 100644 index 00000000000..19765731bd6 --- /dev/null +++ b/python/sglang/srt/constrained/bnf_cache.py @@ -0,0 +1,61 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Cache for the compressed finite state machine.""" + +from typing import Tuple + +from transformers import AutoTokenizer + +from sglang.srt.constrained import ( + GrammarMatcher, + GrammarMatcherInitContext, + GrammarMatcherInitContextCache, +) + +MAX_ROLLBACK_TOKENS = 10 + + +class BNFCache: + grammar_cache: GrammarMatcherInitContextCache + + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + skip_tokenizer_init=False, + whitespace_patterns=None, + ): + # TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init + if skip_tokenizer_init: + return + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) + self.grammar_cache = GrammarMatcherInitContextCache( + tokenizer_or_vocab=tokenizer + ) + + def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext: + key_type, key_string = key + if key_type == "json": + return self.grammar_cache.get_init_context_for_json_schema(key_string) + elif key_type == "regex": + raise ValueError(f"regex hasn't been supported by xgrammar yet") + else: + raise ValueError(f"Invalid key_type: {key_type}") + + def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher: + ctx = self.get_context(key) + return GrammarMatcher( + ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size + ) diff --git a/python/sglang/srt/constrained/grammar.py b/python/sglang/srt/constrained/grammar.py new file mode 100644 index 00000000000..0281539b89c --- /dev/null +++ b/python/sglang/srt/constrained/grammar.py @@ -0,0 +1,190 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Cache for the compressed finite state machine.""" +import logging +from typing import List, Optional, Tuple, Union + +import torch + +from sglang.srt.constrained import GrammarMatcher, RegexGuide +from sglang.srt.constrained.bnf_cache import BNFCache +from sglang.srt.constrained.fsm_cache import FSMCache +from sglang.srt.constrained.jump_forward import JumpForwardCache, JumpForwardMap + +# from sglang.srt.managers.schedule_batch import Req + +logger = logging.getLogger(__name__) + +INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +class XGrammarJump: + pass + + +class JumpHelper: + data: Union[List, str] + state: int + suffix_ids: List[int] + + def __init__( + self, data: Union[List, str] = "", state: int = -1, suffix_ids=[] + ) -> None: + self.data = data + self.state = state + self.suffix_ids = suffix_ids + + def can_jump(self): + return len(self.data) > 0 + + +class Grammar: + grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]] + jump_map: Union[XGrammarJump, JumpForwardMap, None] + + def __init__( + self, + grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]], + jump_map: Union[XGrammarJump, JumpForwardMap, None], + ) -> None: + self.grammar = grammar + self.jump_map = jump_map + + def accept_token(self, token: int): + if isinstance(self.grammar, GrammarMatcher): + assert self.grammar.accept_token(token) + else: + guide, state = self.grammar + self.grammar = guide, guide.get_next_state(state, token) + + def try_jump(self, tokenizer) -> JumpHelper: + if isinstance(self.jump_map, XGrammarJump): + assert isinstance(self.grammar, GrammarMatcher) + return JumpHelper(self.grammar.find_jump_forward_string()) + elif isinstance(self.jump_map, JumpForwardMap): + assert isinstance(self.grammar, Tuple) + + _, state = self.grammar + jump_forward_bytes = self.jump_map.jump_forward_byte(state) + if jump_forward_bytes is None or len(jump_forward_bytes) == 0: + return JumpHelper() # can't jump + + # preprocess the jump forward string + suffix_bytes = [] + continuation_range = range(0x80, 0xC0) + cur_state = state + while ( + len(jump_forward_bytes) + and jump_forward_bytes[0][0] in continuation_range + ): + # continuation bytes + byte_edge = jump_forward_bytes.pop(0) + suffix_bytes.append(byte_edge[0]) + cur_state = byte_edge[1] + + suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] + suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens) + return JumpHelper(suffix_ids, cur_state, suffix_bytes) + else: + return JumpHelper() # can't jump + + def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]: + if isinstance(helper.data, str): + return helper.data, -1 + else: + assert isinstance(self.jump_map, JumpForwardMap) + return self.jump_map.jump_forward_symbol(helper.state) + + def jump_and_retokenize( + self, old_output_ids: List[int], new_output_ids: List[int], next_state: int + ): + if isinstance(self.grammar, GrammarMatcher): + k = 0 + for i, old_id in enumerate(old_output_ids): + if old_id == new_output_ids[i]: + k = i + 1 + else: + break + + # rollback to the last token that is the same + if k < len(old_output_ids): + self.grammar.rollback(len(old_output_ids) - k) + + for i in range(k, len(new_output_ids)): + assert self.grammar.accept_token(new_output_ids[i]) + else: + self.grammar = self.grammar[0], next_state + + def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int): + if isinstance(self.grammar, GrammarMatcher): + # Note that this bitmask is a bitset, not bool + bitmask = self.grammar.find_next_token_bitmask() + # Mask the tokens that are not allowed + vocab_mask[ + self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size) + ] = 1 + else: + guide, state = self.grammar + vocab_mask.fill_(1) + vocab_mask[guide.get_next_instruction(state).tokens] = 0 + + +class GrammarCache: + grammar_cache: Union[BNFCache, FSMCache] + jump_cache: Union[XGrammarJump, JumpForwardCache, None] + + def __init__( + self, + tokenizer_path, + tokenizer_args_dict, + skip_tokenizer_init=False, + whitespace_patterns=None, + backend=None, + allow_jump=False, + ): + if backend == "xgrammar": + self.grammar_cache = BNFCache( + tokenizer_path=tokenizer_path, + tokenizer_args_dict=tokenizer_args_dict, + skip_tokenizer_init=skip_tokenizer_init, + whitespace_patterns=whitespace_patterns, + ) + self.jump_cache = XGrammarJump() if allow_jump else None + else: + assert backend == "outlines" + self.grammar_cache = FSMCache( + tokenizer_path=tokenizer_path, + tokenizer_args_dict=tokenizer_args_dict, + skip_tokenizer_init=skip_tokenizer_init, + constrained_json_whitespace_pattern=whitespace_patterns, + enable=True, + ) + self.jump_cache = JumpForwardCache() if allow_jump else None + + def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar: + if isinstance(self.grammar_cache, BNFCache): + assert not isinstance(self.jump_cache, JumpForwardCache) + return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache) + else: + jump_map = None + guide, regex = self.grammar_cache.query(key) + if isinstance(self.jump_cache, JumpForwardCache): + jump_map = self.jump_cache.query(regex) + return Grammar((guide, 0), jump_map) + + def reset(self): + if isinstance(self.grammar_cache, FSMCache): + self.grammar_cache.reset() + if isinstance(self.jump_cache, JumpForwardCache): + self.jump_cache.reset() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fcd06d8cc9c..85ca560a926 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -37,8 +37,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained import RegexGuide -from sglang.srt.constrained.jump_forward import JumpForwardMap +from sglang.srt.constrained.grammar import Grammar from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -247,9 +246,7 @@ def __init__( self.embedding = None # Constrained decoding - self.regex_fsm: RegexGuide = None - self.regex_fsm_state: int = 0 - self.jump_forward_map: JumpForwardMap = None + self.grammar: Optional[Grammar] = None # For Qwen2-VL self.mrope_position_delta = [] # use mutable object @@ -359,6 +356,8 @@ def check_finished(self): return def jump_forward_and_retokenize(self, jump_forward_str, next_state): + assert self.grammar is not None and self.tokenizer is not None + if self.origin_input_text is None: # Recovering text can only use unpadded ids self.origin_input_text = self.tokenizer.decode( @@ -398,7 +397,8 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): self.surr_offset = self.read_offset - i break - self.regex_fsm_state = next_state + # update the inner state of the grammar + self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state) if self.return_logprob: # For fast-forward part's logprobs @@ -468,8 +468,8 @@ class ScheduleBatch: # Stream has_stream: bool = False - # Has regex - has_regex: bool = False + # Has grammar + has_grammar: bool = False # device device: str = "cuda" @@ -477,7 +477,7 @@ class ScheduleBatch: @classmethod def init_new( cls, - reqs, + reqs: List[Req], req_to_token_pool, token_to_kv_pool, tree_cache, @@ -491,7 +491,7 @@ def init_new( model_config=model_config, return_logprob=any(req.return_logprob for req in reqs), has_stream=any(req.stream for req in reqs), - has_regex=any(req.regex_fsm for req in reqs), + has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, ) @@ -803,26 +803,10 @@ def check_for_jump_forward(self, pad_input_ids_func): keep_indices = set(i for i in range(len(self.reqs))) for i, req in enumerate(self.reqs): - if req.jump_forward_map is not None: - jump_forward_bytes = req.jump_forward_map.jump_forward_byte( - req.regex_fsm_state - ) - if jump_forward_bytes is not None and len(jump_forward_bytes) > 1: - suffix_bytes = [] - continuation_range = range(0x80, 0xC0) - cur_state = req.regex_fsm_state - while ( - len(jump_forward_bytes) - and jump_forward_bytes[0][0] in continuation_range - ): - # continuation bytes - byte_edge = jump_forward_bytes.pop(0) - suffix_bytes.append(byte_edge[0]) - cur_state = byte_edge[1] - - suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes] - suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens) - + if req.grammar is not None: + jump_helper = req.grammar.try_jump(req.tokenizer) + if jump_helper.can_jump(): + suffix_ids = jump_helper.suffix_ids # Current ids, for cache and revert cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1] cur_output_ids = req.output_ids @@ -836,10 +820,8 @@ def check_for_jump_forward(self, pad_input_ids_func): ( jump_forward_str, next_state, - ) = req.jump_forward_map.jump_forward_symbol(cur_state) + ) = req.grammar.jump_forward_str_state(jump_helper) - # Make the incrementally decoded text part of jump_forward_str - # so that the UTF-8 will not corrupt jump_forward_str = new_text + jump_forward_str if not req.jump_forward_and_retokenize( jump_forward_str, next_state @@ -946,7 +928,7 @@ def filter_batch( self.top_logprobs_nums = None self.has_stream = any(req.stream for req in self.reqs) - self.has_regex = any(req.regex_fsm for req in self.reqs) + self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) @@ -979,7 +961,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.return_logprob = self.return_logprob or other.return_logprob self.has_stream = self.has_stream or other.has_stream - self.has_regex = self.has_regex or other.has_regex + self.has_grammar = self.has_grammar or other.has_grammar def get_model_worker_batch(self): if self.forward_mode.is_decode(): @@ -989,13 +971,10 @@ def get_model_worker_batch(self): extend_prefix_lens = self.prefix_lens extend_logprob_start_lens = self.extend_logprob_start_lens - if self.has_regex: - self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] - self.sampling_info.regex_fsm_states = [ - req.regex_fsm_state for req in self.reqs - ] + if self.has_grammar: + self.sampling_info.grammars = [req.grammar for req in self.reqs] else: - self.sampling_info.regex_fsms = None + self.sampling_info.grammars = None global bid bid += 1 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 55b05f84698..b1fb96b2a77 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -29,8 +29,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.constrained.fsm_cache import FSMCache -from sglang.srt.constrained.jump_forward import JumpForwardCache +from sglang.srt.constrained.grammar import GrammarCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -225,17 +224,20 @@ def __init__( ) # Init the FSM cache for constrained generation + self.grammar_cache = None + if not server_args.skip_tokenizer_init: - self.regex_fsm_cache = FSMCache( + self.grammar_cache = GrammarCache( server_args.tokenizer_path, { "tokenizer_mode": server_args.tokenizer_mode, "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, + whitespace_patterns=server_args.constrained_json_whitespace_pattern, + backend=server_args.grammar_backend, + allow_jump=not server_args.disable_regex_jump_forward, ) - self.jump_forward_cache = JumpForwardCache() # Init new token estimation assert ( @@ -402,22 +404,20 @@ def handle_generate_request( # By default, only return the logprobs for output tokens req.logprob_start_len = len(recv_req.input_ids) - 1 - # Init regex FSM + # Init regex FSM or BNF if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None ): + assert self.grammar_cache is not None if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("json", req.sampling_params.json_schema) + req.grammar = self.grammar_cache.query( + ("json", req.sampling_params.json_schema), + self.model_config.vocab_size, ) elif req.sampling_params.regex is not None: - req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( - ("regex", req.sampling_params.regex) - ) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string + req.grammar = self.grammar_cache.query( + ("regex", req.sampling_params.regex), self.model_config.vocab_size ) # Truncate prompts that are too long @@ -796,10 +796,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_ids[i] - ) + if req.grammar is not None: + req.grammar.accept_token(next_token_ids[i]) if req.return_logprob: logprob_pt += self.add_logprob_return_values( @@ -855,10 +853,8 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): req.output_ids.append(next_token_id) req.check_finished() - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_id - ) + if req.grammar is not None: + req.grammar.accept_token(next_token_id) if req.finished(): self.tree_cache.cache_finished_req(req) @@ -1056,7 +1052,9 @@ def flush_cache(self): ): self.tree_cache.reset() self.tree_cache_metrics = {"total": 0, "hit": 0} - self.regex_fsm_cache.reset() + if self.grammar_cache is not None: + self.grammar_cache.reset() + # TODO(dark): reset the bnf cache self.req_to_token_pool.clear() self.token_to_kv_pool.clear() torch.cuda.empty_cache() diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 27a2d07fb27..6afd48cc8a1 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -6,7 +6,7 @@ import torch import sglang.srt.sampling.penaltylib as penaltylib -from sglang.srt.constrained import RegexGuide +from sglang.srt.constrained.grammar import Grammar if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -29,11 +29,9 @@ class SamplingBatchInfo: # Bias Tensors vocab_size: int logit_bias: torch.Tensor = None - vocab_mask: torch.Tensor = None + vocab_mask: Optional[torch.Tensor] = None - # FSM states - regex_fsms: List[RegexGuide] = None - regex_fsm_states: List[int] = None + grammars: Optional[List[Optional[Grammar]]] = None # Penalizer penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None @@ -136,8 +134,7 @@ def update_penalties(self): self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self): - has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) - if not has_regex: + if not self.grammars or not any(grammar for grammar in self.grammars): self.vocab_mask = None return @@ -147,12 +144,9 @@ def update_regex_vocab_mask(self): dtype=torch.bool, device=self.device, ) - for i, regex_fsm in enumerate(self.regex_fsms): - if regex_fsm is not None: - self.vocab_mask[i].fill_(1) - self.vocab_mask[i][ - regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens - ] = 0 + for i, grammar in enumerate(self.grammars): + if grammar is not None: + grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size) def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): if self.penalizer_orchestrator: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6ccd8918577..9cb7c03310d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -102,6 +102,7 @@ class ServerArgs: # Kernel backend attention_backend: Optional[str] = None sampling_backend: Optional[str] = None + grammar_backend: Optional[str] = "outlines" # Optimization/debug options disable_flashinfer: bool = False @@ -537,6 +538,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.sampling_backend, help="Choose the kernels for sampling layers.", ) + parser.add_argument( + "--grammar-backend", + type=str, + choices=["xgrammar", "outlines"], + default=ServerArgs.grammar_backend, + help="Choose the backend for constrained decoding.", + ) # Optimization/debug options parser.add_argument(