Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context Free Grammar Constrained Decoding (ebnf interface, compatible with llama-cpp) #27557

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

Saibo-creator
Copy link
Contributor

@Saibo-creator Saibo-creator commented Nov 17, 2023

What does this PR do?

This PR adds a new feature (Context Free Grammar Constrained Decoding) to the library.
There is already one PR(WIP) for this feature( #26520 ), but this one has a different motivation and implementation.

This implementation is inspired by and adapted from https://github.com/Shopify/torch-grammar and https://github.com/ggerganov/llama.cpp/pull/1773/files

This implementation aims to achieve the following goals:

  • CFG-constrained decoding
  • EBNF notation as interface for the grammar
  • standalone implementation of the grammar parser(left recursive-descent parsing, the same as in llama-cpp)
  • compatibility with grammars in the llama.cpp library(https://github.com/ggerganov/llama.cpp/tree/master/grammars)
  • incremental parsing and also non-incremental parsing(some tokenizers doesn't support incremental parsing from my experiments)
  • unicode support for the grammar(not trivial but important for any multi-lingual model)

The two main differences from PR #26520 :

  • dependency on lark, which may not be a bad thing, but my experience is that it will reduce the flexibility and may be hard to adapt to our specfic need, e.g. unicode grammar support.
  • ebnf interface. This PR supports the same EBNF as in llama-cpp, so that users can directly migrate from llama-cpp

Challenges for this PR:

  • compatibility with all the tokenizers in the transformers library.

Current status:

  • The grammar parser is implemented and works well with the example grammars from llama.cpp library.
  • A few integration tests are added to test the combination of grammar and tokenizer.
  • no unicode support yet, means it will probably fail when you want to constrain with emoji or other unicode characters.
  • greedy search
  • sampling, top-k, top-p
  • beam search

TODO:

  • Batching support
  • compatible with greedy decoding and sampling under beam=1
  • grammar parser fails to parse llama-cpp's json grammar(more precisely the string line). Currently, a slightly simplified version of json grammar if used(now fixed)
  • grammar parser requires the last rule ending with a new line, otherwise, parsing error will be raised. This is not user-friendly and should be fixed
  • The EOS token seems not always included in the allowed tokens even when it should be, maybe due to the nature of recursive-descent parsing ?
  • compatible with beam_search and beam_sample(Now throws error RuntimeError: probability tensor contains either inf, nan or element < 0). A good reference is the ConstrainedBeamSearchScorer
  • unicode support
  • properly test with different tokenizers(bpe, wordpiece, unigram, etc.)

Fixes # #25778
Related to PR # #26520

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@gante

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@jvhoffbauer
Copy link

I think it is a great idea to be compatible with llama.cpp!

@abhinavkulkarni
Copy link

Hey @Saibo-creator,

I tried running grammar_utils.py on json.gbnf, but I get the following error:

Traceback (most recent call last):
  File "/home/user/grammar.py", line 657, in <module>
    state = parse_ebnf(input_text)
  File "/home/user/grammar.py", line 249, in parse_ebnf
    grammar_repr = parse_rule(state, grammar_repr)
  File "/home/user/grammar.py", line 231, in parse_rule
    pos = parse_alternates(state, pos, name, rule_id, False)
  File "/home/user/grammar.py", line 212, in parse_alternates
    while pos[0] == "|":
IndexError: string index out of range

@Saibo-creator
Copy link
Contributor Author

Hello @abhinavkulkarni ,
This is probably because the missing new line at the end of the grammar.

Try

root   ::= object

object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws

value  ::= object | array | string | number | ("true" | "false" | "null") ws

array  ::= "[" ws ( value ("," ws value)* )? "]" ws

string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws


ws ::= ([ \t\n] ws)?

instead of

root   ::= object

object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws

value  ::= object | array | string | number | ("true" | "false" | "null") ws

array  ::= "[" ws ( value ("," ws value)* )? "]" ws

string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws


ws ::= ([ \t\n] ws)?

Let me know if this doesn't work

@abhinavkulkarni
Copy link

abhinavkulkarni commented Nov 17, 2023

Thanks @Saibo-creator, that works.

I have the following piece of code:

model_id = "TheBloke/zephyr-7B-alpha-AWQ"
tokenizer = LlamaTokenizerFast.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0")

with open("./json.gbnf", "r") as file:
    grammar_str = file.read()
    grammar = IncrementalGrammarAcceptor(grammar_str, "root", tokenizer)
    logits_processor = GrammarConstrainedLogitsProcessor(grammar, batch_size=2, num_beams=1)

prompt = f'''What is the difference between nuclear fusion and fission?
###Response:'''

input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
output = model.generate(
    inputs=input_ids, 
    # do_sample=True,
    # temperature=0.7,
    # top_p=0.15,
    # top_k=0,
    max_new_tokens=512,
    repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id,
    logits_processor=[logits_processor],
    streamer=streamer)

I get a response that starts with:

{

"Nuclear" 

but then continues to output \n till it reaches max token limit.

Please note, if I don't specify custom logits_processor, I get a pretty valid output:

What is the difference between nuclear fusion and fission?
###Response:
Nuclear fusion and fission are two different processes that occur in the nucleus of an atom. 

1. Nuclear Fusion: In this process, two or more atomic nuclei combine to form a heavier nucleus. This process releases a tremendous amount of energy, which is used as a source of power in stars and in controlled environments like nuclear fusion reactors. The most common example of nuclear fusion is the reaction that occurs inside the sun.

2. Nuclear Fission: In this process, a heavy nucleus splits into two lighter nuclei, releasing a significant amount of energy. This process is used to generate electricity in nuclear power plants. However, it also has the potential for catastrophic consequences if not properly controlled.

In summary, while both nuclear fusion and fission involve changes in the nucleus of an atom, they differ in terms of the number of nuclei involved and the type of reaction that takes place.

@Saibo-creator
Copy link
Contributor Author

Saibo-creator commented Nov 17, 2023

@abhinavkulkarni

Thank you for testing!

I will look into this issue!

By the way, I just integrated the gcd feature into generation_api, now you can run it with

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.grammar_utils import IncrementalGrammarConstraint


if __name__ == '__main__':
    torch.manual_seed(2)

    model_id = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id)
    with open("examples/grammars/json.gbnf", "r") as file:
        grammar_str = file.read()
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)

    prefix1= "This is a valid json string for email:"
    prefix2= "This is a valid json string for shopping cart:"
    input_ids = tokenizer([prefix1, prefix2],add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]

    output = model.generate(input_ids, do_sample=False, max_length=30, num_beams=2, grammar=grammar,
                            num_return_sequences=2)
    # decode output
    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)

    """
    'This is a valid json string for email:{ "title": "Theory", "text": "Theory", "type": "text", "text": "Theory", "type',
    'This is a valid json string for shopping cart:{ "name": "MyCart", "price": "10", "price": "10", "price": "10", "price": "'
    """

If you have time, could you try to call via above api and confirm if the problem remains?

For GPT2, it works as expected, so this may be related to the specific implementation of llama-tokenizer. I will try to fix it asap

@abhinavkulkarni
Copy link

For prompt:

prompt = f"A sample JSON for employee record in a database: "

I do get a JSON-looking response, but then again, the model continues to output newlines until it hits the token limit:

{
   "id": 1,
   "name": "John",
   "age": 25,
   "salary": 30000,
   "department": {
       "id": 1,
       "name": "Sales"
   }
empty stack
}
....
....
....

@Saibo-creator
Copy link
Contributor Author

@abhinavkulkarni
I'm able to reproduce the "strange behavior" you reported and it actually not a bug but rather an "expected behavior".

In the json grammar, we have
object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}" ws

