Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend tokenizer vocabulary with new words #627

Closed
anferico opened this issue Feb 11, 2021 · 13 comments
Closed

Extend tokenizer vocabulary with new words #627

anferico opened this issue Feb 11, 2021 · 13 comments
Labels

Comments

@anferico
Copy link

Suppose I have a pre-trained tokenizer, e.g. a BertWordPieceTokenizer, with its own vocabulary. My goal is to use it to tokenize some technical text which will likely contain unknown words (represented as "[UNK]" tokens).

Is there a way to fine-tune the tokenizer so that unknown words are automatically added to its vocabulary? I have found similar issues in the transformers repository (transformers/issues/2691 and transformers/issues/1413), but what they suggest is to manually add unknown tokens, whereas I would like them to be added automatically.

Here's a pseudo-code representation of what I would need:

pre_trained_tokenizer = ...
vocab = pre_trained_tokenizer.get_vocab()

technical_text = [
  'some text with unknown words',
  'some other text with unknown words',
  ...
]

updated_tokenizer = pre_trained_tokenizer.train(
  technical_text,
  initial_vocabulary=vocab
)

new_vocab = updated_tokenizer.get_vocab()  # 'new_vocab' contains all words in 'vocab' plus some new words

Can I do that with huggingface/tokenizers and/or huggingface/transformers?
I thought it would be an easy thing to do, but I wasn't able to find anything useful.

@n1t0
Copy link
Member

n1t0 commented Feb 12, 2021

No that's not possible, you'll have to add the tokens manually indeed.

@anferico
Copy link
Author

Thanks for the reply. Just to clarify, is it a missing feature of the library or is it a limitation of the tokenization algorithm?

@n1t0
Copy link
Member

n1t0 commented Feb 18, 2021

It depends on the specific tokenization algorithm, but the tokenizer doesn't save all the training state that would be needed to pick up the training back where it was initially left.

@jowagner
Copy link

Most off-the-shelf models have plenty of unused vocabulary entries that you could repurpose:

  1. Train a new vocabulary on the target domain corpus from scratch
  2. Find the new vocabulary entries that are not in the old vocabulary
  3. If the number of new entries is outside the desired range change settings and/or corpus and repeat from step 1
  4. Replace the last unused entries with the new entries
  5. Fine-tune as usual

If your application needs some unused entries for itself you must of course leave a sufficient number of such entries.

@juanjucm
Copy link

juanjucm commented Mar 1, 2021

Here's a pseudo-code representation of what I would need:

pre_trained_tokenizer = ...
vocab = pre_trained_tokenizer.get_vocab()

technical_text = [
  'some text with unknown words',
  'some other text with unknown words',
  ...
]

updated_tokenizer = pre_trained_tokenizer.train(
  technical_text,
  initial_vocabulary=vocab
)

new_vocab = updated_tokenizer.get_vocab()  # 'new_vocab' contains all words in 'vocab' plus some new words

Hi @anferico , I don't know if this is what you were looking for, but this could be a possible approach for your problem:

  1. First, you need to extract tokens out of your data while applying the same preprocessing steps used by the tokenizer. To do so you can just use the tokenizer itself:
    new_tokens = tokenizer.basic_tokenizer.tokenize(' '.join(technical_text))
  2. Now you just add the new tokens to the tokenizer vocabulary:
    tokenizer.add_tokens(new_tokens)
    This method only adds new tokens, which means you don't have to worry about words already present in the tokenizer's vocab.

The result would be the tokenizer with your specific domain tokens along with the original tokenizer's vocabulary. Of course, you can just encapsulate this in a function and use it like you do in your pseudocode.

Remember that for your model to work, you will need to update the embedding layer with the new augmented vocabulary:
model.resize_token_embeddings(len(tokenizer))

Hope it helps!! :)

@dplaniel
Copy link

dplaniel commented May 6, 2022

@harveyaot
Copy link

I think this tutorial shared by the official can help with your question Training a new tokenizer from an old one

