Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Behaviour between slow and fast LLaMa tokenizer not equivalent #23889

Closed
NielsRogge opened this issue May 31, 2023 · 13 comments
Closed

Behaviour between slow and fast LLaMa tokenizer not equivalent #23889

NielsRogge opened this issue May 31, 2023 · 13 comments
Labels
Core: Tokenization Internals of the library; Tokenization.

Comments

@NielsRogge
Copy link
Contributor

NielsRogge commented May 31, 2023

System Info

Transformers v4.29.2

Who can help?

@ArthurZucker

Reproduction

For a new model (#23460), I'd like to get equivalent behaviour between the slow and fast LLaMa tokenizers. The code of the slow tokenizer was taken from the original code, and now I'd like to translate this to the fast tokenizer as well.

However, as can be seen below, behaviour is not equivalent:

from transformers import LlamaTokenizer, LlamaTokenizerFast
import torch

tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", truncation_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.add_special_tokens({"bos_token": "</s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
tokenizer.add_special_tokens({"unk_token": "</s>"})

fast_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", truncation_side="left")
fast_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
fast_tokenizer.add_special_tokens({"bos_token": "</s>"})
fast_tokenizer.add_special_tokens({"eos_token": "</s>"})
fast_tokenizer.add_special_tokens({"unk_token": "</s>"})

prompt = "What is unusual about this image?"

encoding = tokenizer(prompt, return_tensors="pt")

fast_encoding = fast_tokenizer(prompt, return_tensors="pt")

for k,v in encoding.items():
    assert torch.allclose(fast_encoding[k], v)
=> this assertion fails since the input_ids differ:

tensor([[    2,  1724,   338, 22910,  1048,   445,  1967, 29973]])
tensor([[    1,  1724,   338, 22910,  1048,   445,  1967, 29973]])

Expected behavior

I'd expect that the assertion above passes.

@ArthurZucker ArthurZucker added the Core: Tokenization Internals of the library; Tokenization. label May 31, 2023
@ArthurZucker
Copy link
Collaborator

Thanks for reporting, will have a look

@NielsRogge NielsRogge mentioned this issue Jun 5, 2023
5 tasks
@ArthurZucker
Copy link
Collaborator

Okay, what's happening here is that you are adding tokens that are already present in the vocabulary of the model.
</s> is 2.

@ArthurZucker
Copy link
Collaborator

Reproduced is still working for the latest version of transformers because you are relying on adding the token, which should be ignored but is not. The content in rust is modified .
Use this:
fast_tokenizer.bos_token = "</s>"

@ArthurZucker
Copy link
Collaborator

(this will update the processor)

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Jun 8, 2023

Thanks for taking a look!

However I'm using the latest version of Transformers, have added fast_tokenizer.bos_token = "</s>", but the assertion still fails for me.

@NielsRogge NielsRogge reopened this Jun 8, 2023
@NielsRogge
Copy link
Contributor Author

Reproduced on main branch, here's a Colab notebook: https://colab.research.google.com/drive/1KA_mliTsvjnhOCO3SApVJkgVd2HEeVQZ?usp=sharing.

@ArthurZucker
Copy link
Collaborator

Actually, with fast tokenizers there is no logic to properly update the template processor if it exists. The default has always been to initialize the model with the correct tokens. Meaning:
fast_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", truncation_side="left", bos_token=“”)
Is what you should be using. The template processor gets updated only if you change “add_bos” and “add_eos” otherwise the logic is a bit complicated, we have to overload parent setters to bos_token as well as bos_token_id to update the template processing. Not in favor of that so leaving as is, will improve the doc for changing bos and eos in fast

@NielsRogge
Copy link
Contributor Author

Hmm ok so there's no way to have an equivalent fast tokenizer that makes the script above pass?

The reason is that for the new InstructBLIP model (#23460), the processor class (InstructBlipProcessor) would normally use the AutoTokenizer class to load files from the hub. And as the AutoTokenizer API uses the fast tokenizer by default, I'm currently not getting equivalent results as when I use the slow one.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jun 9, 2023

No way no. I am not in favor of introducing a very hacky behaviour while the fix should be in rust in that case.
The following works:

from transformers import LlamaTokenizer, LlamaTokenizerFast
import torch

fast_tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", truncation_side="left", bos_token = "</s>", unk_token = "</s>")
fast_tokenizer.add_special_tokens({"pad_token": "[PAD]"})


tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", truncation_side="left", bos_token = "</s>", unk_token = "</s>")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
prompt = "What is unusual about this image?"

encoding = tokenizer(prompt, return_tensors="pt")

fast_encoding = fast_tokenizer(prompt, return_tensors="pt")

for k,v in encoding.items():
    assert torch.allclose(fast_encoding[k], v)

@ArthurZucker
Copy link
Collaborator

Also once you have a tokenizer ready, you can save it and it should have the correct postProcessor

@NielsRogge
Copy link
Contributor Author

Ok thanks a lot, now works fine and I can use the fast tokenizer.

@dsdanielpark
Copy link

dsdanielpark commented Oct 6, 2023

Hello, ArthurZucker.

Thank you for the great package and maintenance.
I wanted to inquire whether llama2's fast tokenizer is currently functioning correctly based on the above code. Or still have a problem?

I wander this notice message while I trained llama2

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding

@ArthurZucker
Copy link
Collaborator

The above message is pretty much unrelated, and should help you improve the performances when padding the input. Should work alright! 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Core: Tokenization Internals of the library; Tokenization.
Projects
None yet
Development

No branches or pull requests

3 participants