-
Notifications
You must be signed in to change notification settings - Fork 38
/
model.ts
87 lines (79 loc) · 2.59 KB
/
model.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import * as tf from '@tensorflow/tfjs';
import * as mobileNet from '@tensorflow-models/mobilenet';
navigator.mediaDevices
.getUserMedia({
video: true,
audio: false
})
.then(stream => {
video.srcObject = stream;
});
const video = document.getElementById('cam') as HTMLVideoElement;
const Layer = 'global_average_pooling2d_1';
const mobilenetInfer = m => (p): tf.Tensor<tf.Rank> => m.infer(p, Layer);
const canvas = document.getElementById('canvas') as HTMLCanvasElement;
const crop = document.getElementById('crop') as HTMLCanvasElement;
const ImageSize = {
Width: 100,
Height: 56
};
const grayscale = (canvas: HTMLCanvasElement) => {
const imageData = canvas.getContext('2d').getImageData(0, 0, canvas.width, canvas.height);
const data = imageData.data;
for (let i = 0; i < data.length; i += 4) {
const avg = (data[i] + data[i + 1] + data[i + 2]) / 3;
data[i] = avg;
data[i + 1] = avg;
data[i + 2] = avg;
}
canvas.getContext('2d').putImageData(imageData, 0, 0);
};
let mobilenet: (p: any) => tf.Tensor<tf.Rank>;
tf.loadModel('http://localhost:5000/model.json').then(model => {
mobileNet
.load()
.then((mn: any) => {
mobilenet = mobilenetInfer(mn);
document.getElementById('playground').style.display = 'table';
document.getElementById('loading-page').style.display = 'none';
console.log('MobileNet created');
})
.then(() => {
setInterval(() => {
canvas.getContext('2d').drawImage(video, 0, 0);
crop.getContext('2d').drawImage(canvas, 0, 0, ImageSize.Width, ImageSize.Height);
crop
.getContext('2d')
.drawImage(
canvas,
0,
0,
canvas.width,
canvas.width / (ImageSize.Width / ImageSize.Height),
0,
0,
ImageSize.Width,
ImageSize.Height
);
grayscale(crop);
const [punch, kick, nothing] = Array.from((model.predict(
mobilenet(tf.fromPixels(crop))
) as tf.Tensor1D).dataSync() as Float32Array);
const detect = (window as any).Detect;
if (nothing >= 0.4) {
return;
}
console.log(punch.toFixed(2), kick.toFixed(2));
if (kick > punch && kick >= 0.35) {
console.log('%cKick: ' + kick.toFixed(2), 'color: red; font-size: 30px');
detect.onKick();
return;
}
if (punch > kick && punch >= 0.35) {
console.log('%cPunch: ' + punch.toFixed(2), 'color: blue; font-size: 30px');
detect.onPunch();
return;
}
}, 100);
});
});