@savanth14
Copy link

savanth14 commented Apr 16, 2024

@ArthurZucker @younesbelkada @Narsil @n1t0 I tried to add new vocab to the existing mistral tokenizer vocab using the add_tokens() method. Everything went fine till I tried the extended vocab tokenizer for decoding the encoded text. I found that in the decoded text, the spaces are completely missing and all the decoded tokens are merged into a single string. Can you please help me resolve this issue. Here's the sample code:

import sentencepiece as spm

sp = spm.SentencePieceProcessor(model_file='mistral_tok.model')
tokenizer1 = transformers.AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1")

vocab = [sp.id_to_piece(idx) for idx in range(sp.get_piece_size())]

new_tokens = set(vocab) - set(tokenizer1.vocab.keys())

tokenizer1.add_tokens(list(new_tokens))
# output: 14756

print("After adding new tokens, length of mistral tokenizer:", len(tokenizer1))
# output: 46756

tel_text = "నేను బాగున్నాను. మీరు ఏలా ఉన్నారు?" # original text

mistral_encode_ids = tokenizer1.encode(tel_text)

mistral_decode_text = tokenizer1.decode(mistral_encode_ids, skip_special_tokens=True)

print(mistral_decode_text)

# output: నేనుబాగున్నాను.మీరుఏలాఉన్నారు? # decoded text with missing spaces

To dig further into the problem, I re-initialised the mistral tokenizer from its original checkpoint "mistralai/mistral-7b-v0.1". Then I added 3 manually defined random tokens to the tokenizer using the same add_tokens method. Now I used the extended vocab tokenizer to encode and decode some text and it worked fine. I mean, the decoded text has retained the spacing similar to the original random text. Here's the code for this experiment:

mistral_tok = AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1")

new_tokens = ["yoyoyo", "xoxoxo", "z0z0z0"]

mistral_tok.add_tokens(list(new_tokens))

print("After adding new tokens, length of mistral tokenizer:", len(mistral_tok))

random_text = "yoyoyo xoxoxo z0z0z0!"

random_text_2 = "This is my new yoyoyo style xoxoxo of z0z0z0 writing!"

mistral_encode_ids = mistral_tok.encode(random_text)

mistral_decode_text = mistral_tok.decode(mistral_encode_ids, skip_special_tokens=True)

mistral_encode_ids_2 = mistral_tok.encode(random_text_2)

mistral_decode_text_2 = mistral_tok.decode(mistral_encode_ids_2, skip_special_tokens=True)

print(mistral_decode_text)
# output: yoyoyo xoxoxo z0z0z0! # decoded text with spacing intact

print(mistral_decode_text_2) 
# This is my new yoyoyo style xoxoxo of z0z0z0 writing! # decoded text with spacing intact

Where is the problem? Why is the extended vocab tokenizer not able to decode properly when using the vocab from a different tokenizer? On the contrary, it is able to decode properly when new tokens are added manually.

@savanth14
Copy link

Hello @bezir, thanks for your comment. I figured out how to successfully merge two sentencepiece BPE tokenizers without losing the tokenization efficiency. Here's the code:

# Load the pre-trained tokenizer to be extended
original_tokenizer_path = hf_hub_download(repo_id="mistralai/mistral-7b-v0.1", filename="tokenizer.model", local_dir="original_tokenizer")
original_tokenizer_spm = sp_pb2_model.ModelProto()
original_tokenizer_spm.ParseFromString(open(original_tokenizer_path, "rb").read())

# Load the newly trained tokenizer
new_tokenizer_spm = sp_pb2_model.ModelProto()
new_tokenizer_spm.ParseFromString(open("/content/mistral_tel_tokenizer.model", "rb").read())


# Check if the new tokenizer contains english tokens
def contains_eng(text):
    eng_pattern = re.compile(r"[\u0020-\u007E]+")
    return True if eng_pattern.search(text) else False


