Skip to content

Commit

Permalink
Fix truncated words list when the replacement character is decoded (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored and zackees committed May 5, 2023
1 parent ac8b7a8 commit 188726e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
10 changes: 10 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,13 @@ def test_tokenizer():
assert gpt2_tokenizer.decode(gpt2_tokens) == text
assert multilingual_tokenizer.decode(multilingual_tokens) == text
assert len(gpt2_tokens) > len(multilingual_tokens)


def test_split_on_unicode():
multilingual_tokenizer = get_tokenizer(multilingual=True)

tokens = [8404, 871, 287, 6, 246, 526, 3210, 20378]
words, word_tokens = multilingual_tokenizer.split_tokens_on_unicode(tokens)

assert words == [" elle", " est", " l", "'", "�", "é", "rit", "oire"]
assert word_tokens == [[8404], [871], [287], [6], [246], [526], [3210], [20378]]
12 changes: 11 additions & 1 deletion whisper/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,27 @@ def split_to_word_tokens(self, tokens: List[int]):
return self.split_tokens_on_spaces(tokens)

def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"

words = []
word_tokens = []
current_tokens = []
unicode_offset = 0

for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if "\ufffd" not in decoded:

if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)

return words, word_tokens

Expand Down

0 comments on commit 188726e

Please sign in to comment.