-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist.py
176 lines (135 loc) · 4.81 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import _pickle as cPickle
import gzip
import wget
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pyjet.models import SLModel
from pyjet.data import NpDataset
import pyjet.backend as J
from pyjet.layers import Conv2D, MaxPooling2D, FullyConnected, Input
from pyjet.callbacks import ModelCheckpoint, Plotter, OneCycleScheduler
from pyjet.hooks import model_sizes
from pyjet.metrics import Metric
import logging
logging.basicConfig(level=logging.INFO)
# Load the dataset
try:
f = gzip.open("mnist_py3k.pkl.gz", "rb")
except OSError:
print("Could not find MNIST, downloading the dataset")
wget.download("http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist_py3k.pkl.gz")
f = gzip.open("mnist_py3k.pkl.gz", "rb")
(xtr, ytr), (xval, yval), (xte, yte) = cPickle.load(f)
# Need to convert to keras format
f.close()
xtr = xtr.reshape((-1, 1, 28, 28)) # Should be (Height, Width, Channel)
xval = xval.reshape((-1, 1, 28, 28)) # Should be (Height, Width, Channel)
xte = xte.reshape((-1, 1, 28, 28)) # Should be (Height, Width, Channel)
print("Maximum Pixel value in training set:", np.max(xtr))
print("Training Data Shape:", xtr.shape)
print("Training Labels Shape:", ytr.shape)
print("Validation Data Shape: ", xval.shape)
print("Validation Labels Shape: ", yval.shape)
# Visualize an image
ind = np.random.randint(xtr.shape[0])
# plt.imshow(xtr[ind, 0, :, :], cmap='gray')
# plt.title("Digit = %s" % ytr[ind])
# plt.show()
# Create the model
class MNISTModel(SLModel):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = Conv2D(
20, kernel_size=5, activation="relu", batchnorm=True, input_batchnorm=True
)
self.conv2 = Conv2D(30, kernel_size=5, activation="relu", batchnorm=True)
self.mp = MaxPooling2D(2)
self.fc1 = FullyConnected(50, activation="relu")
self.fc2 = FullyConnected(10)
print(model_sizes(self, Input(1, 28, 28)))
self.infer_inputs(Input(1, 28, 28))
def forward(self, x):
x = self.conv1(x)
x = self.mp(x)
x = self.conv2(x)
x = self.mp(x)
x = J.flatten(x)
x = self.fc1(x)
self.loss_in = self.fc2(x)
return F.softmax(self.loss_in, dim=-1)
model = MNISTModel()
model.add_loss(nn.CrossEntropyLoss())
# This will save the best scoring model weights to the current directory
best_model = ModelCheckpoint(
"mnist_pyjet" + ".state",
monitor="val_accuracy",
mode="max",
verbose=1,
save_best_only=True,
)
# This will plot the model's accuracy during training
plotter = Plotter(scale="linear", monitor="accuracy")
# Turn the numpy dataset into a BatchGenerator
train_datagen = NpDataset(xtr, y=ytr).flow(batch_size=64, shuffle=True, seed=1234)
# Turn the val data into a BatchGenerator
val_datagen = NpDataset(xval, y=yval).flow(batch_size=1000, shuffle=True, seed=1234)
# Set up the optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
model.add_optimizer(optimizer)
# Add the LR scheduler
one_cycle = OneCycleScheduler(
optimizer, (1e-4, 1e-2), (0.95, 0.85), train_datagen.steps_per_epoch * 5
)
class LR(Metric):
def __init__(self, onecycle):
super().__init__()
self.onecycle = onecycle
def __call__(self, y_pred, y_true):
return self.score(y_pred, y_true)
def score(self, y_pred, y_true):
return J.tensor(self.onecycle.lr)
def accumulate(self):
return self.onecycle.lr
def reset(self):
return self
class Momentum(Metric):
def __init__(self, onecycle):
super().__init__()
self.onecycle = onecycle
def __call__(self, y_pred, y_true):
return self.score(y_pred, y_true)
def score(self, y_pred, y_true):
return J.tensor(self.onecycle.momentum)
def accumulate(self):
return self.onecycle.momentum
def reset(self):
return self
# Fit the model
model.fit_generator(
train_datagen,
epochs=10,
steps_per_epoch=train_datagen.steps_per_epoch,
validation_data=val_datagen,
validation_steps=val_datagen.steps_per_epoch,
metrics=[LR(one_cycle), Momentum(one_cycle), "accuracy", "top3_accuracy"],
callbacks=[one_cycle, best_model, plotter],
)
# Load the best model
model = MNISTModel()
print("Loading the model")
model.load_state("mnist_pyjet.state")
# Test it on the test set
test_datagen = NpDataset(xte).flow(batch_size=1000, shuffle=False)
test_preds = model.predict_generator(test_datagen, test_datagen.steps_per_epoch)
num_test = xte.shape[0]
# Visualize an image and its prediction
while True:
ind = np.random.randint(xte.shape[0])
plt.imshow(xte[ind, 0, :, :], cmap="gray")
test_pred = test_preds[ind]
plt.title("Prediction = %s" % np.argmax(test_pred))
plt.show()