-
Notifications
You must be signed in to change notification settings - Fork 0
/
mae.py
130 lines (93 loc) · 4.91 KB
/
mae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
# Time : 2022/9/17 17:19
# Author : Regulus
# FileName: mae.py
# Explain: Implementation of MAE
# Software: PyCharm
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from vit_pytorch import Transformer
from vit_pytorch import ViT
class MAE(nn.Module):
def __init__(
self,
*,
encoder,
decoder_dim,
masking_ratio = 0.75,
decoder_depth = 1,
decoder_heads = 8,
decoder_dim_head = 64
):
super().__init__()
assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'
self.masking_ratio = masking_ratio
# extract some hyperparameters and functions from encoder (vision transformer to be trained)
self.encoder = encoder
num_patches, encoder_dim = encoder.pos_embedding.shape[-2:] # encoder.pos_embedding.shape (n, num_patches, encoder_dim)
self.to_patch, self.patch_to_emb = encoder.to_patch_embedding[:2]
pixel_values_per_patch = self.patch_to_emb.weight.shape[-1] # self.patch_to_emb.weight.shape (encoder_dim, patch_size * patch_size * 3)
# decoder parameters
self.decoder_dim = decoder_dim
self.enc_to_dec = nn.Linear(encoder_dim, decoder_dim) if encoder_dim != decoder_dim else nn.Identity()
self.mask_token = nn.Parameter(torch.randn(decoder_dim))
self.decoder = Transformer(dim = decoder_dim, depth = decoder_depth, heads = decoder_heads, dim_head = decoder_dim_head, mlp_dim = decoder_dim * 4)
self.decoder_pos_emb = nn.Embedding(num_patches, decoder_dim)
self.to_pixels = nn.Linear(decoder_dim, pixel_values_per_patch)
def forward(self, img):
device = img.device # img.shape (8, 3, 256, 256)
# get patches
patches = self.to_patch(img) # patches.shape (8, 196, 32*32*3)
batch, num_patches, *_ = patches.shape
# patch to encoder tokens and add positions
tokens = self.patch_to_emb(patches) # tokens.shape (8, 196, 1024)
tokens = tokens + self.encoder.pos_embedding[:, 1:(num_patches + 1)] # self.encoder.pos_embedding.shape (1, 196, 1024)
# calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
num_masked = int(self.masking_ratio * num_patches)
rand_indices = torch.rand(batch, num_patches, device = device).argsort(dim = -1)
masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
# get the unmasked tokens to be encoded
batch_range = torch.arange(batch, device = device)[:, None]
tokens = tokens[batch_range, unmasked_indices]
# get the patches to be masked for the final reconstruction loss
masked_patches = patches[batch_range, masked_indices]
# attend with vision transformer
encoded_tokens = self.encoder.transformer(tokens)
# project encoder to decoder dimensions, if they are not equal - the paper says you can get away with a smaller dimension for decoder
decoder_tokens = self.enc_to_dec(encoded_tokens) # decoder_tokens.shape (8, 49, 512)
# reapply decoder position embedding to unmasked tokens
unmasked_decoder_tokens = decoder_tokens + self.decoder_pos_emb(unmasked_indices)
# repeat mask tokens for number of masked, and add the positions using the masked indices derived above
mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked) # mask_tokens.shape (512,) -> (8, 147, 512)
mask_tokens = mask_tokens + self.decoder_pos_emb(masked_indices)
# concat the masked tokens to the decoder tokens and attend with decoder
decoder_tokens = torch.zeros(batch, num_patches, self.decoder_dim, device=device) # decoder_tokens.shape (8, 196, 512)
decoder_tokens[batch_range, unmasked_indices] = unmasked_decoder_tokens
decoder_tokens[batch_range, masked_indices] = mask_tokens
decoded_tokens = self.decoder(decoder_tokens)
# splice out the mask tokens and project to pixel values
mask_tokens = decoded_tokens[batch_range, masked_indices] # # decoder_tokens.shape (8, 157, 512)
pred_pixel_values = self.to_pixels(mask_tokens) # self.to_pixels: nn.Linear(512, 768)
# calculate reconstruction loss
recon_loss = F.mse_loss(pred_pixel_values, masked_patches)
return recon_loss
if __name__ == "__main__":
v = ViT(
image_size = 224,
patch_size = 16,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)
mae = MAE(
encoder = v,
masking_ratio = 0.75, # the paper recommended 75% masked patches
decoder_dim = 512, # paper showed good results with just 512
decoder_depth = 6 # anywhere from 1 to 8
)
images = torch.randn(8, 3, 224, 224)
loss = mae(images)