-
Notifications
You must be signed in to change notification settings - Fork 24
/
lstm.py
324 lines (259 loc) · 10.6 KB
/
lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#!/usr/bin/env python
"""
Implementation of LSTM variants.
For now, they only support a sequence size of 1, and are ideal for RL use-cases.
Besides that, they are a stripped-down version of PyTorch's RNN layers.
(no bidirectional, no num_layers, no batch_first)
"""
import math
from typing import Tuple
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from .normalize import LayerNorm
class SlowLSTM(nn.Module):
"""
A pedagogic implementation of Hochreiter & Schmidhuber:
'Long-Short Term Memory'
http://www.bioinf.jku.at/publications/older/2604.pdf
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, dropout: float = 0.0):
super(SlowLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.dropout = dropout
# input to hidden weights
self.w_xi = Parameter(th.empty(hidden_size, input_size))
self.w_xf = Parameter(th.empty(hidden_size, input_size))
self.w_xo = Parameter(th.empty(hidden_size, input_size))
self.w_xc = Parameter(th.empty(hidden_size, input_size))
# hidden to hidden weights
self.w_hi = Parameter(th.empty(hidden_size, hidden_size))
self.w_hf = Parameter(th.empty(hidden_size, hidden_size))
self.w_ho = Parameter(th.empty(hidden_size, hidden_size))
self.w_hc = Parameter(th.empty(hidden_size, hidden_size))
# bias terms
self.b_i = th.empty(hidden_size).fill_(0)
self.b_f = th.empty(hidden_size).fill_(0)
self.b_o = th.empty(hidden_size).fill_(0)
self.b_c = th.empty(hidden_size).fill_(0)
# Wrap biases as parameters if desired, else as variables without gradients
if bias:
self.b_i = Parameter(self.b_i)
self.b_f = Parameter(self.b_f)
self.b_o = Parameter(self.b_o)
self.b_c = Parameter(self.b_c)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, x: th.Tensor, hidden: Tuple[th.Tensor, th.Tensor]) -> Tuple[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
h, c = hidden
h = h.view(h.size(0), -1)
c = c.view(h.size(0), -1)
x = x.view(x.size(0), -1)
# Linear mappings
i_t = th.mm(x, self.w_xi) + th.mm(h, self.w_hi) + self.b_i
f_t = th.mm(x, self.w_xf) + th.mm(h, self.w_hf) + self.b_f
o_t = th.mm(x, self.w_xo) + th.mm(h, self.w_ho) + self.b_o
# activations
i_t.sigmoid_()
f_t.sigmoid_()
o_t.sigmoid_()
# cell computations
c_t = th.mm(x, self.w_xc) + th.mm(h, self.w_hc) + self.b_c
c_t.tanh_()
c_t = th.mul(c, f_t) + th.mul(i_t, c_t)
h_t = th.mul(o_t, th.tanh(c_t))
# Reshape for compatibility
h_t = h_t.view(h_t.size(0), 1, -1)
c_t = c_t.view(c_t.size(0), 1, -1)
if self.dropout > 0.0:
F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
return h_t, (h_t, c_t)
def sample_mask(self):
pass
class LSTM(nn.Module):
"""
An implementation of Hochreiter & Schmidhuber:
'Long-Short Term Memory'
http://www.bioinf.jku.at/publications/older/2604.pdf
Special args:
dropout_method: one of
* pytorch: default dropout implementation
* gal: uses GalLSTM's dropout
* moon: uses MoonLSTM's dropout
* semeniuta: uses SemeniutaLSTM's dropout
"""
def __init__(
self, input_size: int, hidden_size: int, bias: bool = True, dropout: float = 0.0, dropout_method: str = "pytorch"
):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.dropout = dropout
self.i2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)
self.reset_parameters()
assert dropout_method.lower() in ["pytorch", "gal", "moon", "semeniuta"]
self.dropout_method = dropout_method
def sample_mask(self):
keep = 1.0 - self.dropout
self.mask = th.bernoulli(th.empty(1, self.hidden_size).fill_(keep))
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, x: th.Tensor, hidden: Tuple[th.Tensor, th.Tensor]) -> Tuple[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
do_dropout = self.training and self.dropout > 0.0
h, c = hidden
h = h.view(h.size(1), -1)
c = c.view(c.size(1), -1)
x = x.view(x.size(1), -1)
# Linear mappings
preact = self.i2h(x) + self.h2h(h)
# activations
gates = preact[:, : 3 * self.hidden_size].sigmoid()
g_t = preact[:, 3 * self.hidden_size :].tanh()
i_t = gates[:, : self.hidden_size]
f_t = gates[:, self.hidden_size : 2 * self.hidden_size]
o_t = gates[:, -self.hidden_size :]
# cell computations
if do_dropout and self.dropout_method == "semeniuta":
g_t = F.dropout(g_t, p=self.dropout, training=self.training)
c_t = th.mul(c, f_t) + th.mul(i_t, g_t)
if do_dropout and self.dropout_method == "moon":
c_t.data.set_(th.mul(c_t, self.mask).data)
c_t.data *= 1.0 / (1.0 - self.dropout)
h_t = th.mul(o_t, c_t.tanh())
# Reshape for compatibility
if do_dropout:
if self.dropout_method == "pytorch":
F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
if self.dropout_method == "gal":
h_t.data.set_(th.mul(h_t, self.mask).data)
h_t.data *= 1.0 / (1.0 - self.dropout)
h_t = h_t.view(1, h_t.size(0), -1)
c_t = c_t.view(1, c_t.size(0), -1)
return h_t, (h_t, c_t)
class GalLSTM(LSTM):
"""
Implementation of Gal & Ghahramami:
'A Theoretically Grounded Application of Dropout in Recurrent Neural Networks'
http://papers.nips.cc/paper/6241-a-theoretically-grounded-application-of-dropout-in-recurrent-neural-networks.pdf
"""
def __init__(self, *args, **kwargs):
super(GalLSTM, self).__init__(*args, **kwargs)
self.dropout_method = "gal"
self.sample_mask()
class MoonLSTM(LSTM):
"""
Implementation of Moon & al.:
'RNNDrop: A Novel Dropout for RNNs in ASR'
https://www.stat.berkeley.edu/~tsmoon/files/Conference/asru2015.pdf
"""
def __init__(self, *args, **kwargs):
super(MoonLSTM, self).__init__(*args, **kwargs)
self.dropout_method = "moon"
self.sample_mask()
class SemeniutaLSTM(LSTM):
"""
Implementation of Semeniuta & al.:
'Recurrent Dropout without Memory Loss'
https://arxiv.org/pdf/1603.05118.pdf
"""
def __init__(self, *args, **kwargs):
super(SemeniutaLSTM, self).__init__(*args, **kwargs)
self.dropout_method = "semeniuta"
class LayerNormLSTM(LSTM):
"""
Layer Normalization LSTM, based on Ba & al.:
'Layer Normalization'
https://arxiv.org/pdf/1607.06450.pdf
Special args:
ln_preact: whether to Layer Normalize the pre-activations.
learnable: whether the LN alpha & gamma should be used.
"""
def __init__(
self,
input_size: int,
hidden_size: int,
bias: bool = True,
dropout: float = 0.0,
dropout_method: str = "pytorch",
ln_preact: bool = True,
learnable: bool = True,
):
super(LayerNormLSTM, self).__init__(
input_size=input_size, hidden_size=hidden_size, bias=bias, dropout=dropout, dropout_method=dropout_method
)
if ln_preact:
self.ln_i2h = LayerNorm(4 * hidden_size, learnable=learnable)
self.ln_h2h = LayerNorm(4 * hidden_size, learnable=learnable)
self.ln_preact = ln_preact
self.ln_cell = LayerNorm(hidden_size, learnable=learnable)
def forward(self, x: th.Tensor, hidden: Tuple[th.Tensor, th.Tensor]) -> Tuple[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
do_dropout = self.training and self.dropout > 0.0
h, c = hidden
h = h.view(h.size(1), -1)
c = c.view(c.size(1), -1)
x = x.view(x.size(1), -1)
# Linear mappings
i2h = self.i2h(x)
h2h = self.h2h(h)
if self.ln_preact:
i2h = self.ln_i2h(i2h)
h2h = self.ln_h2h(h2h)
preact = i2h + h2h
# activations
gates = preact[:, : 3 * self.hidden_size].sigmoid()
g_t = preact[:, 3 * self.hidden_size :].tanh()
i_t = gates[:, : self.hidden_size]
f_t = gates[:, self.hidden_size : 2 * self.hidden_size]
o_t = gates[:, -self.hidden_size :]
# cell computations
if do_dropout and self.dropout_method == "semeniuta":
g_t = F.dropout(g_t, p=self.dropout, training=self.training)
c_t = th.mul(c, f_t) + th.mul(i_t, g_t)
if do_dropout and self.dropout_method == "moon":
c_t.data.set_(th.mul(c_t, self.mask).data)
c_t.data *= 1.0 / (1.0 - self.dropout)
c_t = self.ln_cell(c_t)
h_t = th.mul(o_t, c_t.tanh())
# Reshape for compatibility
if do_dropout:
if self.dropout_method == "pytorch":
F.dropout(h_t, p=self.dropout, training=self.training, inplace=True)
if self.dropout_method == "gal":
h_t.data.set_(th.mul(h_t, self.mask).data)
h_t.data *= 1.0 / (1.0 - self.dropout)
h_t = h_t.view(1, h_t.size(0), -1)
c_t = c_t.view(1, c_t.size(0), -1)
return h_t, (h_t, c_t)
class LayerNormGalLSTM(LayerNormLSTM):
"""
Mixes GalLSTM's Dropout with Layer Normalization
"""
def __init__(self, *args, **kwargs):
super(LayerNormGalLSTM, self).__init__(*args, **kwargs)
self.dropout_method = "gal"
self.sample_mask()
class LayerNormMoonLSTM(LayerNormLSTM):
"""
Mixes MoonLSTM's Dropout with Layer Normalization
"""
def __init__(self, *args, **kwargs):
super(LayerNormMoonLSTM, self).__init__(*args, **kwargs)
self.dropout_method = "moon"
self.sample_mask()
class LayerNormSemeniutaLSTM(LayerNormLSTM):
"""
Mixes SemeniutaLSTM's Dropout with Layer Normalization
"""
def __init__(self, *args, **kwargs):
super(LayerNormSemeniutaLSTM, self).__init__(*args, **kwargs)
self.dropout_method = "semeniuta"