-
Notifications
You must be signed in to change notification settings - Fork 8
/
Models.py
130 lines (109 loc) · 5.56 KB
/
Models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Filename: Models.py
# Date Created: 16-Mar-2019 2:17:09 pm
# Description: Combine all sublayers into one tranformer model.
import torch
import torch.nn as nn
from Encode_Decode_Layers import EncoderLayer, DecoderLayer
from Embedding import Embedder, PositionalEncoder, PositionalEncoderConcat
from Sublayers import Norm
import torch.nn.functional as F
import copy
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Encoder(nn.Module):
def __init__(self, vocab_size, opt):
super().__init__()
self.N = opt.n_layers
self.embed = Embedder(vocab_size, opt.d_model)
if opt.concat_pos_sinusoid is True:
self.pe = PositionalEncoderConcat(opt.d_model, opt.dropout, opt.max_seq_len)
self.d_model = 2 * opt.d_model
else:
self.pe = PositionalEncoder(opt.d_model, opt.dropout, opt.max_seq_len)
self.d_model = opt.d_model
if opt.relative_time_pitch is True:
self.layers = get_clones(EncoderLayer(self.d_model, opt.heads, opt.d_ff, \
opt.dropout, opt.attention_type, \
opt.relative_time_pitch,
max_relative_position = opt.max_relative_position),
opt.n_layers)
self.layers.insert(0, copy.deepcopy(EncoderLayer(self.d_model, opt.heads, opt.d_ff, \
opt.dropout, opt.attention_type, \
relative_time_pitch = False,
max_relative_position = opt.max_relative_position)))
else:
self.layers = get_clones(EncoderLayer(self.d_model, opt.heads, opt.d_ff, \
opt.dropout, opt.attention_type, \
opt.relative_time_pitch,
max_relative_position = opt.max_relative_position),
opt.n_layers)
self.norm = Norm(self.d_model)
def forward(self, src, mask):
x = self.embed(src)
x = self.pe(x)
for i in range(self.N):
x = self.layers[i](x.float(), mask)
return self.norm(x)
class Decoder(nn.Module):
def __init__(self, vocab_size, opt):
super().__init__()
self.N = opt.n_layers
self.embed = Embedder(vocab_size, opt.d_model)
if opt.concat_pos_sinusoid is True:
self.pe = PositionalEncoderConcat(opt.d_model, opt.dropout, opt.max_seq_len)
self.d_model = 2 * opt.d_model
else:
self.pe = PositionalEncoder(opt.d_model, opt.dropout, opt.max_seq_len)
self.d_model = opt.d_model
if opt.relative_time_pitch is True:
self.layers = get_clones(DecoderLayer(self.d_model, opt.heads, opt.d_ff, \
opt.dropout, opt.attention_type, \
opt.relative_time_pitch,
max_relative_position = opt.max_relative_position),
opt.n_layers-1)
self.layers.insert(0, copy.deepcopy(DecoderLayer(self.d_model, opt.heads, opt.d_ff, \
opt.dropout, opt.attention_type, \
relative_time_pitch = False,
max_relative_position = opt.max_relative_position)))
else:
self.layers = get_clones(DecoderLayer(self.d_model, opt.heads, opt.d_ff, \
opt.dropout, opt.attention_type, \
opt.relative_time_pitch,
max_relative_position = opt.max_relative_position),
opt.n_layers)
self.norm = Norm(self.d_model)
def forward(self, trg, e_outputs, src_mask, trg_mask):
x = self.embed(trg)
x = self.pe(x)
# print(x.shape)
for i in range(self.N):
x = self.layers[i](x.float(), e_outputs, src_mask, trg_mask)
return self.norm(x)
class Transformer(nn.Module):
def __init__(self, src_vocab_size, trg_vocab_size, opt):
super().__init__()
self.encoder = Encoder(src_vocab_size, opt)
self.decoder = Decoder(trg_vocab_size, opt)
if opt.concat_pos_sinusoid is True:
self.d_model = 2 * opt.d_model
else:
self.d_model = opt.d_model
self.linear = nn.Linear(self.d_model, trg_vocab_size)
def forward(self, src, trg, src_mask, trg_mask):
e_outputs = self.encoder(src, src_mask)
d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
output = self.linear(d_output)
return output
def get_model(opt, vocab_size):
# Ensure the provided arguments are valid
assert opt.d_model % opt.heads == 0
assert opt.dropout < 1
print('Attention type: ' + opt.attention_type)
# Initailze the transformer model
model = Transformer(vocab_size, vocab_size, opt)
if opt.load_weights is not None:
print("loading pretrained weights...")
checkpoint = torch.load(f'{opt.load_weights}/' + opt.weights_name, map_location = 'cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(opt.device)
return model