-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix guided decoding with tokenizer mode mistral #11046
Changes from 1 commit
83ea81c
a9a1e3c
710fcc9
9792cee
b3cb571
4ce6b28
d7c7161
d61257d
b674647
78e7dc2
bfedea7
0173c2d
b98f633
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,22 +16,20 @@ | |
|
||
from vllm.model_executor.guided_decoding.xgrammar_utils import ( | ||
convert_lark_to_gbnf, grammar_is_likely_lark) | ||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer | ||
|
||
if TYPE_CHECKING: | ||
from transformers import PreTrainedTokenizer | ||
|
||
from vllm.config import ModelConfig | ||
from vllm.sampling_params import GuidedDecodingParams | ||
|
||
|
||
# TODO: passing batch size to max threads here | ||
def get_local_xgrammar_guided_decoding_logits_processor( | ||
guided_params: GuidedDecodingParams, | ||
tokenizer: PreTrainedTokenizer, | ||
model_config: ModelConfig, | ||
max_threads: int = 8): | ||
config = GrammarConfig.from_guided_params(guided_params=guided_params, | ||
model_config=model_config, | ||
tokenizer=tokenizer, | ||
max_threads=max_threads) | ||
return XGrammarLogitsProcessor(config) | ||
|
@@ -41,7 +39,8 @@ class TokenizerData(NamedTuple): | |
"""Immutable container for cached tokenizer data.""" | ||
encoded_vocab: list[str] | ||
stop_token_ids: list[int] | None | ||
backend_str: str | ||
backend_str: str | None | ||
vocab_type: xgr.VocabType | None | ||
|
||
|
||
class TokenizerDataCache: | ||
|
@@ -68,18 +67,26 @@ def get_tokenizer_data(cls, | |
"get_vocab method.") from e | ||
|
||
stop_token_ids = None | ||
backend_str = xgr.VocabType.RAW | ||
backend_str = "" | ||
vocab_type = xgr.VocabType.RAW | ||
|
||
if stop_token_ids is None and hasattr( | ||
tokenizer, | ||
"eos_token_id") and tokenizer.eos_token_id is not None: | ||
stop_token_ids = [tokenizer.eos_token_id] | ||
|
||
if isinstance(tokenizer, PreTrainedTokenizerFast): | ||
backend_str = tokenizer.backend_tokenizer.to_str() | ||
if stop_token_ids is None and hasattr( | ||
tokenizer, | ||
"eos_token_id") and tokenizer.eos_token_id is not None: | ||
stop_token_ids = [tokenizer.eos_token_id] | ||
|
||
elif isinstance(tokenizer, MistralTokenizer): | ||
# REF: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 | ||
vocab_type = xgr.VocabType.BYTE_FALLBACK | ||
|
||
cls._cache[tokenizer_hash] = TokenizerData( | ||
encoded_vocab=encoded_vocab, | ||
stop_token_ids=stop_token_ids, | ||
backend_str=backend_str) | ||
backend_str=backend_str, | ||
vocab_type=vocab_type) | ||
|
||
return cls._cache[tokenizer_hash] | ||
|
||
|
@@ -99,10 +106,18 @@ def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: | |
|
||
if cache_key not in cls._cache: | ||
assert config.encoded_vocab is not None | ||
tokenizer_info = xgr.TokenizerInfo._create_from_handle( | ||
xgr_core.TokenizerInfo.from_huggingface( | ||
config.encoded_vocab, config.backend_str, | ||
config.vocab_size, config.stop_token_ids)) | ||
|
||
if config.backend_str: | ||
tokenizer_info = xgr.TokenizerInfo._create_from_handle( | ||
xgr_core.TokenizerInfo.from_huggingface( | ||
config.encoded_vocab, config.backend_str, | ||
config.vocab_size, config.stop_token_ids)) | ||
else: | ||
tokenizer_info = xgr.TokenizerInfo( | ||
mgoin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
config.encoded_vocab, | ||
config.vocab_type, | ||
vocab_size=config.vocab_size, | ||
stop_token_ids=config.stop_token_ids) | ||
cls._cache[cache_key] = xgr.GrammarCompiler( | ||
tokenizer_info, max_threads=config.max_threads) | ||
|
||
|
@@ -122,11 +137,11 @@ class GrammarConfig: | |
encoded_vocab: list[str] | None = None | ||
stop_token_ids: list[int] | None = None | ||
backend_str: str | None = None | ||
vocab_type: xgr.VocabType = xgr.VocabType.RAW | ||
|
||
@classmethod | ||
def from_guided_params(cls, | ||
guided_params: GuidedDecodingParams, | ||
model_config: ModelConfig, | ||
tokenizer: PreTrainedTokenizer, | ||
max_threads: int = 8) -> GrammarConfig: | ||
|
||
|
@@ -136,24 +151,27 @@ def from_guided_params(cls, | |
encoded_vocab = None | ||
stop_token_ids = None | ||
backend_str = None | ||
vocab_type = xgr.VocabType.RAW | ||
else: | ||
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) | ||
encoded_vocab = tokenizer_data.encoded_vocab | ||
stop_token_ids = tokenizer_data.stop_token_ids | ||
backend_str = tokenizer_data.backend_str | ||
vocab_type = tokenizer_data.vocab_type | ||
|
||
if guided_params.json: | ||
if not isinstance(guided_params.json, str): | ||
json_str = json.dumps(guided_params.json) | ||
else: | ||
json_str = guided_params.json | ||
return cls(json_str=json_str, | ||
vocab_size=model_config.hf_config.vocab_size, | ||
vocab_size=tokenizer.vocab_size, | ||
encoded_vocab=encoded_vocab, | ||
stop_token_ids=stop_token_ids, | ||
backend_str=backend_str, | ||
tokenizer_hash=tokenizer_hash, | ||
max_threads=max_threads) | ||
max_threads=max_threads, | ||
vocab_type=vocab_type) | ||
elif guided_params.grammar: | ||
# XGrammar only supports GBNF grammars, so we must convert Lark | ||
if grammar_is_likely_lark(guided_params.grammar): | ||
|
@@ -168,20 +186,22 @@ def from_guided_params(cls, | |
else: | ||
grammar_str = guided_params.grammar | ||
return cls(grammar_str=grammar_str, | ||
vocab_size=model_config.hf_config.vocab_size, | ||
vocab_size=tokenizer.vocab_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aarnphm there is a reason why we needed to reference the model's vocab size and not the tokenizers, correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for check it out @mgoin! I really appreciate how fast you answered to this.
Yeah, but the problem of this part is it was not correct. I reported a traceback on #11045.
First, What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to reference tokenizer vocab size because of additional padding tokens. this is the thread https://vllm-dev.slack.com/archives/C07QQ8DAXMK/p1732673561777159 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tks! I reverted the code. |
||
encoded_vocab=encoded_vocab, | ||
stop_token_ids=stop_token_ids, | ||
backend_str=backend_str, | ||
tokenizer_hash=tokenizer_hash, | ||
max_threads=max_threads) | ||
max_threads=max_threads, | ||
vocab_type=vocab_type) | ||
elif guided_params.json_object: | ||
return cls(json_object=True, | ||
vocab_size=model_config.hf_config.vocab_size, | ||
vocab_size=tokenizer.vocab_size, | ||
encoded_vocab=encoded_vocab, | ||
stop_token_ids=stop_token_ids, | ||
backend_str=backend_str, | ||
tokenizer_hash=tokenizer_hash, | ||
max_threads=max_threads) | ||
max_threads=max_threads, | ||
vocab_type=vocab_type) | ||
else: | ||
raise ValueError( | ||
"Currently only support JSON and EBNF grammar mode for xgrammar" | ||
|
@@ -257,10 +277,14 @@ def __call__(self, input_ids: list[int], | |
# fill_next_token_bitmask so we move it to the device of scores | ||
device_type = scores.device.type | ||
if device_type != "cuda": | ||
wallashss marked this conversation as resolved.
Show resolved
Hide resolved
|
||
scores = scores.to("cpu") | ||
scores = scores.to("cpu").unsqueeze(0) | ||
|
||
# Note: In this method, if the tensors have different dimensions | ||
# on CPU device fails, but on GPU it runs without error. Hence the | ||
# unsqueeze above for scores, to match the token bitmask shape | ||
xgr.apply_token_bitmask_inplace(scores, | ||
self.token_bitmask.to(scores.device)) | ||
if device_type != "cuda": | ||
scores = scores.to(device_type) | ||
scores = scores.to(device_type).squeeze() | ||
|
||
return scores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you revert this change?
This is being used for offline use case, with LLM, where as get_guided_decoding_logit_processor is being used for online usecase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed what I did and checked that it was not so good based on the difference of implementation of the methods
get_local_outlines_guided_decoding_logits_processor
andget_outlines_guided_decoding_logits_processor
. But I tried something a little bit difference to not revert everything, just to avoid code duplication. See if you agree, if not I won't insist I can revert with no problem. Also I updated the tests to check the offline and online version to pass all over these code paths, considering the offline path.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is even more confusing now that there are three functions. I would prefer a revert as it seems you have no other changes to this file? We can consider refactor in another PR