-
Notifications
You must be signed in to change notification settings - Fork 18
/
eval_pool_imagenet.m
44 lines (36 loc) · 1.33 KB
/
eval_pool_imagenet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
function outputs = eval_pool_imagenet(net, imdb)
if strcmp(net.device, 'cpu')
net.move('gpu');
end
net.conserveMemory = 0;
nsamp = 1;
%% Get output size.
inputs = {'image', gpuArray(single(imresize(imdb.images.data(:, :, :, 1), [224 224])))};
net.eval(inputs) ;
% Gather results.
index = strfind({net.layers.name}, 'pool_final');
index = find(not(cellfun('isempty', index)));
x = squeeze(gather(net.vars(net.layers(index(1)).outputIndexes(1)).value));
% Rerserve memory.
sz = size(x);
sz(end) = size(imdb.images.data, 4);
outputs = zeros(sz, 'single');
while nsamp < size(imdb.images.data, 4)
step = min(256, size(imdb.images.data, 4) - nsamp+1);
images = gpuArray(single(imresize(imdb.images.data(:, :, :, nsamp:nsamp+step-1), [224 224])));
images(:,:,1,:) = images(:,:,1,:) - imdb.meta.dataMean(1);
images(:,:,2,:) = images(:,:,2,:) - imdb.meta.dataMean(2);
images(:,:,3,:) = images(:,:,3,:) - imdb.meta.dataMean(3);
inputs = {'image', images};
net.eval(inputs) ;
nsamp = nsamp + step;
% Gather results.
index = strfind({net.layers.name}, 'pool_final');
index = find(not(cellfun('isempty', index)));
% Concat results.
x = squeeze(gather(net.vars(net.layers(index(1)).outputIndexes(1)).value));
outputs(:, nsamp:nsamp+step-1) = x;
end
if strcmp(net.device, 'gpu')
net.move('cpu');
end