diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index f4cd80b9f..09d0351e1 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -12,3 +12,13 @@ def test_tokenizer(): assert gpt2_tokenizer.decode(gpt2_tokens) == text assert multilingual_tokenizer.decode(multilingual_tokens) == text assert len(gpt2_tokens) > len(multilingual_tokens) + + +def test_split_on_unicode(): + multilingual_tokenizer = get_tokenizer(multilingual=True) + + tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378] + words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens) + + assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"] + assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]] diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index dfa9f71eb..236f65e00 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -279,17 +279,27 @@ def split_to_word_tokens(self, tokens: List[int]): return self.split_tokens_on_spaces(tokens) def split_tokens_on_unicode(self, tokens: List[int]): + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + words = [] word_tokens = [] current_tokens = [] + unicode_offset = 0 for token in tokens: current_tokens.append(token) decoded = self.decode_with_timestamps(current_tokens) - if "\ufffd" not in decoded: + + if ( + replacement_char not in decoded + or decoded_full[unicode_offset + decoded.index(replacement_char)] + == replacement_char + ): words.append(decoded) word_tokens.append(current_tokens) current_tokens = [] + unicode_offset += len(decoded) return words, word_tokens