-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn.js
128 lines (122 loc) · 3.54 KB
/
nn.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
/**
* Wrapper class around Tensorflow.js API.
*/
class NeuralNetwork {
/**
* @param {tf.Sequential | number} a
* @param {number} b
* @param {number} c
* @param {number} d
*/
constructor(a, b, c, d) {
if (a instanceof tf.Sequential) {
this.model = a;
this.input_nodes = b;
this.hidden_nodes = c;
this.output_nodes = d;
} else {
this.input_nodes = a;
this.hidden_nodes = b;
this.output_nodes = c;
this.model = this.createModel();
}
}
/**
* @returns {NeuralNetwork}
*/
copy() {
return tf.tidy(() => {
const modelCopy = this.createModel();
const weights = this.model.getWeights();
const weightCopies = [];
for (let i = 0; i < weights.length; i++) {
weightCopies[i] = weights[i].clone();
}
modelCopy.setWeights(weightCopies);
return new NeuralNetwork(
modelCopy,
this.input_nodes,
this.hidden_nodes,
this.output_nodes
);
});
}
/**
* @param {NeuralNetwork} other
* @returns {NeuralNetwork}
*/
crossover(other) {
return tf.tidy(() => {
const modelCopy = this.createModel();
const weights = this.model.getWeights();
const otherWeights = other.model.getWeights();
const weightCopies = [];
for (let i = 0; i < weights.length; i++) {
if (Math.random() < 0.5) {
weightCopies[i] = otherWeights[i].clone();
} else {
weightCopies[i] = weights[i].clone();
}
}
modelCopy.setWeights(weightCopies);
return new NeuralNetwork(
modelCopy,
this.input_nodes,
this.hidden_nodes,
this.output_nodes
);
});
}
/**
* @param {number} rate
*/
mutate(rate) {
tf.tidy(() => {
const weights = this.model.getWeights();
const mutatedWeights = [];
for (let i = 0; i < weights.length; i++) {
let tensor = weights[i];
let shape = weights[i].shape;
let values = tensor.dataSync().slice();
for (let j = 0; j < values.length; j++) {
if (random(1) < rate) {
let w = values[j];
values[j] = w + randomGaussian();
}
}
let newTensor = tf.tensor(values, shape);
mutatedWeights[i] = newTensor;
}
this.model.setWeights(mutatedWeights);
});
}
dispose() {
this.model.dispose();
}
/**
* @param {number[]} inputs
* @returns {number[]}
*/
predict(inputs) {
return tf.tidy(() => {
const xs = tf.tensor2d([inputs]);
const ys = this.model.predict(xs);
return ys;
});
}
createModel() {
const model = tf.sequential();
const hidden = tf.layers.dense({
units: this.hidden_nodes,
inputShape: [this.input_nodes],
activation: 'relu',
});
model.add(hidden);
const output = tf.layers.dense({
units: this.output_nodes,
activation: 'softmax',
});
model.add(output);
return model;
}
}