-
Notifications
You must be signed in to change notification settings - Fork 1
/
XLNet.py
279 lines (221 loc) · 10.2 KB
/
XLNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# modified from the original code of https://github.com/huggingface/transformers
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.init import xavier_uniform_
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
class XLNetRelativeAttention(nn.Module):
def __init__(self, d_model, n_head, dropout=0.1):
super().__init__()
if d_model % n_head != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (d_model, n_head)
)
self.output_attentions = False
self.n_head = n_head
self.d_head = d_model // n_head
self.d_model = d_model
self.scale = 1 / (self.d_head ** 0.5)
self.q = nn.Parameter(torch.FloatTensor(d_model, self.n_head, self.d_head))
self.k = nn.Parameter(torch.FloatTensor(d_model, self.n_head, self.d_head))
self.v = nn.Parameter(torch.FloatTensor(d_model, self.n_head, self.d_head))
self.o = nn.Parameter(torch.FloatTensor(d_model, self.n_head, self.d_head))
self.r = nn.Parameter(torch.FloatTensor(d_model, self.n_head, self.d_head))
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_s_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head))
self.seg_embed = nn.Parameter(torch.FloatTensor(2, self.n_head, self.d_head))
self.layer_norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
xavier_uniform_(self.q)
xavier_uniform_(self.k)
xavier_uniform_(self.v)
xavier_uniform_(self.o)
xavier_uniform_(self.r)
xavier_normal_(self.r_r_bias)
xavier_normal_(self.r_w_bias)
def prune_heads(self, heads):
raise NotImplementedError
@staticmethod
def rel_shift(x, klen=-1):
"""perform relative shift to form the relative attention score."""
x_size = x.shape
x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
# x = x[:, 0:klen, :, :]
x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))
return x
@staticmethod
def rel_shift_bnij(x, klen=-1):
x_size = x.shape
x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
# Note: the tensor-slice form was faster in my testing than torch.index_select
# However, tracing doesn't like the nature of the slice, and if klen changes
# during the run then it'll fail, whereas index_select will be fine.
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
# x = x[:, :, :, :klen]
return x
def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None, head_mask=None):
"""Core relative positional attention operations."""
# content based attention score
ac = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_w_bias, k_head_h)
# position based attention score
bd = torch.einsum("ibnd,jbnd->bnij", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift_bnij(bd, klen=ac.shape[3])
# segment based attention score
if seg_mat is None:
ef = 0
else:
ef = torch.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = torch.einsum("ijbs,ibns->bnij", seg_mat, ef)
# merge attention scores and perform masking
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if attn_mask.dtype == torch.float16:
attn_score = attn_score - 65500 * torch.einsum("ijbn->bnij", attn_mask)
else:
attn_score = attn_score - 1e30 * torch.einsum("ijbn->bnij", attn_mask)
# attention probability
attn_prob = F.softmax(attn_score, dim=3)
attn_prob = self.dropout(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * torch.einsum("ijbn->bnij", head_mask)
# attention output
attn_vec = torch.einsum("bnij,jbnd->ibnd", attn_prob, v_head_h)
if self.output_attentions:
return attn_vec, torch.einsum("bnij->ijbn", attn_prob)
return attn_vec
def post_attention(self, h, attn_vec, residual=True):
"""Post-attention processing."""
# post-attention projection (back to `d_model`)
attn_out = torch.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out)
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
return output
def forward(self, h, g, r, attn_mask_h=None, attn_mask_g=None, seg_mat=None, mems=None, target_mapping=None, head_mask=None):
if g is not None:
# Two-stream attention with relative positional encoding.
# content based attention score
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
cat = h
# content-based key head
k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
# content-based value head
v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
# position-based key head
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
# h-stream
# content-stream query head
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
# core attention ops
attn_vec_h = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
)
if self.output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
# post processing
output_h = self.post_attention(h, attn_vec_h)
# g-stream
# query-stream query head
q_head_g = torch.einsum("ibh,hnd->ibnd", g, self.q)
# core attention ops
if target_mapping is not None:
q_head_g = torch.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = torch.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask
)
if self.output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
# post processing
output_g = self.post_attention(g, attn_vec_g)
if self.output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
# Multi-head attention with relative positional encoding
if mems is not None and mems.dim() > 1:
cat = torch.cat([mems, h], dim=0)
else:
cat = h
# content heads
q_head_h = torch.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = torch.einsum("ibh,hnd->ibnd", cat, self.v)
# positional heads
k_head_r = torch.einsum("ibh,hnd->ibnd", r, self.r)
# core attention ops
attn_vec = self.rel_attn_core(
q_head_h, k_head_h, v_head_h, k_head_r, seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask
)
if self.output_attentions:
attn_vec, attn_prob = attn_vec
# post processing
output_h = self.post_attention(h, attn_vec)
output_g = None
outputs = (output_h, output_g)
if self.output_attentions:
outputs = outputs + (attn_prob,)
return outputs
class XLNetFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.layer_norm = XLNetLayerNorm(config.d_model, eps=config.layer_norm_eps)
self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout)
if isinstance(config.ff_activation, str):
self.activation_function = ACT2FN[config.ff_activation]
else:
self.activation_function = config.ff_activation
def forward(self, inp):
output = inp
output = self.layer_1(output)
output = self.activation_function(output)
output = self.dropout(output)
output = self.layer_2(output)
output = self.dropout(output)
output = self.layer_norm(output + inp)
return output
class XLNetLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.rel_attn = XLNetRelativeAttention(config)
self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout)
def forward(
self, output_h, output_g, attn_mask_h, attn_mask_g, r, seg_mat, mems=None, target_mapping=None, head_mask=None
):
outputs = self.rel_attn(
output_h,
output_g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems=mems,
target_mapping=target_mapping,
head_mask=head_mask,
)
output_h, output_g = outputs[:2]
if output_g is not None:
output_g = self.ff(output_g)
output_h = self.ff(output_h)
outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
return outputs