and the last ws basically allows the model to generate arbitrary white space(including new line) after the json object because such white space doesn't break the json syntax.

This may not be a desired behavior, so I removed that ws from the json grammar and it seems to work correctly.

But it does surprise me that the model didn't pick EOS after finishing the json object. Maybe the EOS was not added to the allowed token list due to a bug.
I will double-check if I treated the EOS correctly in the grammar implementation.

@abhinavkulkarni
Copy link

abhinavkulkarni commented Nov 18, 2023

@Saibo-creator: Thanks for the changes.

Removing the whitespace ws fixes the newline problem.

For the simple prompt, prompt = f"A sample JSON for employee record in a database: ", I still see a WARN log line:

WARNING:grammar_utils:empty stack

A few points:

  1. Should the grammar processor not reset its state after one call of model.generate? Calling model.generate on the same grammar processor throws an expectation. It would be expensive to have to parse the grammar afresh for every single model.generate call.
  2. Should amping up the repetition_penalty not fix the whitespace issue? Unless there is a bug that doesn't include EOS in the state transition machine which you alluded to.

@Saibo-creator
Copy link
Contributor Author

@abhinavkulkarni

Regarding resetting the state of grammar processor, here is my consideration:

Currently the GrammarConstrainedLogitsProcessor contains the parsing state, and I think it may be useful to not reset the state after every generation, because this could allow the user to continue the grammar-constrained generation, see the code example below.

