Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix spm decoder multi-byte #1092

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions llms/mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@
REPLACEMENT_CHAR = "\ufffd"


def _remove_space(x):
if x and x[0] == " ":
return x[1:]
return x


class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.

Expand Down Expand Up @@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer):

def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
self._sep = "\u2581".encode()

# Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value

# Replace bytes with their value
for i in range(len(self.tokenmap)):
if self.tokenmap[i].startswith("<0x"):
self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16))
if value.startswith("<0x"):
# Replace bytes with their value
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
else:
self.tokenmap[tokenid] = value.encode()

self.reset()

def reset(self):
self.offset = 0
self._unflushed = ""
self._unflushed = b""
self.text = ""
self.tokens = []

def _flush(self):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text

def add_token(self, token):
v = self.tokenmap[token]
if v[0] == "\u2581":
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
if v.startswith(self._sep):
self._flush()
self._unflushed = v
else:
self._unflushed += v

def finalize(self):
if self.text or not self.trim_space:
self.text += self._unflushed.replace("\u2581", " ")
else:
self.text = _remove_space(self._unflushed.replace("\u2581", " "))
self._unflushed = ""
self._flush()
self._unflushed = b""


class BPEStreamingDetokenizer(StreamingDetokenizer):
Expand Down
3 changes: 3 additions & 0 deletions llms/tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def check(tokens):
text += detokenizer.last_segment
self.assertEqual(text, expected_text)

tokens = tokenizer.encode("こんにちは!私の名前はAI")
check(tokens)

tokens = tokenizer.encode("a ,b")
check(tokens)

Expand Down