Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama] AutoTokenizer does not add eos_token at the end #23833

Closed
2 of 4 tasks
csyourui opened this issue May 28, 2023 · 10 comments
Closed
2 of 4 tasks

[llama] AutoTokenizer does not add eos_token at the end #23833

csyourui opened this issue May 28, 2023 · 10 comments
Assignees

Comments

@csyourui
Copy link

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.29.2
  • Platform: Linux-3.10.0-1160.42.2.el7.x86_64-x86_64-with-glibc2.35
  • Python version: 3.9.16
  • Huggingface_hub version: 0.14.1
  • Safetensors version: not installed
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

code:

from transformers import AutoTokenizer, LlamaTokenizer

auto_tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", add_eos_token=True, use_fast=True)
llama_tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", add_eos_token=True, use_fast=True)

print(auto_tokenizer.decode(auto_tokenizer.encode("auto_tokenizer", add_special_tokens = True)))
print(llama_tokenizer.decode(llama_tokenizer.encode("llama_tokenizer", add_special_tokens = True)))

results:

<s> auto_tokenizer
<s> llama_tokenizer</s>

Expected behavior

add eos token like:

<s> auto_tokenizer</s>
<s> llama_tokenizer</s>
@csyourui csyourui changed the title [llama] AutoTokenizer does not add eos_token at the end [llama] AutoTokenizer does not add eos_token at the end May 28, 2023
@NielsRogge
Copy link
Contributor

NielsRogge commented May 30, 2023

Hi,

Note that it doesn't make sense to pass use_fast to the slow (Python-based) LlamaTokenizer. It only makes sense to pass use_fast to the AutoTokenizer class, which can either load the fast (Rust-based) LlamaTokenizerFast class or the slow (Python-based) LlamaTokenizer.

In the code snippet above, auto_tokenizer will be an instance of LlamaTokenizerFast and llama_tokenizer will be an instance of LlamaTokenizer:

>>> type(auto_tokenizer)
<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>
>>> type(llama_tokenizer)
<class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>

Pinging @ArthurZucker regarding the eos_token issue

@csyourui
Copy link
Author

Hi,

Note that it doesn't make sense to pass use_fast to the slow (Python-based) LlamaTokenizer. It only makes sense to pass use_fast to the AutoTokenizer class, which can either load the fast (Rust-based) LlamaTokenizerFast class or the slow (Python-based) LlamaTokenizer.

In the code snippet above, auto_tokenizer will be an instance of LlamaTokenizerFast and llama_tokenizer will be an instance of LlamaTokenizer:

>>> type(auto_tokenizer)
<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>
>>> type(llama_tokenizer)
<class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>

Pinging @ArthurZucker regarding the eos_token issue

Thank you so much for explaining this ~~~

@ArthurZucker
Copy link
Collaborator

Hey!
Thanks for reporting. The quickest fix I can give you is to initialise the fast tokenizer from the slow one, using the correct arguments.

fast = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b", add_eos_token=True, from_slow=True)

This will produce the expected outputs:

>>>  fast.encode("auto_tokenizer", add_special_tokens = True)
[1, 4469, 29918, 6979, 3950, 2]

The reason behind this is that the post_processor is responsible of adding the eos and bos tokens. The processor is initialised when the slow tokenizer is converted to the fast version, and changing the argument on the fly will not result in a change of the processor.

I'll open a PR to make sure that changing the eos and bos update the processor. Thanks for reporting.

@tt6746690
Copy link

For transformers v4.35.0, LlamaTokenizerFast still cannot encode </s> properly. I wonder if there are plans to fix this issue?

@ArthurZucker
Copy link
Collaborator

Hello, this seems to work fine for me:

>>> from transformers import AutoTokenizer 
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") 
>>> tokenizer.encode("</s>", add_special_tokens = False)) 
>>> tokenizer.encode("Hey</s>sir", add_special_tokens = False)
>>> tokenizer.encode("Hey</s>sir", add_special_tokens = False)
 [18637, 2, 8889]
>>> tokenizer.tokenize("Hey</s>", add_special_tokens = False)
['▁Hey', '</s>']

For such an important model we try to fix this as soon as possible as it can impact training for example, would mind sharing a reproducer ? 🤗

@tt6746690
Copy link

tt6746690 commented Nov 6, 2023

@ArthurZucker

I don't have access to "meta-llama/Llama-2-7b-hf", but the following two llama / llama2 model gives me the same results.

Transformers is installed via pip install . on commit b8f1cde and tokenizer==0.14.1

import transformers
print(transformers.__version__) # 4.35.0.dev0
from transformers import AutoTokenizer 
s = 'huggyllama/llama-7b'
s = "NousResearch/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(s) 
print(tokenizer.encode("</s>", add_special_tokens = False)) # [2]
print(tokenizer.tokenize("</s>", add_special_tokens = False)) # ['▁</s>']
print(tokenizer.encode("Hey</s>sir", add_special_tokens = False)) # [18637, 829, 29879, 29958, 29879, 381]
print(tokenizer.tokenize("Hey</s>sir", add_special_tokens = False)) # ['▁Hey', '</', 's', '>', 's', 'ir']

@ArthurZucker
Copy link
Collaborator

That's expected if they did not update the tokenizer.json file to the correct normalisation. I would recommend you to open an issue on the hub as I don't maintain them 🤗

@tt6746690
Copy link

tt6746690 commented Nov 7, 2023

Thanks for the information. Just wondering what is the correct normalisation? I tried setting normalized=False for the special token </s> and that does not help

@ArthurZucker
Copy link
Collaborator

normalization=False should be used. The way to set it on an already initialized tokenizer is the following:

  • Simple way:
>>> tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf", from_slow=True)
>>> print(tokenizer.tokenize("Hey</s>sir", add_special_tokens = False))
['▁Hey', '</s>', '▁sir']
  • After init:
>>> from transformers import AddedToken
>>> tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") 
>>> tokenizer.add_tokens(AddedToken("</s>", normalized=False, special=True), special_tokens=True)

>>> tokenizer.save_pretrained("/tmp/tokenizer-llama")
>>> tokenizer = AutoTokenizer.from_pretrained("/tmp/tokenizer-llama") 
>>> print(tokenizer.tokenize("Hey</s>sir", add_special_tokens = False))
['▁Hey', '</s>', '▁sir']

That is because fast tokenizers are supposed to be fixed after initialization. I'm planning on supporting the update without having to save/load the tokenizer but this was never possible before either.

@tt6746690
Copy link

It works! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants