-
Notifications
You must be signed in to change notification settings - Fork 8
/
vit.py
142 lines (110 loc) · 4.79 KB
/
vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import timm.models.vision_transformer
from functools import partial
from iopath.common.file_io import PathManagerFactory
import torch
import torch.nn as nn
pathmgr = PathManagerFactory.get()
# modified based on mvp by Tete Xiao, thanks!
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer
referene:
- MAE: https://github.com/facebookresearch/mae/blob/main/models_vit.py
- timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
def __init__(self, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
# remove the classifier
if hasattr(self, "pre_logits"):
del self.pre_logits
del self.head
def extract_feat(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
# x = x[:, 0].detach().float()
return x[:, 1:]
def forward_norm(self, x):
return self.norm(x)
def forward(self, x):
return self.forward_norm(self.extract_feat(x))
def freeze(self):
self.pos_embed.requires_grad = False
self.cls_token.requires_grad = False
def _freeze_module(m):
for p in m.parameters():
p.requires_grad = False
_freeze_module(self.patch_embed)
_freeze_module(self.blocks)
trainable_params = []
for name, p in self.named_parameters():
if p.requires_grad:
trainable_params.append(name)
#print("Trainable parameters in the encoder:")
#print(trainable_params)
def vit_s16(pretrained, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
assert os.path.exists(pretrained) or pretrained.startswith("none")
# load from checkpoint
if not pretrained.startswith("none"):
load_checkpoint(pretrained, model)
# print("Loaded encoder from: {}".format(pretrained))
hidden_dim = 384
return model, hidden_dim
def vit_scratch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
hidden_dim = 384
return model, hidden_dim
def vit_scratch_base16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
hidden_dim = 768
return model, hidden_dim
def vit_b16(pretrained, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
assert os.path.exists(pretrained) or pretrained.startswith("none")
# load from checkpoint
if not pretrained.startswith("none"):
load_checkpoint(pretrained, model)
print("Loaded encoder from: {}".format(pretrained))
hidden_dim = 768
return model, hidden_dim
def vit_l16(pretrained, **kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
assert os.path.exists(pretrained) or pretrained.startswith("none")
# load from checkpoint
if not pretrained.startswith("none"):
load_checkpoint(pretrained, model)
print("Loaded encoder from: {}".format(pretrained))
hidden_dim = 1024
return model, hidden_dim
def unwrap_model(model):
"""Remove the DistributedDataParallel wrapper if present."""
wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel)
return model.module if wrapped else model
def load_checkpoint(checkpoint_file, model):
"""Loads a checkpoint selectively based on the input options."""
assert pathmgr.exists(checkpoint_file), "Checkpoint '{}' not found".format(
checkpoint_file
)
with pathmgr.open(checkpoint_file, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
state_dict = checkpoint["model"]
# torch.save({'model': checkpoint,}, "/home/zoeyc/github/reality_gym/pretrained/vit_base_patch16_224.pth")
r = unwrap_model(model).load_state_dict(state_dict, strict=False)
if r.unexpected_keys or r.missing_keys:
print(f"Loading weights, unexpected keys: {r.unexpected_keys}")
print(f"Loading weights, missing keys: {r.missing_keys}")