And if the user wants to start a new generation, a new instance of LogitProcessor is indeed needed (here we can also add a reset method to make it more user-friendly)

It would be expensive to have to parse the grammar afresh for every single model.generate call.

I don't get this point though.

  • If the user's goal is to start another generation from scratch, then the grammar has to be parsed afresh, I don't think there is a way to avoid it ?
  • If the user's goal is to continue the generation, then the example I showed below should solve the problem. The parsing would not need to start from scratch but simply continue from the old parsing state

Does this sound reasonable for you?

Regarding the design choice to put the parsing state inside the LogitProcessor, I'm not sure if this is the best way to do it. So I would like to have your opinion @gante :)

from transformers import AutoModelForCausalLM, AutoTokenizer,TextStreamer, set_seed
from transformers.generation.grammar_utils import IncrementalGrammarConstraint
from transformers.generation.logits_process import GrammarConstrainedLogitsProcessor


if __name__ == '__main__':

    import logging
    logging.getLogger("transformers.generation").setLevel(logging.INFO)

    # model_id = "saibo/llama-1B"
    model_id = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    streamer = TextStreamer(tokenizer, skip_special_tokens=True)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id)
    with open("examples/grammars/json.gbnf", "r") as file:
        grammar_str = file.read()
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)

    prefix1= "This is a valid json string for email:"
    prefix2= "This is a valid json string for shopping cart:"
    input_ids = tokenizer([prefix2],add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"]

    logits_processor = GrammarConstrainedLogitsProcessor(grammar)

    ###################################################
    # generation under the Grammar constraint for 10 tokens
    ##################################################

    output = model.generate(input_ids, do_sample=False, max_new_tokens=10, num_beams=2, logits_processor=[logits_processor],
                            num_return_sequences=1, repetition_penalty=1.5)

    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)
    # 'This is a valid json string for shopping cart:{ "name": "MyCart", "price'

    ###################################################
    # Continue the generation under the same constraint for 10 tokens
    #
    # 1. Need to use the output of the previous generation as the input for the next generation
    # 2. Reuse the same logits_processor because the parser state is stored in the logits_processor
    #
    ##################################################

    output = model.generate(output[0].unsqueeze(0), do_sample=False, max_new_tokens=10, num_beams=2, logits_processor=[logits_processor],
                            num_return_sequences=1, repetition_penalty=1.5)
    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)
    # 'This is a valid json string for shopping cart:{ "name": "MyCart", "price": "10", "description": "MyCart'


    ###################################################
    # We want to generate another valid json string
    #
    # 1. Create a new logits_processor with empty parser state
    # 2. Use the same prompt as the input
    ##################################################

    logits_processor = GrammarConstrainedLogitsProcessor(grammar)

    output = model.generate(input_ids, do_sample=True, max_new_tokens=20, num_beams=2, logits_processor=[logits_processor],
                            num_return_sequences=1, repetition_penalty=1.5)

    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(generations)
    # 'This is a valid json string for shopping cart:{ "name": "MyCart", "price": "10", "description": "MyCart'

