Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add conversion script for the tokenizer #27338

Merged
merged 25 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/en/model_doc/whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ The original code can be found [here](https://github.com/openai/whisper).
- Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release.
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.

- To convert the tokenizer, we recommand using the following:

```bash
python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_tokenizer True --num_languages 100 --multilingual True
```
Here the number of languages is set to `100` to account for `cantonese` which was added in `whisper-large-v3`.

## WhisperConfig

[[autodoc]] WhisperConfig
Expand Down
116 changes: 115 additions & 1 deletion src/transformers/models/whisper/convert_openai_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@

import argparse
import hashlib
import json
import os
import tempfile
import urllib
import warnings

import torch
from torch import nn
from tqdm import tqdm

from transformers import WhisperConfig, WhisperForConditionalGeneration
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperTokenizer
from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode
from transformers.utils.import_utils import _is_package_available


_MODELS = {
Expand All @@ -38,6 +42,11 @@
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
}

_TOKENIZERS = {
"multilingual": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken",
"english": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken",
}


def remove_ignore_keys_(state_dict):
ignore_keys = ["layers", "blocks"]
Expand Down Expand Up @@ -174,11 +183,116 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path):
model.save_pretrained(pytorch_dump_folder_path)


# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960
def _bpe(mergeable_ranks, token: bytes, max_rank=None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]
return parts


def convert_tiktoken_bpe_to_hf(tiktoken_url: str):
bpe_ranks = load_tiktoken_bpe(tiktoken_url)
byte_encoder = bytes_to_unicode()

def token_bytes_to_string(b):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])

merges = []
vocab = {}
for token, rank in bpe_ranks.items():
vocab[token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = tuple(_bpe(bpe_ranks, token, max_rank=rank))
if len(merged) == 2: # account for empty token
merges.append(" ".join(map(token_bytes_to_string, merged)))
return vocab, merges


def convert_tiktoken_to_hf(
pytorch_dump_folder_path: str, multilingual: bool = True, num_languages: int = 100, time_precision=0.02
) -> WhisperTokenizer:
# requires whisper, unless we use the path to the tiktoken file
tiktoken_tokenizer_path = _TOKENIZERS["multilingual" if multilingual else "english"]
start_of_transcript = ["<|endoftext|>", "<|startoftranscript|>"]
control_tokens = [
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nocaptions|>",
"<|notimestamps|>",
]
# these are special tokens, not normalized
language_tokens = [f"<|{k}|>" for k in list(LANGUAGES)[:num_languages]]
# These are not special but normalized
timestamp_tokens = [("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)]

vocab, merges = convert_tiktoken_bpe_to_hf(tiktoken_tokenizer_path)

with tempfile.TemporaryDirectory() as tmpdirname:
vocab_file = f"{tmpdirname}/vocab.json"
merge_file = f"{tmpdirname}/merges.txt"
with open(vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")

with open(merge_file, "w", encoding="utf-8") as writer:
writer.write("#version: 0.2\n")
for bpe_tokens in merges:
writer.write(bpe_tokens + "\n")

hf_tokenizer = WhisperTokenizer(vocab_file, merge_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to convert the fast tokenizer as well? Or all good with just the slow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fast can always be converted from slow when loading with autoTokenizer, so no need I'd say but can add a comment


hf_tokenizer.add_tokens(start_of_transcript + language_tokens + control_tokens, special_tokens=True)
hf_tokenizer.add_tokens(timestamp_tokens, special_tokens=False)
hf_tokenizer.save_pretrained(pytorch_dump_folder_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# # Required parameters
parser.add_argument("--checkpoint_path", type=str, help="Patht to the downloaded checkpoints")
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument(
"--convert_tokenizer",
type=bool,
default=False,
help="Whether or not the tokenizer should be converted along with the model.",
)
parser.add_argument(
"--num_languages",
type=int,
default=99,
help="Number of languages that are supported. Whisper-large-v3 supports 100 languages, 1 more than the previous releases.",
)
parser.add_argument(
"--multilingual",
type=bool,
default="store_true",
help="Whether or not the model is multilingual or english only",
)
args = parser.parse_args()

if args.convert_tokenizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it's more intuitive to always convert the tokenizer, since we can't use the model without it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but not BC because this requires tiktoken

try:
if not _is_package_available("tiktoken"):
raise """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer"""
except Exception:
pass
else:
from tiktoken.load import load_tiktoken_bpe

convert_tiktoken_to_hf(args.pytorch_dump_folder_path, args.multilingual, args.num_languages)

convert_openai_whisper_to_tfms(args.checkpoint_path, args.pytorch_dump_folder_path)
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def get_pairs(word):
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}

# language code lookup by name, with a few language aliases
Expand All @@ -207,6 +208,7 @@ def get_pairs(word):
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}

TASK_IDS = ["translate", "transcribe"]
Expand Down