-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
82 lines (70 loc) · 2.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import numpy as np
from bp.mnist import mnist
from bp.bpModel import bpNet
from bp.checkFile import checkFile
import time
import matplotlib.pyplot as plt
"""
# Tunable Parameter:
## 1) hiddenLayersSize=[100, 50] --bpNet
## 2) weightInitStd=0.01 --bpNet
## 3) learningrate in def:update --bpNet.update
## 4) batchsize=100
"""
def main(config,fileName = 'data'):
testFile(fileName)
Start = time.time()
msTrain = mnist()
msTest = mnist(kind="t10k")
# Extract Mnist data
trainImg = msTrain.images
trainLabel = msTrain.oneHotLabels
testImg = msTest.images
testLabel = msTest.oneHotLabels
# Training data size
trainSize = trainImg.shape[0]
# Training batch size
batchSize = config['batchSize']
# Number of iterations
itersNum = config['itersNum']
# Learning rate
learningRate = config['learningRate']
# Iter per epoch
iterPerEpoch = max((trainSize / batchSize), 1)
# Initialize bp
network = bpNet(hiddenLayersSize=config['hiddenLayersSize'])
# start training model
trainA=[]
testA=[]
print("iterPerEpoch:",iterPerEpoch)
#print("Start Model trianing...")
for index in range(itersNum):
if (index % iterPerEpoch == 0):
trainAcc = network.accuracy(trainImg, trainLabel)
testAcc = network.accuracy(testImg, testLabel)
trainA.append(trainAcc)
testA.append(testAcc)
print("Train Acc: %.5f Test Acc: %.5f" % (trainAcc, testAcc))
# Get randomly selected index
trainIndexs = np.random.choice(trainSize, batchSize)
imgs = trainImg[trainIndexs]
labels = trainLabel[trainIndexs]
network.update(imgs, labels, lr=learningRate)
End = time.time()
print("BP Time Consuming:%.2fs"%(End-Start))
# Plot
g = [trainA,testA]
colors=["blue","orange"]
labels=["trainAcc","testAcc"]
plt.figure(figsize=(6,6))
for i in range(len(g)):
plt.plot(g[i], colors[i], label = labels[i])
plt.legend(loc=4)
if __name__ == '__main__':
config = dict(
batchSize = 100,
itersNum = 100000,
learningRate = 0.1,
hiddenLayersSize = [100,50,50])
main(config)