-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tokenizer.py
35 lines (29 loc) · 1.23 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pickle
UNK_TOKEN_ID = 0
PAD_TOKEN_ID = 1
SOS_TOKEN_ID = 2
EOS_TOKEN_ID = 3
class CharTokenizer:
def __init__(self):
self.char_to_id = {'<PAD>': PAD_TOKEN_ID, '<SOS>': SOS_TOKEN_ID, '<EOS>': EOS_TOKEN_ID, '<UNK>': UNK_TOKEN_ID}
self.id_to_char = {PAD_TOKEN_ID: '<PAD>', SOS_TOKEN_ID: '<SOS>', EOS_TOKEN_ID: '<EOS>', UNK_TOKEN_ID: '<UNK>'}
self.next_id = 4
def encode(self, text):
encoded = [SOS_TOKEN_ID]
for char in text:
if char not in self.char_to_id:
self.char_to_id[char] = self.next_id
self.id_to_char[self.next_id] = char
self.next_id += 1
encoded.append(self.char_to_id[char])
encoded.append(EOS_TOKEN_ID)
return encoded
def decode(self, ids):
return ''.join(self.id_to_char[id_] for id_ in ids if id_ not in (UNK_TOKEN_ID, SOS_TOKEN_ID, EOS_TOKEN_ID, PAD_TOKEN_ID))
def save(self, file_path):
with open(file_path, 'wb') as file:
pickle.dump((self.char_to_id, self.id_to_char, self.next_id), file)
def load(self, file_path):
with open(file_path, 'rb') as file:
self.char_to_id, self.id_to_char, self.next_id = pickle.load(file)
return