-
Notifications
You must be signed in to change notification settings - Fork 7
/
prediction.py
122 lines (99 loc) · 4.12 KB
/
prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
import cPickle as pickle
import random
import numpy as np
from chainer import cuda, Variable
import chainer.functions as F
import chainer.links as L
from chainer import serializers
import chainer
class Network(chainer.Chain):
def __init__(self, n_vocab, n_units, dropout_ratio=0.0, train=True):
super(Network, self).__init__(
embed=L.EmbedID(n_vocab, n_units),
l1=L.LSTM(n_units, n_units),
l2=L.LSTM(n_units, n_units),
l3=L.LSTM(n_units, n_units),
l4=L.LSTM(n_units, n_units),
l5=L.LSTM(n_units, n_units),
l6=L.Linear(n_units, n_vocab),
)
for param in self.params():
param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape)
self.train = train
self.dropout_ratio = dropout_ratio
def reset_state(self):
self.l1.reset_state()
self.l2.reset_state()
self.l3.reset_state()
self.l4.reset_state()
self.l5.reset_state()
def __call__(self, x):
h0 = self.embed(x)
h1 = self.l1(F.dropout(h0, ratio=self.dropout_ratio, train=self.train))
h2 = self.l2(F.dropout(h1, ratio=self.dropout_ratio, train=self.train))
h3 = self.l3(F.dropout(h2, ratio=self.dropout_ratio, train=self.train))
h4 = self.l4(F.dropout(h3, ratio=self.dropout_ratio, train=self.train))
h5 = self.l5(F.dropout(h4, ratio=self.dropout_ratio, train=self.train))
y = self.l6(F.dropout(h5, ratio=self.dropout_ratio, train=self.train))
return y
def predict(self, x):
h0 = self.embed(x)
h1 = self.l1(F.dropout(h0, ratio=self.dropout_ratio, train=self.train))
h2 = self.l2(F.dropout(h1, ratio=self.dropout_ratio, train=self.train))
h3 = self.l3(F.dropout(h2, ratio=self.dropout_ratio, train=self.train))
h4 = self.l4(F.dropout(h3, ratio=self.dropout_ratio, train=self.train))
h5 = self.l5(F.dropout(h4, ratio=self.dropout_ratio, train=self.train))
y = self.l6(F.dropout(h5, ratio=self.dropout_ratio, train=self.train))
return F.softmax(y)
<