-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathmain.lua
124 lines (106 loc) · 4.29 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
require 'paths'
paths.dofile('util.lua')
paths.dofile('img.lua')
--------------------------------------------------------------------------------
-- Initialization
--------------------------------------------------------------------------------
if arg[1] == 'demo' or arg[1] == 'predict-test' then
-- Test set annotations do not have ground truth part locations, but provide
-- information about the location and scale of people in each image.
a = loadAnnotations('test')
elseif arg[1] == 'predict-valid' or arg[1] == 'eval' then
-- Validation set annotations on the other hand, provide part locations,
-- visibility information, normalization factors for final evaluation, etc.
a = loadAnnotations('valid')
else
print("Please use one of the following input arguments:")
print(" demo - Generate and display results on a few demo images")
print(" predict-valid - Generate predictions on the validation set (MPII images must be available in 'images' directory)")
print(" predict-test - Generate predictions on the test set")
print(" eval - Run basic evaluation on predictions from the validation set")
return
end
m = torch.load('umich-stacked-hourglass.t7') -- Load pre-trained model
if arg[1] == 'demo' then
idxs = torch.Tensor({695, 3611, 2486, 7424, 10032, 5, 4829})
-- If all the MPII images are available, use the following line to see a random sampling of images
-- idxs = torch.randperm(a.nsamples):sub(1,10)
else
idxs = torch.range(1,a.nsamples)
end
if arg[1] == 'eval' then
nsamples = 0
else
nsamples = idxs:nElement()
-- Displays a convenient progress bar
xlua.progress(0,nsamples)
preds = torch.Tensor(nsamples,16,2)
end
--------------------------------------------------------------------------------
-- Main loop
--------------------------------------------------------------------------------
for i = 1,nsamples do
-- Set up input image
local im = image.load('images/' .. a['images'][idxs[i]])
local center = a['center'][idxs[i]]
local scale = a['scale'][idxs[i]]
local inp = crop(im, center, scale, 0, 256)
-- Get network output
local out = m:forward(inp:view(1,3,256,256):cuda())
cutorch.synchronize()
local hm = out[#out][1]:float()
hm[hm:lt(0)] = 0
-- Get predictions (hm and img refer to the coordinate space)
local preds_hm, preds_img = getPreds(hm, center, scale)
preds[i]:copy(preds_img)
xlua.progress(i,nsamples)
-- Display the result
if arg[1] == 'demo' then
preds_hm:mul(4) -- Change to input scale
local dispImg = drawOutput(inp, hm, preds_hm[1])
w = image.display{image=dispImg,win=w}
sys.sleep(3)
end
collectgarbage()
end
-- Save predictions
if arg[1] == 'predict-valid' then
local predFile = hdf5.open('preds/valid-example.h5', 'w')
predFile:write('preds', preds)
predFile:close()
elseif arg[1] == 'predict-test' then
local predFile = hdf5.open('preds/test.h5', 'w')
predFile:write('preds', preds)
predFile:close()
elseif arg[1] == 'demo' then
w.window:close()
end
--------------------------------------------------------------------------------
-- Evaluation code
--------------------------------------------------------------------------------
if arg[1] == 'eval' then
-- Calculate distances given each set of predictions
local labels = {'valid-example','valid-ours'}
local dists = {}
for i = 1,#labels do
local predFile = hdf5.open('preds/' .. labels[i] .. '.h5','r')
local preds = predFile:read('preds'):all()
table.insert(dists,calcDists(preds, a.part, a.normalize))
end
require 'gnuplot'
gnuplot.raw('set bmargin 1')
gnuplot.raw('set lmargin 3.2')
gnuplot.raw('set rmargin 2')
gnuplot.raw('set multiplot layout 2,3 title "MPII Validation Set Performance (PCKh)"')
gnuplot.raw('set xtics font ",6"')
gnuplot.raw('set ytics font ",6"')
displayPCK(dists, {9,10}, labels, 'Head')
displayPCK(dists, {2,5}, labels, 'Knee')
displayPCK(dists, {1,6}, labels, 'Ankle')
gnuplot.raw('set tmargin 2.5')
gnuplot.raw('set bmargin 1.5')
displayPCK(dists, {13,14}, labels, 'Shoulder')
displayPCK(dists, {12,15}, labels, 'Elbow')
displayPCK(dists, {11,16}, labels, 'Wrist', true)
gnuplot.raw('unset multiplot')
end