-
Notifications
You must be signed in to change notification settings - Fork 19
/
main.py
42 lines (32 loc) · 1.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from LanguageLoader import *
from RNN import *
en_path = 'data/en.zip'
fr_path = 'data/fr.zip'
max_length = 20
num_epochs = 1000
num_batches = 750
batch_size = 100
vocab_size = 15000
def main():
data = LanguageLoader(en_path, fr_path, vocab_size, max_length)
rnn = RNN(data.input_size, data.output_size)
losses = []
for epoch in range(num_epochs):
print("=" * 50 + (" EPOCH %i " % epoch) + "=" * 50)
for i, batch in enumerate(data.sentences(batch_size * num_batches, batch_size)):
input, target = batch
loss, outputs = rnn.train(Variable(torch.from_numpy(input).long()), Variable(torch.from_numpy(target).long()))
losses.append(loss)
if i % 100 is 0:
print("Loss at step %d: %.2f" % (i, loss))
print("Truth: \"%s\"" % data.vec_to_sentence(target))
print("Guess: \"%s\"\n" % data.vec_to_sentence(outputs))
rnn.save()
def translate():
data = LanguageLoader(en_path, fr_path, vocab_size, max_length)
rnn = RNN(data.input_size, data.output_size)
vecs = data.sentence_to_vec("the president is here <EOS>")
translation = rnn.eval(vecs)
print(data.vec_to_sentence(translation))
main()
#translate()