-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathSublayers.py
342 lines (263 loc) · 14.5 KB
/
Sublayers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# Filename: Sublayers.py
# Date Created: 15-Mar-2019 2:42:12 pm
# Description: Sublayer functions used for attention mechanism.
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
def shape_list(x):
"""
Return list of dims.
"""
shape = list(x.shape)
return shape
def _relative_position_to_absolute_position_masked(x):
"""Helper function for dot_product_self_attention_relative
Rearrange attention logits or weights tensor.
Dimensions of input represents:
[batch, heads, query_position, memory_position - query_position + length - 1]
Dimensions of output represents:
[batch, heads, query_position, memory_position]
Only works with masked attention.
Args:
x: a Tensor with shape [batch, heads, length, length]
Returns:
a Tensor with shape [batch, heads, length, length]
"""
batch, heads, length, _ = shape_list(x)
x = F.pad(x, (1, 0, 0, 0, 0, 0, 0, 0))
x = torch.reshape(x, (batch, heads, 1 + length, length))
x = x[0:x.shape[0] - 0, 0:x.shape[1] - 0, 1:x.shape[2], 0:x.shape[3] - 0]
return x
def matmul_with_relative_keys(x, y, heads_share_relative_embedding):
if heads_share_relative_embedding:
ret = torch.einsum("bhld,md -> bhlm", x, y)
else:
ret = torch.einsum("bhld,hmd -> bhlm", x, y)
return ret
def matmul_with_relative_time_pitch(x, y):
ret = torch.einsum("bhld,mmd -> bhlm", x, y)
return ret
def get_relative_embeddings_pitch_time(max_relative_position, length, depth,
relative_time_embeddings = None,
relative_pitch_embeddings = None):
"""Instantiate or retrieve relative embeddings, sliced according to length
Use for masked case where the relative attention is only looking left
Args:
max_relative_position: an Integer for the number of entries in the relative
embedding, which corresponds to the max relative distance that is
considered.
length: an Integer, specifies the length of the input sequence for which
this relative embedding is retrieved for.
depth: an Integer, specifies the depth for relative embeddings.
relative_time_embeddings: relative embeddings for time, if not present instantiates one
relative_pitch_embeddings: relative embeddings for pitch, if not present instantiates one
"""
initializer_stddev = depth ** -0.5
embedding_shape = (max_relative_position, max_relative_position, depth)
if relative_time_embeddings is None:
relative_time_embeddings = Variable(torch.from_numpy(np.random.normal\
(0.0, initializer_stddev, embedding_shape).astype('f')))
if relative_pitch_embeddings is None:
relative_pitch_embeddings = Variable(torch.from_numpy(np.random.normal\
(0.0, initializer_stddev, embedding_shape).astype('f')))
pad_length = max(length - max_relative_position, 0)
slice_start_position = max(max_relative_position - length, 0)
padded_relative_time_embeddings = F.pad(
relative_time_embeddings,
(0, 0, pad_length, 0, pad_length, 0))
used_relative_time_embeddings = padded_relative_time_embeddings[
slice_start_position:length,
slice_start_position:slice_start_position + length,
0:(padded_relative_time_embeddings.shape[2] - 0)
]
padded_relative_pitch_embeddings = F.pad(
relative_pitch_embeddings,
(0, 0, pad_length, 0, pad_length, 0))
used_relative_pitch_embeddings = padded_relative_pitch_embeddings[
slice_start_position:slice_start_position + length,
slice_start_position:slice_start_position + length,
0:(padded_relative_pitch_embeddings.shape[2] - 0)
]
return used_relative_time_embeddings, used_relative_pitch_embeddings, relative_time_embeddings, relative_pitch_embeddings
def get_relative_embeddings_left(max_relative_position, length, depth,
num_heads,
heads_share_relative_embedding,
relative_embeddings = None):
"""Instantiate or retrieve relative embeddings, sliced according to length
Use for masked case where the relative attention is only looking left
Args:
max_relative_position: an Integer for the number of entries in the relative
embedding, which corresponds to the max relative distance that is
considered.
length: an Integer, specifies the length of the input sequence for which
this relative embedding is retrieved for.
depth: an Integer, specifies the depth for relative embeddings.
num_heads: an Integer, specifies the number of heads.
heads_share_relative_embedding: a Boolean specifying if the relative
embedding is shared across heads.
"""
initializer_stddev = depth ** -0.5
if heads_share_relative_embedding:
embedding_shape = (max_relative_position, depth)
else:
embedding_shape = (num_heads, max_relative_position, depth)
if relative_embeddings is None:
relative_embeddings = Variable(torch.from_numpy(np.random.normal(0.0, initializer_stddev, embedding_shape).astype('f')))
pad_length = max(length - max_relative_position, 0)
slice_start_position = max(max_relative_position - length, 0)
if heads_share_relative_embedding:
padded_relative_embeddings = F.pad(
relative_embeddings,
(0, 0, pad_length, 0))
used_relative_embeddings = padded_relative_embeddings[slice_start_position:slice_start_position + length,
0:(padded_relative_embeddings.shape[1] - 0)]
else:
padded_relative_embeddings = F.pad(
relative_embeddings,
(0, 0, pad_length, 0, 0, 0))
used_relative_embeddings = padded_relative_embeddings[
0:(padded_relative_embeddings.shape[0] - 0),
slice_start_position:slice_start_position + length,
0:(padded_relative_embeddings.shape[2] - 0)
]
return used_relative_embeddings, relative_embeddings
def dot_product_self_attention_relative(q,
k,
v,
mask = None,
bias = None,
max_relative_position = None,
dropout = None,
heads_share_relative_embedding = False,
relative_embeddings = None,
relative_time_pitch = False,
relative_time_embeddings = None,
relative_pitch_embeddings = None):
if not max_relative_position:
raise ValueError("Max relative position (%s) should be > 0 when using "
"relative self attention." % (max_relative_position))
# Use separate embeddings suitable for keys and values.
_, heads, length, depth_k = shape_list(k)
logits = torch.matmul(q, k.transpose(-2, -1))
if mask is not None:
mask = mask.unsqueeze(1) #shape of mask must be broadcastable with shape of underlying tensor
logits = logits.masked_fill(mask == 0, -1e9) #masked_fill fills elements of scores with -1e9 where mask == 0
key_relative_embeddings, relative_embeddings = get_relative_embeddings_left(
max_relative_position, length, depth_k, heads, heads_share_relative_embedding, relative_embeddings)
key_relative_embeddings = key_relative_embeddings.to(q.device)
relative_logits = matmul_with_relative_keys(q, key_relative_embeddings,
heads_share_relative_embedding)
relative_logits = _relative_position_to_absolute_position_masked(relative_logits) #[1, 8, 1023, 1024]
if relative_time_pitch == True:
to_use_time_relative_embeddings, to_use_pitch_relative_embeddings,\
relative_time_embeddings, relative_pitch_embeddings \
= get_relative_embeddings_pitch_time(max_relative_position, length,
depth_k,
relative_time_embeddings,
relative_pitch_embeddings)
relative_time_pitch_sum = (to_use_time_relative_embeddings + to_use_pitch_relative_embeddings).to(q.device)
relative_time_pitch_term = matmul_with_relative_time_pitch(q, relative_time_pitch_sum)
relative_logits = relative_logits + relative_time_pitch_term
logits += relative_logits
if bias is not None:
logits += bias
weights = F.softmax(logits, dim = -1)
# Dropping out the attention links for each of the heads.
if dropout is not None:
weights = dropout(weights)
output = torch.matmul(weights, v)
return output, relative_embeddings, relative_time_embeddings, relative_pitch_embeddings
else:
logits += relative_logits
if bias is not None:
logits += bias
weights = F.softmax(logits, dim = -1)
# Dropping out the attention links for each of the heads.
if dropout is not None:
weights = dropout(weights)
output = torch.matmul(weights, v)
return output, relative_embeddings
def attention(q, v, k, d_k, mask = None, dropout = None):
scores = torch.matmul(q, k.transpose(-2, -1))/ math.sqrt(d_k)
if mask is not None:
#mask = mask.unsqueeze(1) #shape of mask must be broadcastable with shape of underlying tensor
scores = scores.masked_fill(mask == 0, -1e9) #masked_fill fills elements of scores with -1e9 where mask == 0
scores = F.softmax(scores, dim = -1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout = 0.0, attention_type = "Baseline",
bias = None,
max_relative_position = 512,
heads_share_relative_embedding = False,
relative_time_pitch = False):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads #final dimension = d_model/N as we split embedding vec into N heads
self.h = heads #number of heads
self.attention_type = attention_type
self.bias = bias
self.max_relative_position = max_relative_position
self.heads_share_relative_embedding = heads_share_relative_embedding
self.relative_time_pitch = relative_time_pitch
self.relative_embeddings = None
self.relative_time_embeddings = None
self.relative_pitch_embeddings = None
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
bs = q.size(0) #batch size
#original size bs * seq_len * h * d_k
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# transpose to get dimensions of bs * h * seq_len * d_k
k = k.transpose(1,2) # torch.Size([512, 3, 8, 64]) transpose will result in torch.Size([512, 8, 3, 64])
q = q.transpose(1,2)
v = v.transpose(1,2)
# calculate attention using defined attention function
if self.attention_type == "Baseline":
scores = attention(q, k, v, self.d_k, mask, self.dropout)
else:
if self.relative_time_pitch:
scores, self.relative_embeddings,\
self.relative_time_embeddings,\
self.relative_pitch_embeddings = dot_product_self_attention_relative(q, k, v, mask,
self.bias,
self.max_relative_position,
self.dropout,
self.heads_share_relative_embedding,
self.relative_embeddings,
self.relative_time_pitch,
self.relative_time_embeddings,
self.relative_pitch_embeddings)
else:
scores, self.relative_embeddings = dot_product_self_attention_relative(q, k, v, mask,
self.bias,
self.max_relative_position,
self.dropout,
self.heads_share_relative_embedding,
self.relative_embeddings)
#concatenate heads and put through final linear layer
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
return output.unsqueeze(1)
class Norm(nn.Module):
def __init__(self, d_model, eps = 1e-6):
super().__init__()
self.size = d_model
#create two learnable parameters to calibrate normalisation
self.alpha = nn.Parameter(torch.ones(self.size))
self.bias = nn.Parameter(torch.ones(self.size))
self.eps = eps
def forward(self, x):
norm = self.alpha * (x - x.mean(dim = 2, keepdim = True)) / (x.std(dim = 2, keepdim = True) + self.eps) + self.bias
return norm