-
Notifications
You must be signed in to change notification settings - Fork 37
/
server.py
118 lines (95 loc) · 3.39 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Mute tensorflow debugging information on console
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from flask import Flask, request, render_template, jsonify
from scipy.misc import imsave, imread, imresize
import numpy as np
import argparse
from keras.models import model_from_yaml
import re
import base64
import pickle
app = Flask(__name__)
def load_model(bin_dir):
''' Load model from .yaml and the weights from .h5
Arguments:
bin_dir: The directory of the bin (normally bin/)
Returns:
Loaded model from file
'''
# load YAML and create model
yaml_file = open('%s/model.yaml' % bin_dir, 'r')
loaded_model_yaml = yaml_file.read()
yaml_file.close()
model = model_from_yaml(loaded_model_yaml)
# load weights into new model
model.load_weights('%s/model.h5' % bin_dir)
return model
@app.route("/")
def index():
''' Render index for user connecting to /
'''
return render_template('index.html')
@app.route('/predict/', methods=['GET','POST'])
def predict():
''' Called when user presses the predict button.
Processes the canvas and handles the image.
Passes the loaded image into the neural network and it makes
class prediction.
'''
# Local functions
def crop(x):
# Experimental
_len = len(x) - 1
for index, row in enumerate(x[::-1]):
z_flag = False
for item in row:
if item != 0:
z_flag = True
break
if z_flag == False:
x = np.delete(x, _len - index, 0)
return x
def parseImage(imgData):
# parse canvas bytes and save as output.png
imgstr = re.search(b'base64,(.*)', imgData).group(1)
with open('output.png','wb') as output:
output.write(base64.decodebytes(imgstr))
# get data from drawing canvas and save as image
parseImage(request.get_data())
# read parsed image back in 8-bit, black and white mode (L)
x = imread('output.png', mode='L')
x = np.invert(x)
### Experimental
# Crop on rows
# x = crop(x)
# x = x.T
# Crop on columns
# x = crop(x)
# x = x.T
# Visualize new array
imsave('resized.png', x)
x = imresize(x,(28,28))
# reshape image data for use in neural network
x = x.reshape(1,28,28,1)
# Convert type to float32
x = x.astype('float32')
# Normalize to prevent issues with model
x /= 255
# Predict from model
out = model.predict(x)
# Generate response
response = {'prediction': chr(mapping[(int(np.argmax(out, axis=1)[0]))]),
'confidence': str(max(out[0]) * 100)[:6]}
return jsonify(response)
if __name__ == '__main__':
# Parse optional arguments
parser = argparse.ArgumentParser(description='A webapp for testing models generated from training.py on the EMNIST dataset')
parser.add_argument('--bin', type=str, default='bin', help='Directory to the bin containing the model yaml and model h5 files')
parser.add_argument('--host', type=str, default='0.0.0.0', help='The host to run the flask server on')
parser.add_argument('--port', type=int, default=5000, help='The port to run the flask server on')
args = parser.parse_args()
# Overhead
model = load_model(args.bin)
mapping = pickle.load(open('%s/mapping.p' % args.bin, 'rb'))
app.run(host=args.host, port=args.port)