forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 4
/
serious_mnist.py
136 lines (120 loc) · 4.54 KB
/
serious_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
#!/usr/bin/env python
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
import sys
import numpy as np
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d, optim
from tinygrad.helpers import getenv
from extra.datasets import fetch_mnist
from extra.augment import augment_img
from extra.training import train, evaluate
GPU = getenv("GPU")
QUICK = getenv("QUICK")
DEBUG = getenv("DEBUG")
class SqueezeExciteBlock2D:
def __init__(self, filters):
self.filters = filters
self.weight1 = Tensor.scaled_uniform(self.filters, self.filters//32)
self.bias1 = Tensor.scaled_uniform(1,self.filters//32)
self.weight2 = Tensor.scaled_uniform(self.filters//32, self.filters)
self.bias2 = Tensor.scaled_uniform(1, self.filters)
def __call__(self, input):
se = input.avg_pool2d(kernel_size=(input.shape[2], input.shape[3])) #GlobalAveragePool2D
se = se.reshape(shape=(-1, self.filters))
se = se.dot(self.weight1) + self.bias1
se = se.relu()
se = se.dot(self.weight2) + self.bias2
se = se.sigmoid().reshape(shape=(-1,self.filters,1,1)) #for broadcasting
se = input.mul(se)
return se
class ConvBlock:
def __init__(self, h, w, inp, filters=128, conv=3):
self.h, self.w = h, w
self.inp = inp
#init weights
self.cweights = [Tensor.scaled_uniform(filters, inp if i==0 else filters, conv, conv) for i in range(3)]
self.cbiases = [Tensor.scaled_uniform(1, filters, 1, 1) for i in range(3)]
#init layers
self._bn = BatchNorm2d(128)
self._seb = SqueezeExciteBlock2D(filters)
def __call__(self, input):
x = input.reshape(shape=(-1, self.inp, self.w, self.h))
for cweight, cbias in zip(self.cweights, self.cbiases):
x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu()
x = self._bn(x)
x = self._seb(x)
return x
class BigConvNet:
def __init__(self):
self.conv = [ConvBlock(28,28,1), ConvBlock(28,28,128), ConvBlock(14,14,128)]
self.weight1 = Tensor.scaled_uniform(128,10)
self.weight2 = Tensor.scaled_uniform(128,10)
def parameters(self):
if DEBUG: #keeping this for a moment
pars = [par for par in get_parameters(self) if par.requires_grad]
no_pars = 0
for par in pars:
print(par.shape)
no_pars += np.prod(par.shape)
print('no of parameters', no_pars)
return pars
else:
return get_parameters(self)
def save(self, filename):
with open(filename+'.npy', 'wb') as f:
for par in get_parameters(self):
#if par.requires_grad:
np.save(f, par.numpy())
def load(self, filename):
with open(filename+'.npy', 'rb') as f:
for par in get_parameters(self):
#if par.requires_grad:
try:
par.numpy()[:] = np.load(f)
if GPU:
par.gpu()
except:
print('Could not load parameter')
def forward(self, x):
x = self.conv[0](x)
x = self.conv[1](x)
x = x.avg_pool2d(kernel_size=(2,2))
x = self.conv[2](x)
x1 = x.avg_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
x2 = x.max_pool2d(kernel_size=(14,14)).reshape(shape=(-1,128)) #global
xo = x1.dot(self.weight1) + x2.dot(self.weight2)
return xo
if __name__ == "__main__":
lrs = [1e-4, 1e-5] if QUICK else [1e-3, 1e-4, 1e-5, 1e-5]
epochss = [2, 1] if QUICK else [13, 3, 3, 1]
BS = 32
lmbd = 0.00025
lossfn = lambda out,y: out.sparse_categorical_crossentropy(y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
X_train, Y_train, X_test, Y_test = fetch_mnist()
X_train = X_train.reshape(-1, 28, 28).astype(np.uint8)
X_test = X_test.reshape(-1, 28, 28).astype(np.uint8)
steps = len(X_train)//BS
np.random.seed(1337)
if QUICK:
steps = 1
X_test, Y_test = X_test[:BS], Y_test[:BS]
model = BigConvNet()
if len(sys.argv) > 1:
try:
model.load(sys.argv[1])
print('Loaded weights "'+sys.argv[1]+'", evaluating...')
evaluate(model, X_test, Y_test, BS=BS)
except:
print('could not load weights "'+sys.argv[1]+'".')
if GPU:
params = get_parameters(model)
[x.gpu_() for x in params]
for lr, epochs in zip(lrs, epochss):
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(1,epochs+1):
#first epoch without augmentation
X_aug = X_train if epoch == 1 else augment_img(X_train)
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
accuracy = evaluate(model, X_test, Y_test, BS=BS)
model.save(f'examples/checkpoint{accuracy * 1e6:.0f}')