-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
84 lines (64 loc) · 2.41 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import matplotlib.pyplot as plt
from nn import FCLayer, NN
def generate_sine_data():
""" Generate sine sample data. """
# generate some data
x = np.linspace(0, 2*np.pi, 2000) # even distribution across x
np.random.shuffle(x) # shuffle x's
# calculate Y's, adding noise
y = np.sin(x)+np.random.rand(x.shape[0])/10.0
# get the final form of X,Y
X = np.array(x/2/np.pi)
Y = y/y.max()
return X, Y
def generate_split(X, Y, validation_split=0.15):
""" Splits input data into training and validation data. """
validation_split = 0.15
split_index = int(np.round(X.shape[0]*(1-validation_split)))
X_train = X[:split_index]
Y_train = Y[:split_index]
X_val = X[split_index:]
Y_val = Y[split_index:]
return X_train, Y_train, X_val, Y_val
def test_network():
""" Test neural network, aim is learning the sine function. """
# generate data, make the split
X, Y = generate_sine_data()
X_train, Y_train, X_val, Y_val = generate_split(X, Y)
# our training params
lr = 0.1
epochs = 500
# instance layers & network
layers = [FCLayer("input", 1, 20),
FCLayer("hidden_1", 20, 20),
FCLayer("output", 20, 1)]
nn = NN(layers)
# train
train_metric, validation_metric = nn.train(learning_rate=lr,
max_epochs=epochs,
x_train=X_train,
y_train=Y_train,
x_val=X_val, y_val=Y_val,
batch_size=64,
threshold=0.01)
# get predictions
preds_train = nn.predict(X_train)
preds_val = nn.predict(X_val)
# plot predictions vs ground truth
plt.figure()
plt.scatter(X, Y)
plt.scatter(X_train, np.array(preds_train)[:, 0])
plt.scatter(X_val, np.array(preds_val)[:, 0])
plt.title('Neural network test on learning the sine function')
plt.legend(['ground truth', 'predictions on training set',
'predictions on validation set'])
plt.figure()
plt.plot(train_metric)
plt.plot(validation_metric)
plt.title('Training metrics')
plt.legend(['RMSE on training set', 'RMSE on validation set'])
plt.show()
return
if __name__ == "__main__":
test_network()