-
Notifications
You must be signed in to change notification settings - Fork 43
/
detect.py
132 lines (110 loc) · 4.74 KB
/
detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import argparse
import shutil
import time
from pathlib import Path
from sys import platform
from models import *
from utils.datasets import *
from utils.utils import *
def detect(
cfg,
weights,
images,
output='output', # output folder
img_size=416,
conf_thres=0.5,
nms_thres=0.45,
save_txt=False,
save_images=True,
webcam=False
):
device = torch_utils.select_device()
if os.path.exists(output):
shutil.rmtree(output) # delete output folder
os.makedirs(output) # make new output folder
# Initialize model
model = Darknet(cfg, img_size)
# Load weights
if weights.endswith('.pt'): # pytorch format
if weights.endswith('yolov3.pt') and not os.path.exists(weights):
if (platform == 'darwin') or (platform == 'linux'):
os.system('wget https://storage.googleapis.com/ultralytics/yolov3.pt -O ' + weights)
model.load_state_dict(torch.load(weights, map_location='cpu')['model'])
else: # darknet format
load_darknet_weights(model, weights)
model.to(device).eval()
# Set Dataloader
if webcam:
save_images = False
dataloader = LoadWebcam(img_size=img_size)
else:
dataloader = LoadImages(images, img_size=img_size)
# Get classes and colors
classes = load_classes(parse_data_cfg('cfg/coco.data')['names'])
colors = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(classes))]
for i, (path, img, im0) in enumerate(dataloader):
t = time.time()
if webcam:
print('webcam frame %g: ' % (i + 1), end='')
else:
print('image %g/%g %s: ' % (i + 1, len(dataloader), path), end='')
save_path = str(Path(output) / Path(path).name)
# Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device)
if ONNX_EXPORT:
torch.onnx.export(model, img, 'weights/model.onnx', verbose=True)
return
pred = model(img)
pred = pred[pred[:, :, 4] > conf_thres] # remove boxes < threshold
if len(pred) > 0:
# Run NMS on predictions
try :
detections = non_max_suppression(pred.unsqueeze(0), conf_thres, nms_thres)[0]
# Rescale boxes from 416 to true image size
detections[:, :4] = scale_coords(img_size, detections[:, :4], im0.shape)
# Print results to screen
unique_classes = detections[:, -1].cpu().unique()
for c in unique_classes:
n = (detections[:, -1].cpu() == c).sum()
print('%g %ss' % (n, classes[int(c)]), end=', ')
# Draw bounding boxes and labels of detections
for x1, y1, x2, y2, conf, cls_conf, cls in detections:
if save_txt: # Write to file
with open(save_path + '.txt', 'a') as file:
file.write('%g %g %g %g %g %g\n' %
(x1, y1, x2, y2, cls, cls_conf * conf))
# Add bbox to the image
label = '%s %.2f' % (classes[int(cls)], conf)
plot_one_box([x1, y1, x2, y2], im0, label=label, color=colors[int(cls)])
except:
print("sth wrong")
dt = time.time() - t
print('Done. (%.3fs)' % dt)
if save_images: # Save generated image with detections
cv2.imwrite(save_path, im0)
if webcam: # Show live webcam
#cv2.imshow(weights + ' - %.2f FPS' % (1 / dt), im0)
cv2.imshow("im",im0)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
if save_images and (platform == 'darwin'): # linux/macos
os.system('open ' + output + ' ' + save_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('--weights', type=str, default='weights/best.pt', help='path to weights file')
parser.add_argument('--images', type=str, default='data/samples', help='path to images')
parser.add_argument('--img-size', type=int, default=416, help='size of each image dimension')
parser.add_argument('--conf-thres', type=float, default=0.40, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.45, help='iou threshold for non-maximum suppression')
opt = parser.parse_args()
print(opt)
with torch.no_grad():
detect(
opt.cfg,
opt.weights,
opt.images,
img_size=opt.img_size,
conf_thres=opt.conf_thres,
nms_thres=opt.nms_thres
)