-
Notifications
You must be signed in to change notification settings - Fork 0
/
Vision Transformer.py
301 lines (202 loc) · 8.78 KB
/
Vision Transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import torch
import torch.nn as nn
class PatchEmbedding(nn.Module):
"""
Splits Image into patches and embeds them
Parameters
----------
img_size : int
Size of the input image.
patch_size : int
Size of the the patch.
emb_dim : int, optional
Embedding Dimensions. The default is 768.
Returns
-------
torch.Tensor of Shape(num_samples, num_patches, emb_dim).
"""
def __init__(self, img_size, patch_size, emb_dim = 768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
self.projection = nn.Conv2d(in_channels = 3, out_channels = emb_dim,
kernel_size = patch_size, stride = patch_size)
def forward(self, x):
x = self.projection(x)
x = x.flatten(2)
x = x.transpose(1,2)
return x
class SelfAttention(nn.Module):
"""
Implementation of Self-Attention mechanism.
Parameters
----------
dim : int
The input and output dimensions of tokens.
num_heads : int, optional
Number of attention heads. The default is 12.
bias : bool, optional
If True then includes bias in the qkv projections. The default is True.
attn_dropout_p : float, optional
Dropout Probablility applied to qkv tensors. The default is 0..
proj_dropout_p : float, optional
Dropout probability applied to output tensor. The default is 0..
Returns
-------
torch.Tensor of shape(num_samples, num_patches + 1, dim)
"""
def __init__(self, dim, num_heads = 12, bias = True, attn_dropout_p = 0., proj_dropout_p = 0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias = bias)
self.attn_dropout = nn.Dropout(attn_dropout_p)
self.projection = nn.Linear(dim, dim)
self.proj_dropout = nn.Dropout(proj_dropout_p)
def forward(self, x):
num_samples, num_tokens, dim = x.shape
if dim != self.dim:
raise ValueError
qkv = self.qkv(x)
qkv = qkv.reshape(num_samples, num_tokens, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2,0,3,1,4)
q, k, v = qkv[0], qkv[1], qkv[2]
key_t = k.transpose(-2,-1)
dot_prod = (q @ key_t) * self.scale
attn = dot_prod.softmax(dim = -1)
attn = self.attn_dropout(attn)
self_attn = attn @ v
self_attn = self_attn.transpose(1,2).flatten(2)
x = self.projection(self_attn)
x = self.proj_dropout(x)
return x
class MLP(nn.Module):
"""
Multilayer Perceptron.
Linear -> GELU -> Dropout -> Linear -> Dropout
Parameters
----------
in_c : int
Number of input channels.
hidden_c : int
Number of hidden channels.
out_c : int
Number of output channels.
dropout_p : float, optional
Dropout probablity. The default is 0.0.
Returns
-------
torch.Tensor of shape(num_samples, num_patches + 1, out_c)
"""
def __init__(self, in_c, hidden_c, out_c, dropout_p = 0.0):
super().__init__()
self.fc1 = nn.Linear(in_c, hidden_c)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(hidden_c, out_c)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x):
x = self.gelu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""
Transformer Block.
Input -> LayerNorm -> Self-Attention --> LayerNorm -> MLP -> Output
' +' +
'------------------------------''-------------------'
Parameters
----------
dim : int
Embedding Dimension.
num_heads : int
Number of Transformer Blocks.
ratio : int, optional
Determines the hidden dimensions/channels of MLP modules. The default is 4.0.
bias : bool, optional
If True then includes bias to qkv. The default is True.
dropout_p : float, optional
Dropout probablity . The default is 0.0.
attn_dropout_p : float, optional
Dropout probability for Self-Attention Module. The default is 0.0.
Returns
-------
torch.Tensor of shape(num_samples, num_patches + 1, dim)
"""
def __init__(self, dim, num_heads, ratio = 4.0, bias = True, dropout_p = 0.0, attn_dropout_p = 0.0):
super().__init__()
self.norm = nn.LayerNorm(dim, eps = 1e-6)
self.attn = SelfAttention(dim, num_heads = num_heads, bias = bias)
hidden_c = int(dim * ratio)
self.mlp = MLP(dim, hidden_c, dim)
def forward(self, x):
x = x + self.attn(self.norm(x))
x = x + self.mlp(self.norm(x))
return x
class VisionTransformer(nn.Module):
"""
Implementation of Vision Transformer(ViT).
Parameters
----------
img_size : int, optional
Size of the input image. The default is 384.
patch_size : int, optional
Size of the patches. The default is 16.
num_classes : int, optional
Total Number of classes in training data. The default is 1000.
emb_dim : int, optional
Embedding dimension. The default is 768.
depth : int, optional
Total number of Transformer blocks. The default is 12.
num_heads : iny, optional
Number of attention blocks. The default is 12.
ratio : float, optional
determines hidden channels of MLP. The default is 4..
bias : bool, optional
If True then includes bias to qkv. The default is True.
dropout_p : float, optional
Dropout Probability. The default is 0.0.
attn_dropout_p : float, optional
Dropout probability for attention. The default is 0.0.
Returns
-------
torch.Tensor of shape(num_samples, num_classes)
"""
def __init__(self, img_size = 384, patch_size = 16, num_classes = 1000, emb_dim = 768, depth = 12,
num_heads = 12, ratio = 4., bias = True, dropout_p = 0.0, attn_dropout_p = 0.0):
super().__init__()
self.patch_embedding = PatchEmbedding(img_size, patch_size, emb_dim)
self.pos_embedding = nn.Parameter(torch.zeros(1, 1 + self.patch_embedding.num_patches, emb_dim))
self.pos_dropout = nn.Dropout(dropout_p)
self.cls_tokens = nn.Parameter(torch.zeros(1, 1, emb_dim))
self.ViTBlocks = nn.ModuleList([TransformerBlock(dim = emb_dim,
num_heads = num_heads)
for _ in range(depth)
])
self.norm = nn.LayerNorm(emb_dim, eps = 1e-6)
self.cls_head = nn.Linear(emb_dim, num_classes)
def forward(self, x):
num_samples = x.shape[0]
x = self.patch_embedding(x)
cls_tokens = self.cls_tokens.expand(num_samples, -1, -1)
x = torch.cat([cls_tokens, x], dim = 1)
x = x + self.pos_embedding
x = self.pos_dropout(x)
for block in self.ViTBlocks:
x = block(x)
x = self.norm(x)
cls_final_token = x[:, 0]
x = self.cls_head(cls_final_token)
return x
def test():
vit = VisionTransformer()
x = torch.randn((1, 3, 384,384))
out = vit(x)
print(out.shape)
print(vit)
if __name__ == '__main__':
test()