-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
272 lines (213 loc) · 10.2 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import jax_resnet
import jax
from flax import linen as nn
import jax.numpy as jnp
import optax
def get_loss(*, logits, labels, num_classes, l2_reg=False, params=None):
labels_one_hot = jax.nn.one_hot(labels, num_classes=num_classes)
if l2_reg:
return optax.softmax_cross_entropy(logits=logits, labels=labels_one_hot).mean() + l2_loss(params=params, alpha=0.00004)
else:
return optax.softmax_cross_entropy(logits=logits, labels=labels_one_hot).mean()
@jax.jit
def l2_loss(params, alpha):
loss = 0.0
for i in jax.tree_leaves(params):
loss += alpha * jax.lax.square(i).mean()
return loss
# -----RESNET50-----
def get_resnet(no_params=False, num_classes=10):
model = jax_resnet.ResNet50(n_classes=num_classes) # this is good implementation no need to make from start
if no_params:
return model
else:
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1,32,32,3)))
return params, model
# -----INCEPTION V4-----
class conv2d_bnorm(nn.Module):
nb_filter: int
num_row: int
num_col: int
padding: str = "SAME"
strides: Tuple[int,int] = (1,1)
use_bias: bool = False
@nn.compact
def __call__(self, x):
x = nn.Conv(self.nb_filter, (self.num_row, self.num_col),
strides=self.strides,
padding=self.padding,
use_bias=self.use_bias,
kernel_init=jax.nn.initializers.variance_scaling(scale=2.0,mode="fan_in",distribution="normal")
)(x)
x = nn.BatchNorm(use_running_average=False, axis=-1, momentum=0.9997, use_scale=False)(x)
return nn.relu(x)
class block_inception_a(nn.Module):
@nn.compact
def __call__(self, x):
branch_0 = conv2d_bnorm(nb_filter=96, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=64, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=96, num_row=3, num_col=3)(branch_1)
branch_2 = conv2d_bnorm(nb_filter=64, num_row=1, num_col=1)(x)
branch_2 = conv2d_bnorm(nb_filter=96, num_row=3, num_col=3)(branch_2)
branch_2 = conv2d_bnorm(nb_filter=96, num_row=3, num_col=3)(branch_2)
branch_3 = nn.avg_pool(x, window_shape=(3,3), strides=(1,1), padding="SAME")
branch_3 = conv2d_bnorm(nb_filter=96, num_row=1, num_col=1)(branch_3)
x = jax.lax.concatenate([branch_0,branch_1,branch_2,branch_3], 3)
return x
class block_reduction_a(nn.Module):
@nn.compact
def __call__(self, x):
branch_0 = conv2d_bnorm(nb_filter=384, num_row=3, num_col=3, strides=(2,2), padding="VALID")(x)
branch_1 = conv2d_bnorm(nb_filter=192, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=244, num_row=3, num_col=3)(branch_1)
branch_1 = conv2d_bnorm(nb_filter=256, num_row=3, num_col=3, strides=(2,2), padding="VALID")(branch_1)
branch_2 = nn.max_pool(x, window_shape=(3,3), strides=(2,2), padding="VALID")
x = jax.lax.concatenate([branch_0,branch_1,branch_2], 3)
return x
class block_inception_b(nn.Module):
@nn.compact
def __call__(self, x):
branch_0 = conv2d_bnorm(nb_filter=384, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=192, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=244, num_row=1, num_col=7)(branch_1)
branch_1 = conv2d_bnorm(nb_filter=256, num_row=7, num_col=1)(branch_1)
branch_2 = conv2d_bnorm(nb_filter=192, num_row=1, num_col=1)(x)
branch_2 = conv2d_bnorm(nb_filter=192, num_row=7, num_col=1)(branch_2)
branch_2 = conv2d_bnorm(nb_filter=244, num_row=1, num_col=7)(branch_2)
branch_2 = conv2d_bnorm(nb_filter=244, num_row=7, num_col=1)(branch_2)
branch_2 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=7)(branch_2)
branch_3 = nn.avg_pool(x, window_shape=(3,3), strides=(1,1), padding="SAME")
branch_3 = conv2d_bnorm(nb_filter=128, num_row=1, num_col=1)(branch_3)
x = jax.lax.concatenate([branch_0,branch_1,branch_2,branch_3], 3)
return x
class block_reduction_b(nn.Module):
@nn.compact
def __call__(self,x):
branch_0 = conv2d_bnorm(nb_filter=192, num_row=1, num_col=1)(x)
branch_0 = conv2d_bnorm(nb_filter=192, num_row=3, num_col=3, strides=(2,2), padding="VALID")(branch_0)
branch_1 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=7)(branch_1)
branch_1 = conv2d_bnorm(nb_filter=320,num_row=7,num_col=1)(branch_1)
branch_1 = conv2d_bnorm(nb_filter=320, num_row=3, num_col=3, strides=(2,2), padding="VALID")(branch_1)
branch_2 = nn.max_pool(x, window_shape=(3,3), strides=(2,2), padding="VALID")
x = jax.lax.concatenate([branch_0,branch_1,branch_2],3)
return x
class block_inception_c(nn.Module):
@nn.compact
def __call__(self,x):
branch_0 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=384, num_row=1, num_col=1)(x)
branch_1_0 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=3)(branch_1)
branch_1_1 = conv2d_bnorm(nb_filter=256, num_row=3, num_col=1)(branch_1)
branch_1 = jax.lax.concatenate([branch_1_0,branch_1_1],3)
branch_2 = conv2d_bnorm(nb_filter=384, num_row=1, num_col=1)(x)
branch_2 = conv2d_bnorm(nb_filter=448, num_row=3, num_col=1)(branch_2)
branch_2 = conv2d_bnorm(nb_filter=512, num_row=1, num_col=3)(branch_2)
branch_2_0 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=3)(branch_2)
branch_2_1 = conv2d_bnorm(nb_filter=256, num_row=3, num_col=1)(branch_2)
branch_2 = jax.lax.concatenate([branch_2_0,branch_2_1],3)
branch_3 = nn.avg_pool(x, window_shape=(3,3), strides=(1,1), padding="SAME")
branch_3 = conv2d_bnorm(nb_filter=256, num_row=1, num_col=1)(branch_3)
x = jax.lax.concatenate([branch_0,branch_1,branch_2,branch_3],3)
return x
class inception_v4_base(nn.Module):
@nn.compact
def __call__(self,x):
x = conv2d_bnorm(nb_filter=32, num_row=3, num_col=3, strides=(2,2), padding="VALID")(x)
x = conv2d_bnorm(nb_filter=32, num_row=3, num_col=3, padding="VALID")(x)
x = conv2d_bnorm(nb_filter=64, num_row=3, num_col=3)(x)
branch_0 = nn.max_pool(x, window_shape=(3,3), strides=(2,2), padding="VALID")
branch_1 = conv2d_bnorm(nb_filter=96, num_row=3, num_col=3, strides=(2,2), padding="VALID")(x)
x = jax.lax.concatenate([branch_0,branch_1],3)
branch_0 = conv2d_bnorm(nb_filter=64, num_row=1, num_col=1)(x)
branch_0 = conv2d_bnorm(nb_filter=96, num_row=3, num_col=3, padding="VALID")(branch_0)
branch_1 = conv2d_bnorm(nb_filter=64, num_row=1, num_col=1)(x)
branch_1 = conv2d_bnorm(nb_filter=64, num_row=1, num_col=7)(branch_1)
branch_1 = conv2d_bnorm(nb_filter=64, num_row=7, num_col=1)(branch_1)
branch_1 = conv2d_bnorm(nb_filter=96, num_row=3, num_col=3, padding="VALID")(branch_1)
x = jax.lax.concatenate([branch_0,branch_1],3)
branch_0 = conv2d_bnorm(nb_filter=192, num_row=3, num_col=3, strides=(2,2), padding="VALID")(x)
branch_1 = nn.max_pool(x, window_shape=(3,3), strides=(2,2), padding="VALID")
x = jax.lax.concatenate([branch_0,branch_1],3)
for _ in range(4):
x = block_inception_a()(x)
x = block_reduction_a()(x)
for _ in range(7):
x = block_inception_b()(x)
x = block_reduction_b()(x)
for _ in range(3):
x = block_inception_c()(x)
return x
class inception_v4(nn.Module):
num_classes: int
dropout_keep: float
@nn.compact
def __call__(self, x, training):
x = inception_v4_base()(x)
x = nn.avg_pool(x, window_shape=(8,8), padding="VALID")
x = nn.Dropout(rate = self.dropout_keep, deterministic=not training)(x)
x = x.reshape((x.shape[0], -1)) # should flatten
x = nn.Dense(features=self.num_classes)(x)
x = nn.softmax(x)
return x
def get_inception(no_params=False, num_classes=10):
model = inception_v4(num_classes=num_classes, dropout_keep=0.2) # this is good implementation no need to make from start
if no_params:
return model
else:
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1,299,299,3)),training=False) # will cause bug as has to be shape of image
return params, model
# -----VGG16-----
class VGG16(nn.Module):
num_classes:int
@nn.compact
def __call__(self,x):
x = nn.Conv(64, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(64, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3,3), strides=(2,2))
x = nn.Conv(128, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(128, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3,3), strides=(2,2))
x = nn.Conv(256, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(256, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(256, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3,3), strides=(2,2))
x = nn.Conv(512, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(512, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(512, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3,3), strides=(2,2))
x = nn.Conv(512, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(512, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.Conv(512, kernel_size=(3,3))(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3,3), strides=(2,2))
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=4096)(x)
x = nn.relu(x)
x = nn.Dense(features=4096)(x)
x = nn.relu(x)
x = nn.Dense(self.num_classes)(x)
x = nn.softmax(x)
return x
def get_vgg16(no_params=False, num_classes=10):
model = VGG16(num_classes=num_classes) # this is good implementation no need to make from start
if no_params:
return model
else:
key = jax.random.PRNGKey(1)
params = model.init(key, jnp.ones((1,244,244,3))) # will cause bug as has to be shape of image
return params, model