forked from 1297rohit/RCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRCNN_print.py
executable file
·96 lines (77 loc) · 2.35 KB
/
RCNN_print.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python
# coding: utf-8
# In[3]:
import os,cv2, keras
import pandas as pd
import numpy as np
import tensorflow as tf
import pickle
from sklearn.model_selection import train_test_split
from keras.callbacks import ModelCheckpoint, EarlyStopping
# In[4]:
path = "Images"
test_path = "Test"
cv2.setUseOptimized(True);
ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()
im = cv2.imread(os.path.join(path,"42850.jpg"))
ss.setBaseImage(im)
ss.switchToSelectiveSearchFast()
rects = ss.process()
imOut = im.copy()
for i, rect in (enumerate(rects)):
x, y, w, h = rect
# print(x,y,w,h)
# imOut = imOut[x:x+w,y:y+h]
cv2.rectangle(imOut, (x, y), (x+w, y+h), (0, 255, 0), 1, cv2.LINE_AA)
# plt.figure()
plt.imshow(imOut)
plt.show()
# 从文件中加载模型
model_loaded = keras.models.load_model('ieeercnn_vgg16_1.h5')
model_loaded.summary()
# 从文件中加载训练历史
hist = pickle.load(open('history.dat', 'rb'))
# 显示训练历史
import matplotlib.pyplot as plt
# plt.plot(hist.history["acc"])
# plt.plot(hist.history['val_acc'])
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.title("model loss")
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.legend(["Loss","Validation Loss"])
plt.savefig('chart loss.png')
plt.show()
X_test = pickle.load(open('X_test.dat', 'rb'))
im = X_test[1600]
plt.imshow(im)
img = np.expand_dims(im, axis=0)
out= model_loaded.predict(img)
if out[0][0] > out[0][1]:
print("plane")
else:
print("not plane")
plt.show()
print("Model predicting and saving results...")
z=0
for e,i in enumerate(os.listdir(path)):
if i.startswith("4"):
z += 1
img = cv2.imread(os.path.join(path,i))
ss.setBaseImage(img)
ss.switchToSelectiveSearchFast()
ssresults = ss.process()
imout = img.copy()
for e,result in enumerate(ssresults):
if e < 2000:
x,y,w,h = result
timage = imout[y:y+h,x:x+w]
resized = cv2.resize(timage, (224,224), interpolation = cv2.INTER_AREA)
img = np.expand_dims(resized, axis=0)
out= model_loaded.predict(img)
if out[0][0] > 0.65:
cv2.rectangle(imout, (x, y), (x+w, y+h), (0, 255, 0), 1, cv2.LINE_AA)
plt.figure()
plt.imshow(imout)
plt.savefig(os.path.join(test_path,i + '.png'))