original_tokenizer_tokenset = set(p.piece for p in original_tokenizer_spm.pieces)
print(f"Number of tokens before merge: {len(original_tokenizer_tokenset)}")
for p in new_tokenizer_spm.pieces:
    piece = p.piece
    if piece not in original_tokenizer_tokenset and not contains_eng(piece):
        new_p = sp_pb2_model.ModelProto().SentencePiece()
        new_p.piece = piece
        new_p.score = 0
        original_tokenizer_spm.pieces.append(new_p)
print(f"Number of tokens after merge: {len(original_tokenizer_spm.pieces)}")

# Save the extended tokenizer to a checkpoint
extended_tokenizer_save_path="/content/english-telugu-tokenizer"
os.makedirs(extended_tokenizer_save_path, exist_ok=True)
with open(os.path.join(extended_tokenizer_save_path, "tokenizer.model"), "wb") as f:
    f.write(original_tokenizer_spm.SerializeToString())

I adapted this code from this source: https://github.com/google/sentencepiece/blob/master/python/add_new_vocab.ipynb

However, I have a new problem now. I trained a tiktoken's bytelevel tokenizer using the code from this repo: https://github.com/gautierdag/tokenizer-bench
When I merged this new tokenizer with a pre-trained one, the telugu encoding performance got degraded. The merged tokenizer is splitting telugu text into too many fine-grained tokens. Encoding-Decoding for english, and decoding for telugu are working fine. Here's the code I used for merging:

import json

def merge_tokenizers(file1, file2, output_file):
    # Load the tokenizers
    with open(file1, 'r') as f:
        tokenizer1 = json.load(f)
    with open(file2, 'r') as f:
        tokenizer2 = json.load(f)

    # Get the maximum rank in tokenizer1's vocab
    max_rank = max(tokenizer1['model']['vocab'].values())

    # Combine the vocabs and merges
    combined_vocab = tokenizer1['model']['vocab'].copy()
    for token, rank in tokenizer2['model']['vocab'].items():
        if token not in combined_vocab:
            combined_vocab[token] = len(combined_vocab) + 1


    combined_merges = tokenizer1['model']['merges'].copy()
    for merge in tokenizer2['model']['merges']:
        if merge not in combined_merges:
            combined_merges.append(merge)

    # combined_merges = tokenizer1['model']['merges'].copy()
    # combined_merges.extend(merge for merge in tokenizer2['model']['merges'] if merge not in combined_merges)

    # Update the vocab and merges in tokenizer1
    tokenizer1['model']['vocab'] = combined_vocab
    tokenizer1['model']['merges'] = combined_merges

    # Save the updated tokenizer
    with open(output_file, 'w') as f:
        json.dump(tokenizer1, f)

# Usage
merge_tokenizers("/content/gpt_32k.json", "/content/telugu_tokenizer_tiktoken.json", 'tokenizer_18.json')

The new tokenizer's training corpus is completely different from that of the pre-trained one. Also, I made sure to remove any duplicate merges. Still the encoding performance is poor.

@ArthurZucker
Copy link
Collaborator

Few things here. The extra spaces that are removed is because you are using the legacy behaviour which was fixed in main.
Init the tokenizer with legacy=False.
Missing spaces can also be fixed by setting normalized=False when adding the token as an AddedToken:

from transformers import AddedToken, AutoTokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1")
tokenizer.add_tokens(AddedToken("<bbb>",normalized=True), True)
tokenizer.decode(tokenizer.encode(". <bbb>"))
'<s> .<bbb>'

vs

from transformers import AddedToken, AutoTokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/mistral-7b-v0.1")
tokenizer.add_tokens(AddedToken("<bbb>",normalized=False), True)
tokenizer.decode(tokenizer.encode(". <bbb>"))
'<s> . <bbb>'

@ArthurZucker
Copy link
Collaborator

The merged tokenizer is splitting telugu text into too many fine-grained tokens

try setting the ignore_merges parameter in BPETokenizer to True. See this PR:

Copy link

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions bot added the Stale label May 31, 2024
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Jun 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

9 participants