@abhinavkulkarni
Copy link

abhinavkulkarni commented Nov 18, 2023

Does this sound reasonable for you?

Thanks @Saibo-creator, it makes sense not to reset the grammar state so that the user can continue the generation.

One more minor correction, the rule for generating string in JSON grammar should be:

string ::= "\"" ( [ a-zA-Z0-9] )* "\"" ws

instead of

string ::= "\"" ( [a-zA-Z0-9] )* "\"" ws

Copy link

@arshadshk arshadshk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@arshadshk
Copy link

can we add a python gbnf file too ? Can take inspiration from : https://github.com/ggerganov/llama.cpp/blob/master/grammars/c.gbnf

@Saibo-creator Saibo-creator changed the title [Draft] Context Free Grammar Constrained Decoding (ebnf interface, compatible with llama-cpp) [WIP] Context Free Grammar Constrained Decoding (ebnf interface, compatible with llama-cpp) Nov 20, 2023
@shermansiu
Copy link
Contributor

shermansiu commented Nov 21, 2023

This should be related to the constrained decoding in Picard and Synchromesh.

@Saibo-creator
Copy link
Contributor Author

Saibo-creator commented Nov 23, 2023

Hello @gante @ArthurZucker

I'm excited to share that the feature is now in great shape, and I'm eager to hear your thoughts on it.

The implementation of the grammar-constrained decoding feature is quite complex, as we aim to make it compatible with

  • beam search,
  • sampling,
  • all tokenizers,
  • Unicode
  • etc

It's relatively straightforward to integrate it with greedy search or greedy sampling. This leads me to my first question: Should we break down this feature into multiple versions, starting with a basic one, or would it be better to aim for a comprehensive solution in a single merge? From my perspective, once we have thoroughly tested greedy decoding and greedy sampling, it might be beneficial to merge them first, as they already cater to a wide range of use cases.

Additionally, I'm facing some challenges in devising tests for this feature. Currently, I have a setup similar to what's outlined here, where I create simple grammars and verify the accuracy of the generation. However, establishing a systematic testing approach is tricky. For example, if we want to test the json grammar compatibility with all models, running the model with actual weights becomes necessary. Without the weights, the model might generate nonsensical but syntactically correct json outputs, which doesn't help in effective testing. While using actual weights does lead to valid json generation, it significantly slows down the process.

I'd appreciate your insights on how to navigate these testing challenges. In the meantime, I'll continue refining the feature.

: )

2.remove `grammar` from generation argument list, use `GrammarLogitsProcessor` instead
@arshadshk
Copy link

@ Saibo-creator I suggest we break down this feature into multiple versions, starting with a basic one. This creates motivation and encourages more people to collaborate, a greedy search for JSON sounds good for a start.

@Saibo-creator Saibo-creator changed the title [WIP] Context Free Grammar Constrained Decoding (ebnf interface, compatible with llama-cpp) Context Free Grammar Constrained Decoding (ebnf interface, compatible with llama-cpp) Nov 24, 2023
@Saibo-creator
Copy link
Contributor Author

Saibo-creator commented Nov 24, 2023

Thank you @arshadshk for the feedback. I agree with you! In terms of greedy search and random sampling-based decoding, this feature should already be solid enough.

And indeed json is the most popular use case for this feature, so we can add Unicode support a bit later.

