-
Notifications
You must be signed in to change notification settings - Fork 46
/
model.py
181 lines (143 loc) · 6.68 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import torch
import timm
import numpy as np
from einops import repeat, rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block
def random_indexes(size : int):
forward_indexes = np.arange(size)
np.random.shuffle(forward_indexes)
backward_indexes = np.argsort(forward_indexes)
return forward_indexes, backward_indexes
def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))
class PatchShuffle(torch.nn.Module):
def __init__(self, ratio) -> None:
super().__init__()
self.ratio = ratio
def forward(self, patches : torch.Tensor):
T, B, C = patches.shape
remain_T = int(T * (1 - self.ratio))
indexes = [random_indexes(T) for _ in range(B)]
forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
patches = take_indexes(patches, forward_indexes)
patches = patches[:remain_T]
return patches, forward_indexes, backward_indexes
class MAE_Encoder(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
num_layer=12,
num_head=3,
mask_ratio=0.75,
) -> None:
super().__init__()
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
self.shuffle = PatchShuffle(mask_ratio)
self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
self.layer_norm = torch.nn.LayerNorm(emb_dim)
self.init_weight()
def init_weight(self):
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.pos_embedding, std=.02)
def forward(self, img):
patches = self.patchify(img)
patches = rearrange(patches, 'b c h w -> (h w) b c')
patches = patches + self.pos_embedding
patches, forward_indexes, backward_indexes = self.shuffle(patches)
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, 't b c -> b t c')
features = self.layer_norm(self.transformer(patches))
features = rearrange(features, 'b t c -> t b c')
return features, backward_indexes
class MAE_Decoder(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
num_layer=4,
num_head=3,
) -> None:
super().__init__()
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
self.init_weight()
def init_weight(self):
trunc_normal_(self.mask_token, std=.02)
trunc_normal_(self.pos_embedding, std=.02)
def forward(self, features, backward_indexes):
T = features.shape[0]
backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
features = take_indexes(features, backward_indexes)
features = features + self.pos_embedding
features = rearrange(features, 't b c -> b t c')
features = self.transformer(features)
features = rearrange(features, 'b t c -> t b c')
features = features[1:] # remove global feature
patches = self.head(features)
mask = torch.zeros_like(patches)
mask[T-1:] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
img = self.patch2img(patches)
mask = self.patch2img(mask)
return img, mask
class MAE_ViT(torch.nn.Module):
def __init__(self,
image_size=32,
patch_size=2,
emb_dim=192,
encoder_layer=12,
encoder_head=3,
decoder_layer=4,
decoder_head=3,
mask_ratio=0.75,
) -> None:
super().__init__()
self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)
def forward(self, img):
features, backward_indexes = self.encoder(img)
predicted_img, mask = self.decoder(features, backward_indexes)
return predicted_img, mask
class ViT_Classifier(torch.nn.Module):
def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
super().__init__()
self.cls_token = encoder.cls_token
self.pos_embedding = encoder.pos_embedding
self.patchify = encoder.patchify
self.transformer = encoder.transformer
self.layer_norm = encoder.layer_norm
self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)
def forward(self, img):
patches = self.patchify(img)
patches = rearrange(patches, 'b c h w -> (h w) b c')
patches = patches + self.pos_embedding
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, 't b c -> b t c')
features = self.layer_norm(self.transformer(patches))
features = rearrange(features, 'b t c -> t b c')
logits = self.head(features[0])
return logits
if __name__ == '__main__':
shuffle = PatchShuffle(0.75)
a = torch.rand(16, 2, 10)
b, forward_indexes, backward_indexes = shuffle(a)
print(b.shape)
img = torch.rand(2, 3, 32, 32)
encoder = MAE_Encoder()
decoder = MAE_Decoder()
features, backward_indexes = encoder(img)
print(forward_indexes.shape)
predicted_img, mask = decoder(features, backward_indexes)
print(predicted_img.shape)
loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
print(loss)