-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
46 lines (35 loc) · 1.68 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
import sys
import numpy as np
import csv
import cv2
from PIL import Image
from IPython.display import display
def predict(sess, config, data, graph):
imagePath = "../results/predict"+config["dataset"]+".jpg"
img = cv2.imread(imagePath)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
imgRes = cv2.resize(img, (data.config["x"], data.config["y"]), interpolation=cv2.INTER_AREA)
imgRes = (imgRes - imgRes.mean()) / imgRes.std()
inputData = np.expand_dims(imgRes, axis=0)
if config["batchSize"] > 1:
fillerArr = np.zeros((1,data.config["y"], data.config["x"], data.config["imageChannels"]))
for x in range(config["batchSize"]-1):
inputData = np.concatenate((inputData, fillerArr), axis=0)
feed_dict = {
graph["imagePlaceholder"]: inputData
}
predClasses = sess.run(graph["prediction"], feed_dict=feed_dict)
predClasses = predClasses[0].reshape(data.config["x"]*data.config["y"])
predImg = np.zeros((data.config["x"]*data.config["y"],3))
#for idx, p in enumerate(predClasses):
# predImg[idx] = data.config["ClassToRGB"][p]
#print(np.unique(predClasses))
for cl in range(config["classes"]):
predImg[(predClasses == cl)] = data.config["ClassToRGB"][cl]
#print(np.unique(predImg))
predImg = predImg.reshape((data.config["y"], data.config["x"], data.config["imageChannels"])).astype("uint8")
savePath = "../results/"+data.config["name"]+str(data.config["x"])+str(data.config["y"])+config["neuralNetwork"]+".png"
savedImage = Image.fromarray(predImg, "RGB")
savedImage.save(savePath)
display(savedImage)