diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 15f9e91137be..8d73a5655fdf 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -34,8 +34,13 @@ The original code can be found [here](https://github.com/openai/whisper). - Inference is currently only implemented for short-form i.e. audio is pre-segmented into <=30s segments. Long-form (including timestamps) will be implemented in a future release. - One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text. -This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts). -The original code can be found [here](https://github.com/openai/whisper). +- To convert the tokenizer, we recommend using the following: + +```bash +python src/transformers/models/whisper/convert_openai_to_hf.py --checkpoint_path "" --pytorch_dump_folder_path "Arthur/whisper-3" --convert_tokenizer True --whisper_version 3 --multilingual True +``` +Here the `whisper_version` will set the number of languages to `100` to account for `cantonese` which was added in `whisper-large-v3`. + ## Inference diff --git a/src/transformers/models/whisper/convert_openai_to_hf.py b/src/transformers/models/whisper/convert_openai_to_hf.py index 6eb7e0f233c8..1d016b598439 100755 --- a/src/transformers/models/whisper/convert_openai_to_hf.py +++ b/src/transformers/models/whisper/convert_openai_to_hf.py @@ -17,7 +17,9 @@ import argparse import hashlib import io +import json import os +import tempfile import urllib import warnings @@ -25,7 +27,9 @@ from torch import nn from tqdm import tqdm -from transformers import WhisperConfig, WhisperForConditionalGeneration +from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperTokenizer +from transformers.models.whisper.tokenization_whisper import LANGUAGES, bytes_to_unicode +from transformers.utils.import_utils import _is_package_available _MODELS = { @@ -41,6 +45,11 @@ "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", } +_TOKENIZERS = { + "multilingual": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/multilingual.tiktoken", + "english": "https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/gpt2.tiktoken", +} + def remove_ignore_keys_(state_dict): ignore_keys = ["layers", "blocks"] @@ -178,11 +187,119 @@ def convert_openai_whisper_to_tfms(checkpoint_path, pytorch_dump_folder_path): model.save_pretrained(pytorch_dump_folder_path) +# Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960 +def _bpe(mergeable_ranks, token: bytes, max_rank=None) -> list[bytes]: + parts = [bytes([b]) for b in token] + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = mergeable_ranks.get(pair[0] + pair[1]) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :] + return parts + + +def convert_tiktoken_bpe_to_hf(tiktoken_url: str): + bpe_ranks = load_tiktoken_bpe(tiktoken_url) + byte_encoder = bytes_to_unicode() + + def token_bytes_to_string(b): + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) + + merges = [] + vocab = {} + for token, rank in bpe_ranks.items(): + vocab[token_bytes_to_string(token)] = rank + if len(token) == 1: + continue + merged = tuple(_bpe(bpe_ranks, token, max_rank=rank)) + if len(merged) == 2: # account for empty token + merges.append(" ".join(map(token_bytes_to_string, merged))) + return vocab, merges + + +def convert_tiktoken_to_hf( + pytorch_dump_folder_path: str, multilingual: bool = True, num_languages: int = 100, time_precision=0.02 +) -> WhisperTokenizer: + # requires whisper, unless we use the path to the tiktoken file + tiktoken_tokenizer_path = _TOKENIZERS["multilingual" if multilingual else "english"] + start_of_transcript = ["<|endoftext|>", "<|startoftranscript|>"] + control_tokens = [ + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nocaptions|>", + "<|notimestamps|>", + ] + # these are special tokens, not normalized + language_tokens = [f"<|{k}|>" for k in list(LANGUAGES)[:num_languages]] + # These are not special but normalized + timestamp_tokens = [("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)] + + vocab, merges = convert_tiktoken_bpe_to_hf(tiktoken_tokenizer_path) + + with tempfile.TemporaryDirectory() as tmpdirname: + vocab_file = f"{tmpdirname}/vocab.json" + merge_file = f"{tmpdirname}/merges.txt" + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write("#version: 0.2\n") + for bpe_tokens in merges: + writer.write(bpe_tokens + "\n") + + hf_tokenizer = WhisperTokenizer(vocab_file, merge_file) + + hf_tokenizer.add_tokens(start_of_transcript + language_tokens + control_tokens, special_tokens=True) + hf_tokenizer.add_tokens(timestamp_tokens, special_tokens=False) + hf_tokenizer.save_pretrained(pytorch_dump_folder_path) + + if __name__ == "__main__": parser = argparse.ArgumentParser() # # Required parameters parser.add_argument("--checkpoint_path", type=str, help="Patht to the downloaded checkpoints") parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument( + "--convert_tokenizer", + type=bool, + default=False, + help="Whether or not the tokenizer should be converted along with the model.", + ) + parser.add_argument( + "--whisper_version", + type=int, + default=2, + help="Version of the whisper release", + ) + parser.add_argument( + "--multilingual", + type=bool, + default="store_true", + help="Whether or not the model is multilingual or english only", + ) args = parser.parse_args() + if args.convert_tokenizer: + try: + if not _is_package_available("tiktoken"): + raise """`tiktoken` is not installed, use `pip install tiktoken` to convert the tokenizer""" + except Exception: + pass + else: + from tiktoken.load import load_tiktoken_bpe + + NUM_LANGUAGES_PER_RELEASE = {1: 99, 2: 99, 3: 100} + convert_tiktoken_to_hf( + args.pytorch_dump_folder_path, args.multilingual, NUM_LANGUAGES_PER_RELEASE[args.whisper_version] + ) + convert_openai_whisper_to_tfms(args.checkpoint_path, args.pytorch_dump_folder_path) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 3fa1fe2755c2..a54103ccef8f 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -191,6 +191,7 @@ def get_pairs(word): "ba": "bashkir", "jw": "javanese", "su": "sundanese", + "yue": "cantonese", } # language code lookup by name, with a few language aliases @@ -207,6 +208,7 @@ def get_pairs(word): "moldovan": "ro", "sinhalese": "si", "castilian": "es", + "mandarin": "zh", } TASK_IDS = ["translate", "transcribe"] @@ -1206,7 +1208,7 @@ def _combine_tokens_into_words( if language is None: language = "english" - if language in {"chinese", "japanese", "thai", "lao", "myanmar"}: + if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}: # These languages don't typically use spaces. words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens) else: