forked from render-examples/flask-hello-world
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor.py
41 lines (34 loc) · 1.31 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tensorflow as tf
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
# Method to process the image and return the variable in a format that the model can understand
def import_to_array(img_path):
img = Image.open(img_path).convert('L')
img = ImageOps.exif_transpose(img)
brightness = ImageEnhance.Brightness
contrast = ImageEnhance.Contrast
img = brightness(img).enhance(2.0)
img = contrast(img).enhance(2.0)
img = img.resize((28, 28))
img_array = np.array(img)
img_array = np.invert(img_array)
img_array = (img_array/255.0).astype('float32')
img_array = img_array.reshape(1, 784)
return img_array
# Load the model
interpreter = tf.lite.Interpreter(model_path="number_classification_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Apply the model to the image
def predict(img_path):
img_array = import_to_array(img_path)
interpreter.set_tensor(input_details[0]['index'], img_array)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
predicted_class = np.argmax(output_data)
return predicted_class
# Execute the model
# input_img = 'test7.png'
# print(predict(input_img))