-
Notifications
You must be signed in to change notification settings - Fork 0
/
tokenizer.py
190 lines (159 loc) · 7.11 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import math
import numpy as np
from Finch.layers import Populate, CapPopulation, SortByFitness
from Finch.environments import Sequential
from .layers import *
from .genepool import *
import pickle
from tqdm import tqdm
class TrieNode:
def __init__(self):
self.children = {}
self.token_index = -1 # -1 indicates no token ends here
self.is_end_of_token = False # Indicates if a complete token ends at this node
class Trie:
def __init__(self):
self.root = TrieNode()
def insert(self, token, index):
node = self.root
for char in token:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.token_index = index
node.is_end_of_token = True # Mark the end of a complete token
def search(self, text):
tokens = []
node = self.root
start_index = 0 # Start index of the current token being processed
while start_index < len(text):
node = self.root # Reset node to root for each new starting character
longest_token_index = -1 # Initialize with -1 indicating no token found yet
longest_token_length = 0 # Length of the longest token found
for i in range(start_index, len(text)):
char = text[i]
if char in node.children:
node = node.children[char]
if node.is_end_of_token:
longest_token_index = node.token_index # Update if a longer token is found
longest_token_length = i - start_index + 1
else:
break # Break the loop if current character is not in children
if longest_token_index != -1:
tokens.append(longest_token_index) # Append the longest token index found
start_index += longest_token_length # Move start index to the end of the longest token
else:
start_index += 1 # Move to the next character if no token was found
return tokens
# todo identify gaps based on the halving rule
class GeneticTokenizer:
def __init__(self, min_range=2, max_range=6, max_population=11, start_population=10, mutate_amount=5,
families=2, step_epochs: int = 15, existing_tokens: list = [], right_freezable=False,
left_freezable=True):
self.fitness_results = {} # for speed boost
self.tokens = existing_tokens
self.step_epochs = step_epochs
self.min_range = min_range
self.max_range = max_range
self.last_iteration = 0
self.max_population = max_population
self.start_population = start_population
self.mutate_amount = mutate_amount
self.families = families
self.trie = Trie()
self.right_freezable = right_freezable
self.left_freezable = left_freezable
for i, token in enumerate(self.tokens):
self.trie.insert(token, i) # Populate Trie with existing tokens
def evolve(self, dataset):
total = len(dataset)
with tqdm(total=total, desc="Evolving Tokenizer") as pbar:
for text in dataset:
self.step(text)
pbar.update(1)
def step(self, text: str):
pool = RangePool(min_range=self.min_range, max_range=self.max_range,
source_text=text, right_freezable=self.right_freezable, left_freezable=self.left_freezable)
max_population = self.max_population
start_population = self.start_population
# Create the environment
environment = Sequential(
layers=[
Populate(gene_pool=pool, population=start_population),
MutateToken(individual_selection=self.mutate_amount),
ParentToken(families=self.families, gene_pool=pool),
SortByFitness(),
CapPopulation(max_population)
]
)
environment.compile(self.fitness, verbose_every=False)
environment.iteration = self.last_iteration
environment.evolve(self.step_epochs)
for individual in environment.individuals:
if individual.token not in self.tokens:
self.tokens.append(individual.token)
self.trie.insert(individual.token, len(self.tokens) - 1)
self.last_iteration = environment.iteration
def fitness(self, individual: RangeToken):
token = individual.token
if token in self.fitness_results:
return self.fitness_results[token]
source_text = individual.source
count = source_text.count(token)
percent = count / individual.length
score = ((len(token) + len(self.tokenize(token))) * percent) # the less tokenized the text, the higher fitness
self.fitness_results.update({token: score})
return score
def tokenize(self, text):
indices = self.trie.search(text)
return indices
def detokenize(self, indices, join="|"):
"""
Detokenize the given list of indices into the original text.
"""
return join.join(self.tokens[i] for i in indices)
def interface(self):
while 1:
toks = self.tokenize(input("tokens: "))
print("tokens: ", toks)
print("detokens: ", self.detokenize(toks))
def save(self, filename):
"""
Save the state of the GeneticTokenizer object to a file, excluding the Trie.
"""
with open(filename + ".gentok", 'wb') as f:
pickle.dump({
'fitness_results': self.fitness_results,
'tokens': self.tokens,
# 'trie': self.trie, # Do not save the Trie structure
'min_range': self.min_range,
'max_range': self.max_range,
'max_population': self.max_population,
'start_population': self.start_population,
'mutate_amount': self.mutate_amount,
'families': self.families,
'right_freezable': self.right_freezable,
'left_freezable': self.left_freezable,
}, f)
def load(self, filename):
"""
Load the state of the GeneticTokenizer object from a file, excluding the Trie.
Reconstruct the Trie from the loaded tokens.
"""
with open(filename + ".gentok", 'rb') as f:
data = pickle.load(f)
self.tokens = data['tokens']
self.fitness_results = data['fitness_results']
# self.trie = data['trie'] # Do not load the Trie structure
self.trie = Trie() # Reconstruct the Trie
for i, token in enumerate(self.tokens):
self.trie.insert(token, i)
# Load additional attributes
self.min_range = data['min_range']
self.max_range = data['max_range']
self.max_population = data['max_population']
self.start_population = data['start_population']
self.mutate_amount = data['mutate_amount']
self.families = data['families']
self.right_freezable = data['right_freezable']
self.left_freezable = data['left_freezable']