-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathresmlp.py
120 lines (94 loc) · 3.51 KB
/
resmlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from .ops import blocks
from .utils import export, config, load_from_local_or_url
from typing import Any
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
return self.alpha * x + self.beta
class ResMlpBlock(nn.Module):
def __init__(
self,
hidden_dim,
sequence_len,
layerscale_init: float = 1e-4,
dropout_rate: float = 0.,
drop_path_rate: float = 0.
):
super().__init__()
self.affine_1 = Affine(hidden_dim)
self.linear_patches = nn.Linear(sequence_len, sequence_len)
self.layerscale_1 = nn.Parameter(layerscale_init * torch.ones(hidden_dim))
self.drop1 = blocks.StochasticDepth(1.0 - drop_path_rate)
self.affine_2 = Affine(hidden_dim)
self.mlp_channels = blocks.MlpBlock(hidden_dim, hidden_dim * 4, dropout_rate=dropout_rate)
self.layerscale_2 = nn.Parameter(layerscale_init * torch.ones(hidden_dim))
self.drop2 = blocks.StochasticDepth(1.0 - drop_path_rate)
def forward(self, x):
x = x + self.drop1(self.layerscale_1 * self.linear_patches(self.affine_1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop2(self.layerscale_2 * self.mlp_channels(self.affine_2(x)))
return x
@export
class ResMLP(nn.Module):
def __init__(
self,
image_size: int = 224,
in_channels: int = 3,
num_classes: int = 1000,
patch_size: int = 32,
hidden_dim: int = 768,
depth: int = 12,
dropout_rate: float = 0.,
drop_path_rate: float = 0.,
**kwargs: Any
):
super().__init__()
num_patches = (image_size // patch_size) ** 2
self.stem = nn.Conv2d(in_channels, hidden_dim,
kernel_size=patch_size, stride=patch_size)
self.blocks = nn.Sequential(
*[ResMlpBlock(
hidden_dim,
num_patches,
dropout_rate=dropout_rate,
drop_path_rate=drop_path_rate
) for _ in range(depth)]
)
self.affine = Affine(hidden_dim)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.stem(x)
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = self.affine(x)
x = x.mean(dim=1)
x = self.classifier(x)
return x
def _resmlp(
image_size: int = 224,
patch_size: int = 16,
hidden_dim: int = 768,
depth: int = 12,
pretrained: bool = False,
pth: str = None,
progress: bool = True,
**kwargs: Any
):
model = ResMLP(image_size, patch_size=patch_size,
hidden_dim=hidden_dim, depth=depth, **kwargs)
if pretrained:
load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
return model
@export
def resmlp_s12_224(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _resmlp(224, 16, 384, 12, pretrained, pth, progress, **kwargs)
@export
def resmlp_s24_224(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _resmlp(224, 16, 384, 24, pretrained, pth, progress, **kwargs)
@export
def resmlp_b24_224(pretrained: bool = False, pth: str = None, progress: bool = True, **kwargs: Any):
return _resmlp(224, 16, 768, 24, pretrained, pth, progress, **kwargs)