-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinformer.py
338 lines (246 loc) · 10.6 KB
/
linformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import math
import torch
import torch.nn as nn
from operator import itemgetter
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
# for routing arguments into the functions of the reversible layer
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
def layer_drop(layers, prob):
to_drop = torch.empty(len(layers)).uniform_(0, 1) < prob
blocks = [block for block, drop in zip(layers, to_drop) if not drop]
blocks = layers[:1] if len(blocks) == 0 else blocks
return blocks
# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
class Deterministic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.cpu_state = None
self.cuda_in_fwd = None
self.gpu_devices = None
self.gpu_states = None
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng=False, set_rng=False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
return self.net(*args, **kwargs)
rng_devices = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
return self.net(*args, **kwargs)
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
# once multi-GPU is confirmed working, refactor and send PR back to source
class ReversibleBlock(nn.Module):
def __init__(self, f, g):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
def forward(self, x, f_args={}, g_args={}):
x1, x2 = torch.chunk(x, 2, dim=2)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=2)
def backward_pass(self, y, dy, f_args={}, g_args={}):
y1, y2 = torch.chunk(y, 2, dim=2)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=2)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1, retain_graph=True)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=2)
dx = torch.cat([dx1, dx2], dim=2)
return x, dx
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, args):
ctx.args = args
for block, kwarg in zip(blocks, args):
x = block(x, **kwarg)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
args = ctx.args
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route={}, layer_dropout=0.):
super().__init__()
assert all(len(route) == len(layers) for route in
args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
if self.training and self.layer_dropout > 0:
layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class ReversibleSequence(nn.Module):
def __init__(self, blocks, args_route={}, layer_dropout=0.):
super().__init__()
self.args_route = args_route
self.layer_dropout = layer_dropout
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
def forward(self, x, **kwargs):
x = torch.cat([x, x], dim=-1)
blocks = self.blocks
args = route_args(self.args_route, kwargs, len(blocks))
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
layers_and_args = list(zip(blocks, args))
if self.training and self.layer_dropout > 0:
layers_and_args = layer_drop(layers_and_args, self.layer_dropout)
blocks, args = map(lambda ind: list(map(itemgetter(ind), layers_and_args)), (0, 1))
out = _ReversibleFunction.apply(x, blocks, args)
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
# helper functions
def default(val, default_val):
return val if val is not None else default_val
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# helper classes
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return x + self.fn(x)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class GELU_(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0., activation=None, glu=False):
super().__init__()
activation = default(activation, GELU)
self.glu = glu
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
self.act = activation()
self.dropout = nn.Dropout(dropout)
self.w2 = nn.Linear(dim * mult, dim)
def forward(self, x, **kwargs):
if not self.glu:
x = self.w1(x)
x = self.act(x)
else:
x, v = self.w1(x).chunk(2, dim=-1)
x = self.act(x) * v
x = self.dropout(x)
x = self.w2(x)
return x
class LinformerSelfAttention(nn.Module):
def __init__(self, dim, seq_len, k=256, heads=8, dim_head=None, one_kv_head=False, share_kv=False, dropout=0.):
super().__init__()
assert (dim % heads) == 0, 'dimension must be divisible by the number of heads'
self.seq_len = seq_len
self.k = k
self.heads = heads
dim_head = default(dim_head, dim // heads)
self.dim_head = dim_head
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
kv_dim = dim_head if one_kv_head else (dim_head * heads)
self.to_k = nn.Linear(dim, kv_dim, bias=False)
self.proj_k = nn.Parameter(init_(torch.zeros(seq_len, k)))
self.share_kv = share_kv
if not share_kv:
self.to_v = nn.Linear(dim, kv_dim, bias=False)
self.proj_v = nn.Parameter(init_(torch.zeros(seq_len, k)))
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Linear(dim_head * heads, dim)
def forward(self, x, context=None, **kwargs):
b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k
kv_len = n if context is None else context.shape[1]
assert kv_len == self.seq_len, f'the sequence length of the key / values must be {self.seq_len} - {kv_len} given'
queries = self.to_q(x)
proj_seq_len = lambda args: torch.einsum('bnd,nk->bkd', *args)
kv_input = x if context is None else context
keys = self.to_k(kv_input)
values = self.to_v(kv_input) if not self.share_kv else keys
kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k)
# project keys and values along the sequence length dimension to k
keys, values = map(proj_seq_len, zip((keys, values), kv_projs))
# merge head into batch for queries and key / values
queries = queries.reshape(b, n, h, -1).transpose(1, 2)
merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
keys, values = map(merge_key_values, (keys, values))
# attention
dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5)
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
out = torch.einsum('bhnk,bhkd->bhnd', attn, values)
# split heads
out = out.transpose(1, 2).reshape(b, n, -1)
return self.to_out(out)
class Linformer(nn.Module):
def __init__(self, dim, seq_len, depth, k=256, heads=8, dim_head=None, one_kv_head=False, share_kv=False,
reversible=False, dropout=0.):
super().__init__()
layers = nn.ModuleList([])
for _ in range(depth):
attn = LinformerSelfAttention(dim, seq_len, k=k, heads=heads, dim_head=dim_head, one_kv_head=one_kv_head,
share_kv=share_kv, dropout=dropout)
ff = FeedForward(dim, dropout=dropout)
layers.append(nn.ModuleList([
PreNorm(dim, attn),
PreNorm(dim, ff)
]))
execute_type = ReversibleSequence if reversible else SequentialSequence
self.net = execute_type(layers)
def forward(self, x):
return self.net(x)