-
Notifications
You must be signed in to change notification settings - Fork 321
/
t2t_vit.py
663 lines (580 loc) · 26.6 KB
/
t2t_vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
T2T-ViT Transformer in Paddle
A Paddle Implementation of Tokens-to-Token ViT (T2T-ViT) as describted in:
"Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"
- Paper Link: https://arxiv.org/abs/2101.11986
"""
import math
import numpy as np
import paddle
import paddle.nn as nn
from droppath import DropPath
from utils import orthogonal
class Identity(nn.Layer):
""" Identity layer
The output of this layer is the input without any change.
Use this layer to avoid if condition in some forward methods
"""
def forward(self, x):
return x
class PatchEmbedding(nn.Layer):
"""Patch embedding layer
Apply patch embeddings (tokens-to-token) on input images. Embeddings is
implemented using one of the following ops: Performer, Transformer.
Attributes:
image_size: int, input image size, default: 224
token_type: string, type of token embedding, in ['performer', 'transformer', 'convolution'], default: 'performer'
patch_size: int, size of patch, default: 4
in_channels: int, input image channels, default: 3
embed_dim: int, embedding dimension, default: 96
token_dim: int, intermediate dim for patch_embedding module, default: 64
"""
def __init__(self,
image_size=224,
token_type='performer',
in_channels=3,
embed_dim=768,
token_dim=64):
super().__init__()
if token_type == 'transformer':
# paddle v 2.1 has bugs on nn.Unfold,
# use paddle.nn.functional.unfold method instead
# replacements see forward method.
#self.soft_split0 = nn.Unfold(kernel_size=7, strides=4, paddings=2)
#self.soft_split1 = nn.Unfold(kernel_size=3, strides=2, paddings=1)
#self.soft_split2 = nn.Unfold(kernel_size=3, strides=2, paddings=1)
self.attn1 = TokenTransformer(dim=in_channels * 7 * 7,
in_dim=token_dim,
num_heads=1,
mlp_ratio=1.0)
self.attn2 = TokenTransformer(dim=token_dim * 3 * 3,
in_dim=token_dim,
num_heads=1,
mlp_ratio=1.0)
w_attr_1, b_attr_1 = self._init_weights() # init for linear
self.proj = nn.Linear(token_dim * 3 * 3,
embed_dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
elif token_type == 'performer':
# paddle v 2.1 has bugs on nn.Unfold,
# use paddle.nn.functional.unfold method instead
# replacements see forward method.
#self.soft_split0 = nn.Unfold(kernel_sizes=7, strides=4, paddings=2)
#self.soft_split1 = nn.Unfold(kernel_sizes=3, strides=2, paddings=1)
#self.soft_split2 = nn.Unfold(kernel_sizes=3, strides=2, paddings=1)
self.attn1 = TokenPerformer(dim=in_channels * 7 * 7,
in_dim=token_dim,
kernel_ratio=0.5)
self.attn2 = TokenPerformer(dim=token_dim * 3 * 3,
in_dim=token_dim,
kernel_ratio=0.5)
w_attr_1, b_attr_1 = self._init_weights() # init for linear
self.proj = nn.Linear(token_dim * 3 * 3,
embed_dim,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
elif token_type == 'convolution': # NOTE: currently not supported!!!
# 1st conv
self.soft_split0 = nn.Conv2D(in_channels=in_channels,
out_channels=token_dim,
kernel_size=7,
stride=4,
padding=2)
# 2nd conv
self.soft_split1 = nn.Conv2D(in_channels=token_dim,
out_channels=token_dim,
kernel_size=3,
stride=2,
padding=1)
# 3rd conv
self.proj = nn.Conv2D(in_channels=token_dim,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1)
else:
raise ValueError(f'token_type: {token_type} is not supported!')
# 3 soft splits, each has stride 4, 2, 2, respectively.
self.num_patches = (image_size // (4 * 2 * 2)) * (image_size // (4 * 2 * 2))
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
# x = self.soft_split0(x)
# input x: [B, C, IMAGE_H, IMAGE_W]
x = paddle.nn.functional.unfold(x, kernel_sizes=7, strides=4, paddings=2)
# unfolded x: [B, C * k * k, k * k * num_patches]
x = x.transpose([0, 2, 1])
# transposed x: [B, k * k * num_patches, C * k * k]
x = self.attn1(x)
B, HW, C = x.shape
x = x.transpose([0, 2, 1])
x = x.reshape([B, C, int(np.sqrt(HW)), int(np.sqrt(HW))])
#x = self.soft_split1(x)
x = paddle.nn.functional.unfold(x, kernel_sizes=3, strides=2, paddings=1)
x = x.transpose([0, 2, 1])
x = self.attn2(x)
B, HW, C = x.shape
x = x.transpose([0, 2, 1])
x = x.reshape([B, C, int(np.sqrt(HW)), int(np.sqrt(HW))])
#x = self.soft_split2(x)
x = paddle.nn.functional.unfold(x, kernel_sizes=3, strides=2, paddings=1)
x = x.transpose([0, 2, 1])
x = self.proj(x)
return x
class Mlp(nn.Layer):
""" MLP module
Impl using nn.Linear and activation is GELU, dropout is applied.
Ops: fc -> act -> dropout -> fc -> dropout
Attributes:
fc1: nn.Linear
fc2: nn.Linear
act: GELU
dropout1: dropout after fc1
dropout2: dropout after fc2
"""
def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
w_attr_1, b_attr_1 = self._init_weights()
self.fc1 = nn.Linear(in_features,
hidden_features,
weight_attr=w_attr_1,
bias_attr=b_attr_1)
w_attr_2, b_attr_2 = self._init_weights()
self.fc2 = nn.Linear(hidden_features,
out_features,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class Attention(nn.Layer):
""" Self-Attention
Args:
dim: int, all heads dimension
dim_head: int, single heads dimension, default: None
num_heads: int, num of heads
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
qk_scale: float, if None, qk_scale is dim_head ** -0.5, default: None
attention_dropout: float, dropout rate for attention dropout, default: 0.
dropout: float, dropout rate for projection dropout, default: 0.
skip_connection: bool, if Ture, use v to do skip connection, used in TokenTransformer
"""
def __init__(self,
dim,
in_dim=None,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attention_dropout=0.,
dropout=0.,
skip_connection=False):
super().__init__()
self.num_heads = num_heads
self.in_dim = in_dim or dim
self.dim_head = dim // num_heads
self.scale = qk_scale or self.dim_head ** -0.5
# same as original repo
w_attr_1, b_attr_1 = self._init_weights() # init for linear
self.qkv = nn.Linear(dim,
self.in_dim * 3,
weight_attr=w_attr_1,
bias_attr=b_attr_1 if qkv_bias else False)
self.attn_dropout = nn.Dropout(attention_dropout)
w_attr_2, b_attr_2 = self._init_weights() # init for linear
self.proj = nn.Linear(self.in_dim,
self.in_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.proj_dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(axis=-1)
# use V to do skip connection, used in TokenTransformer
self.skip = skip_connection
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def transpose_multihead(self, x):
if self.skip: # token transformer
#new_shape = x.shape[:-1] + [self.num_heads, self.in_dim]
new_shape = tuple(x.shape[:-1]) + ([self.num_heads, self.in_dim])
else: # regular attention
#new_shape = x.shape[:-1] + [self.num_heads, self.dim_head]
new_shape = tuple(x.shape[:-1]) + (self.num_heads, self.dim_head)
x = x.reshape(new_shape)
x = x.transpose([0, 2, 1, 3])
return x
def forward(self, x):
B, H, C = x.shape
qkv = self.qkv(x).chunk(3, axis=-1)
q, k, v = map(self.transpose_multihead, qkv)
q = q * self.scale
attn = paddle.matmul(q, k, transpose_y=True)
attn = self.softmax(attn)
attn = self.attn_dropout(attn)
z = paddle.matmul(attn, v)
z = z.transpose([0, 2, 1, 3])
if self.skip: # token transformer
z = z.reshape([B, -1, self.in_dim])
else: # regular attention
z = z.reshape([B, H, C])
#z = z.reshape([B, -1, C])
z = self.proj(z)
z = self.proj_dropout(z)
# skip connection
if self.skip:
z = z + v.squeeze(1)
return z
class Block(nn.Layer):
""" Transformer block layers
Transformer block layers contains regular self-attention layers,
mlp layers, norms layers and residual blocks.
Args:
dim: int, all heads dimension
num_heads: int, num of heads
mlp_ratio: ratio to multiply on mlp input dim as mlp hidden dim, default: 4.
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
qk_scale: float, scale factor to replace dim_head ** -0.5, default: None
dropout: float, dropout rate for projection dropout, default: 0.
attention_dropout: float, dropout rate for attention dropout, default: 0.
droppath: float, drop path rate, default: 0.
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
dropout=0.,
attention_dropout=0.,
droppath=0.):
super().__init__()
w_attr_1, b_attr_1 = self._init_weights_layernorm() # init for layernorm
self.norm1 = nn.LayerNorm(dim, epsilon=1e-6, weight_attr=w_attr_1, bias_attr=b_attr_1)
self.attn = Attention(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=dropout,
attention_dropout=attention_dropout)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
w_attr_2, b_attr_2 = self._init_weights_layernorm() # init for layernorm
self.norm2 = nn.LayerNorm(dim, epsilon=1e-6, weight_attr=w_attr_2, bias_attr=b_attr_2)
self.mlp = Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
dropout=dropout)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
h = x
x = self.norm1(x)
x = self.attn(x)
x = self.drop_path(x)
x = h + x
h = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = h + x
return x
class TokenPerformer(nn.Layer):
""" Token Performer layers
Performer layers contains single-attention layers,
mlp layers, norms layers and residual blocks. This module
is used in 'tokens-to-token', which converts image into tokens
and gradually tokenized the tokens.
Args:
dim: int, all heads dimension
in_dim: int, qkv and out dimension in attention
num_heads: int, num of heads
kernel_ratio: ratio to multiply on prm input dim, default: 0.5.
dropout: float, dropout rate for projection dropout, default: 0.
"""
def __init__(self, dim, in_dim, num_heads=1, kernel_ratio=0.5, dropout=0.1):
super().__init__()
self.embed_dim = in_dim * num_heads
w_attr_1, b_attr_1 = self._init_weights() # init for linear
self.kqv = nn.Linear(dim, 3 * self.embed_dim, weight_attr=w_attr_1, bias_attr=b_attr_1)
self.dropout = nn.Dropout(dropout)
w_attr_2, b_attr_2 = self._init_weights() # init for linear
self.proj = nn.Linear(self.embed_dim,
self.embed_dim,
weight_attr=w_attr_2,
bias_attr=b_attr_2)
self.num_heads = num_heads
w_attr_3, b_attr_3 = self._init_weights_layernorm() # init for layernorm
w_attr_4, b_attr_4 = self._init_weights_layernorm() # init for layernorm
self.norm1 = nn.LayerNorm(dim, epsilon=1e-6, weight_attr=w_attr_3, bias_attr=b_attr_3)
self.norm2 = nn.LayerNorm(self.embed_dim, epsilon=1e-6, weight_attr=w_attr_4, bias_attr=b_attr_4)
w_attr_5, b_attr_5 = self._init_weights() # init for linear
w_attr_6, b_attr_6 = self._init_weights() # init for linear
self.mlp = nn.Sequential(nn.Linear(self.embed_dim,
self.embed_dim,
weight_attr=w_attr_5,
bias_attr=b_attr_5),
nn.GELU(),
nn.Linear(self.embed_dim,
self.embed_dim,
weight_attr=w_attr_6,
bias_attr=b_attr_6),
nn.Dropout(dropout))
self.m = int(self.embed_dim * kernel_ratio)
self.w = np.random.random(size=(int(self.embed_dim * kernel_ratio), self.embed_dim))
# init with orthognal matrix
self.w = orthogonal(self.w)
self.w = paddle.create_parameter(
shape=[int(self.embed_dim * kernel_ratio), self.embed_dim],
dtype='float32',
default_initializer=nn.initializer.Assign(self.w / math.sqrt(self.m)))
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
# paddle version 2.1 does not support einsum
def prm_exp(self, x):
# x: [B, T, hs]
# w: [m, hs]
# return x: B, T, m
xd = (x * x).sum(axis=-1, keepdim=True)
#xd = xd.expand([xd.shape[0], xd.shape[1], self.m]) / 2
xd = xd.expand([paddle.shape(xd)[0], paddle.shape(xd)[1], self.m]) / 2
# same as einsum('bti,mi->btm', x, self.w)
wtx = paddle.matmul(x, self.w, transpose_y=True)
out = paddle.exp(wtx - xd) / math.sqrt(self.m)
return out
def single_attention(self, x):
kqv = self.kqv(x).chunk(3, axis=-1)
k, q, v = kqv[0], kqv[1], kqv[2]
qp = self.prm_exp(q)
kp = self.prm_exp(k)
# same as einsum('bti,bi->bt, qp, kp.sum(axi=1).unsqueeze(2)')
D = paddle.matmul(qp, kp.sum(axis=1).unsqueeze(2))
# same as einsum('bin,bim->bnm')
kptv = paddle.matmul(v, kp, transpose_x=True)
# same as einsum('bti,bni->btn')
y = paddle.matmul(qp, kptv, transpose_y=True)
y = y / (D.expand([paddle.shape(D)[0], paddle.shape(D)[1], self.embed_dim]) + 1e-8)
#y = y / (D.expand([D.shape[0], D.shape[1], self.embed_dim]) + 1e-8)
# skip connection
y = self.proj(y)
y = self.dropout(y)
y = v + y
return y
def forward(self, x):
x = self.norm1(x)
x = self.single_attention(x)
h = x
x = self.norm2(x)
x = self.mlp(x)
x = h + x
return x
class TokenTransformer(nn.Layer):
""" Token Transformer layers
Transformer layers contains regular self-attention layers,
mlp layers, norms layers and residual blocks. This module
is used in 'tokens-to-token', which converts image into tokens
and gradually tokenized the tokens.
Args:
dim: int, all heads dimension
in_dim: int, qkv and out dimension in attention
num_heads: int, num of heads
mlp_ratio: ratio to multiply on mlp input dim as mlp hidden dim, default: 1.
qkv_bias: bool, if True, qkv linear layer is using bias, default: False
qk_scale: float, scale factor to replace dim_head ** -0.5, default: None
dropout: float, dropout rate for projection dropout, default: 0.
attention_dropout: float, dropout rate for attention dropout, default: 0.
droppath: float, drop path rate, default: 0.
"""
def __init__(self,
dim,
in_dim,
num_heads,
mlp_ratio=1.0,
qkv_bias=False,
qk_scale=None,
dropout=0.,
attention_dropout=0,
droppath=0.):
super().__init__()
w_attr_1, b_attr_1 = self._init_weights_layernorm()
self.norm1 = nn.LayerNorm(dim, epsilon=1e-6, weight_attr=w_attr_1, bias_attr=b_attr_1)
self.attn = Attention(dim,
in_dim=in_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=dropout,
attention_dropout=attention_dropout,
skip_connection=True)
self.drop_path = DropPath(droppath) if droppath > 0. else Identity()
w_attr_2, b_attr_2 = self._init_weights_layernorm()
self.norm2 = nn.LayerNorm(in_dim, epsilon=1e-6, weight_attr=w_attr_2, bias_attr=b_attr_2)
self.mlp = Mlp(in_features=in_dim,
hidden_features=int(in_dim * mlp_ratio),
out_features=in_dim,
dropout=dropout)
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1.0))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0.0))
return weight_attr, bias_attr
def forward(self, x):
x = self.norm1(x)
x = self.attn(x)
h = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = h + x
return x
class T2TViT(nn.Layer):
""" T2T-ViT model
Args:
image_size: int, input image size, default: 224
in_channels: int, input image channels, default: 3
num_classes: int, num of classes, default: 1000
token_type: string, type of token embedding ['performer', 'transformer'], default: 'performer'
embed_dim: int, dim of each patch after patch embedding, default: 768
depth: int, num of self-attention blocks, default: 12
num_heads: int, num of attention heads, default: 12
mlp_ratio: float, mlp hidden dim = mlp_ratio * mlp_in_dim, default: 4.
qkv_bias: bool, if True, qkv projection is set with bias, default: True
qk_scale: float, scale factor to replace dim_head ** -0.5, default: None
dropout: float, dropout rate for linear projections, default: 0.
attention_dropout: float, dropout rate for attention, default: 0.
droppath: float, drop path rate, default: 0.
token_dim: int, intermediate dim for patch_embedding module, default: 64
"""
def __init__(self,
image_size=224,
in_channels=3,
num_classes=1000,
token_type='performer',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
dropout=0.,
attention_dropout=0.,
droppath=0,
token_dim=64):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
# convert image to paches: T2T-Module
self.patch_embed = PatchEmbedding(image_size=image_size,
token_type=token_type,
in_channels=in_channels,
embed_dim=embed_dim,
token_dim=token_dim)
num_patches = self.patch_embed.num_patches
# tokens add for classification
self.cls_token = paddle.create_parameter(
shape=[1, 1, embed_dim],
dtype='float32',
default_initializer=nn.initializer.Constant(0.0))
# positional embeddings for patch positions
self.pos_embed = paddle.create_parameter(
shape=[1, num_patches + 1, embed_dim],
dtype='float32',
default_initializer=nn.initializer.Constant(0.0))
# dropout for positional embeddings
self.pos_dropout = nn.Dropout(dropout)
# droppath deacay rate
depth_decay = paddle.linspace(0, droppath, depth)
# craete self-attention layers
layer_list = []
for i in range(depth):
layer_list.append(Block(dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
dropout=dropout,
attention_dropout=attention_dropout,
droppath=depth_decay[i]))
self.blocks = nn.LayerList(layer_list)
w_attr_1, b_attr_1 = self._init_weights_layernorm()
self.norm = nn.LayerNorm(embed_dim, epsilon=1e-6, weight_attr=w_attr_1, bias_attr=b_attr_1)
# classifier head
w_attr_2, b_attr_2 = self._init_weights()
self.head = nn.Linear(embed_dim,
num_classes,
weight_attr=w_attr_2,
bias_attr=b_attr_2) if num_classes > 0 else Identity()
def _init_weights_layernorm(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def _init_weights(self):
weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
return weight_attr, bias_attr
def forward_features(self, x):
# Patch Embedding
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand([paddle.shape(x)[0], 1, self.embed_dim])
#cls_tokens = self.cls_token.expand([x.shape[0], -1, -1])
x = paddle.concat([cls_tokens, x], axis=1)
x = x + self.pos_embed
x = self.pos_dropout(x)
# Self-Attention blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x[:, 0] # returns only cls_tokens
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def build_t2t_vit(config):
"""build t2t-vit model using config"""
model = T2TViT(image_size=config.DATA.IMAGE_SIZE,
in_channels=3,
num_classes=config.MODEL.NUM_CLASSES,
token_type=config.MODEL.TOKEN_TYPE,
embed_dim=config.MODEL.EMBED_DIM,
depth=config.MODEL.DEPTH,
num_heads=config.MODEL.NUM_HEADS,
mlp_ratio=config.MODEL.MLP_RATIO,
qk_scale=config.MODEL.QK_SCALE,
qkv_bias=config.MODEL.QKV_BIAS,
dropout=config.MODEL.DROPOUT,
attention_dropout=config.MODEL.ATTENTION_DROPOUT,
droppath=config.MODEL.DROPPATH,
token_dim=64)
return model