Now I'm working on crafting tests. It's a bit challenging to write tests for this feature. For example, I really want to have a TestMixin that tries to test every model to generate json objects. But as I explained above, this seems non-trivial.

I will start with more atomic tests like this

@Saibo-creator
Copy link
Contributor Author

btw, @arshadshk, if you have time, could you also have a look at #27676 ? That PR tries to fix a bug which is important for this CFG feature to work properly, Thanks !

@arshadshk
Copy link

arshadshk commented Nov 24, 2023

@Saibo-creator the (#27676) fix makes sense, I wonder if we open up probs for <eos> token too along with <pad> token, we might need to terminate the generation if nothing more is generated.

@gante
Copy link
Member

gante commented Nov 30, 2023

Hi @Saibo-creator 👋

It's great to see a project with a working example! I'd love to add it to transformers at some point, but we don't have the capacity to maintain a new text generation project at the moment -- you can probably see from my response time in the PRs that our bandwidth at the moment is quite limited :) Since transformers is used in production, we can't add features if we don't have the capacity to maintain them.

My suggestion: let's add the code as is under /examples/research_projects/grammar, for which the transformers team has 0 maintenance guarantees, and move it into the main transformers folder as soon as we have capacity on our end. Does it sound good to you? 🤗

P.S.: as a research project, you'd be able to make any changes you want with pretty much no barriers on our side ;)

@Saibo-creator
Copy link
Contributor Author

Saibo-creator commented Nov 30, 2023

Hi @Saibo-creator 👋

It's great to see a project with a working example! I'd love to add it to transformers at some point, but we don't have the capacity to maintain a new text generation project at the moment -- you can probably see from my response time in the PRs that our bandwidth at the moment is quite limited :) Since transformers is used in production, we can't add features if we don't have the capacity to maintain them.

My suggestion: let's add the code as is under /examples/research_projects/grammar, for which the transformers team has 0 maintenance guarantees, and move it into the main transformers folder as soon as we have capacity on our end. Does it sound good to you? 🤗

P.S.: as a research project, you'd be able to make any changes you want with pretty much no barriers on our side ;)

Sounds great! Thank you @gante !
I'm asking a couple of friends to test it. When it's ready I would be happy to write a blog like this one to introduce this feature.

@gante
Copy link
Member

gante commented Nov 30, 2023

@Saibo-creator sounds great!

And don't let my conservative approach to your suggestions lower your enthusiasm, I'm enjoying your contributions :D

@oobabooga
Copy link
Contributor

I have tested the code in this PR and found that it works very nicely, so I borrowed it for my repository (with due credits): oobabooga/text-generation-webui#4953

It is more robust than the torch-grammar EBNF implementation that I was previously using, which would half the time throw and error while importing a seemingly valid grammar.

Being able to generate structured output like json and lists for a given prompt has many practical applications and this logits processor makes that easy to setup, so I find it extremely valuable.

@abhinavkulkarni
Copy link

Thanks @oobabooga, I was also able to test it successfully in HuggingFace TGI. It does work very well.

@Emekaborisama
Copy link

i did set this up on fastapi and it only return result once

@Saibo-creator
Copy link
Contributor Author

Could you give a working example to show the problem ? I would be happy to investigate it.

@Saibo-creator
Copy link
Contributor Author

Since Transformers will not merge this in the near future, I have written a small extension library. The use is very straightforward.

https://github.com/Saibo-creator/transformers-CFG/tree/main

@Saibo-creator
Copy link
Contributor Author

Hello! My use case requires the grammar to be dependent on the input text. I'm wondering if the current implementation supports passing a batch of grammars along with the batch of input and constrain the output based on different grammars ?

Hey! This is an interesting use case and I'm working on it. Will keep you updated.

@huggingface huggingface deleted a comment from github-actions bot Jan 30, 2024
@huggingface huggingface deleted a comment from github-actions bot Feb 27, 2024
@ArthurZucker ArthurZucker added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants