forked from norouzi/mih
-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_lsh_codes.m
147 lines (132 loc) · 4.94 KB
/
create_lsh_codes.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
% dataset_name should be one of: sift_1M / sift_1B / gist_1M / gist_80M.
% You should download the datasets separately.
dataset_name = 'sift_1B';
% nb controls the lenght of the binary codes generated.
% nb can be set from outside.
if (~exist('nb', 'var'))
nb = 64;
end
% Where the corresponding datasets are stored:
TINY_HOME = 'data/tiny'; % the root of 80 million tiny images dataset
INRIA_HOME = 'data/inria'; % the root of INIRA BIGANN datasets
% Where the output matrix of binary codes should be stored:
outputdir = 'codes/lsh';
% CACHE_DIR is used to store the data mean for the datasets.
CACHE_DIR = 'cache';
if (~exist(outputdir, 'file'))
mkdir(outputdir);
end
if (~exist(CACHE_DIR, 'file'))
mkdir(CACHE_DIR);
end
addpath matlab;
if (strcmp(dataset_name, 'sift_1B') || strcmp(dataset_name, 'sift_1M') || strcmp(dataset_name, 'gist_1M'))
addpath([INRIA_HOME, '/matlab']);
else
addpath([TINY_HOME, '/code']);
end
if strcmp(dataset_name, 'sift_1M')
dataset = 'ANN_SIFT1M';
datahome = INRIA_HOME;
N = 10^6;
elseif strcmp(dataset_name, 'sift_1B')
dataset = 'ANN_SIFT1B';
datahome = INRIA_HOME;
N = 10^9;
elseif strcmp(dataset_name, 'sift_1B_tr')
dataset = 'ANN_SIFT1B';
datahome = INRIA_HOME;
N = 10^8;
elseif strcmp(dataset_name, 'gist_1M')
dataset = 'ANN_GIST1M';
datahome = INRIA_HOME;
N = 10^6;
elseif strcmp(dataset_name, 'gist_80M')
dataset = '80M';
datahome = TINY_HOME;
N = 79*10^6;
end
if ~exist([CACHE_DIR, '/', dataset_name, '_mean.mat'], 'file')
fprintf('Computing the data mean for the %s dataset... \n', dataset_name);
if strcmp(dataset_name, 'sift_1M')
trdata = fvecs_read([datahome, '/ANN_SIFT1M/sift/sift_learn.fvecs']);
learn_mean = mean(trdata, 2);
save([CACHE_DIR, '/sift_1M_mean'], 'learn_mean');
elseif strcmp(dataset_name, 'sift_1B')
Ntraining = 10^8;
nbuffer = 10^6;
for i=1:floor(Ntraining/nbuffer)
fprintf('%d/%d\r', i, floor(Ntraining/nbuffer));
trdatai = b2fvecs_read([datahome, '/ANN_SIFT1B/bigann_learn.bvecs'], [(i-1)*nbuffer+1 (i)*nbuffer]);
learn_meani(:,i) = sum(trdatai, 2, 'double');
end
learn_mean = sum(learn_meani, 2, 'double');
learn_mean = learn_mean / Ntraining;
clear trdatai learn_meani;
save([CACHE_DIR, '/sift_1B_mean'], 'learn_mean');
elseif strcmp(dataset_name, 'gist_1M')
trdata = fvecs_read([datahome, '/ANN_GIST1M/gist/gist_learn.fvecs']);
learn_mean = mean(trdata, 2);
save([CACHE_DIR, '/gist_1M_mean'], 'learn_mean');
elseif strcmp(dataset_name, 'gist_80M')
trdata = read_tiny_gist_binary(1:10^7);
learn_mean = mean(trdata, 2);
clear trdata;
save([CACHE_DIR, '/gist_80M_mean'], 'learn_mean');
perm = randperm(79302017);
perm(1:79*10^6) = sort(perm(1:79*10^6));
save([CACHE_DIR, '/gist_80M_mean'], 'perm', '-append');
else
fprintf('dataset not supported.\n');
continue;
end
fprintf('done. \n');
else
load([CACHE_DIR, '/', dataset_name, '_mean']);
end
nd = size(learn_mean, 1);
W = [randn(nb, nd) zeros(nb, 1)]; % Random projection-based hashing (LSH) preserves angles.
% One can load W from outside too
nbuffer = 10^6;
B = zeros(ceil(nb/8), N, 'uint8');
fprintf('Computing %d-bit binary codes...\n', nb);
for i=1:floor(N/nbuffer)
fprintf('%d/%d\r', i, floor(N/nbuffer));
if strcmp(dataset_name, 'sift_1M')
base = fvecs_read([datahome, '/ANN_SIFT1M/sift/sift_base.fvecs'], [(i-1)*nbuffer+1 (i)*nbuffer]);
elseif strcmp(dataset_name, 'sift_1B')
base = b2fvecs_read([datahome, '/ANN_SIFT1B/bigann_base.bvecs'], [(i-1)*nbuffer+1 (i)*nbuffer]);
elseif strcmp(dataset_name, 'sift_1B_tr')
base = b2fvecs_read([datahome, '/ANN_SIFT1B/bigann_learn.bvecs'], [(i-1)*nbuffer+1 (i)*nbuffer]);
elseif strcmp(dataset_name, 'gist_1M')
base = fvecs_read([datahome, '/ANN_GIST1M/gist/gist_base.fvecs'], [(i-1)*nbuffer+1 (i)*nbuffer]);
elseif strcmp(dataset_name, 'gist_80M')
base = read_tiny_gist_binary( perm(((i-1)*nbuffer+1):((i)*nbuffer)) );
end
base = double(base);
base = bsxfun(@minus, base, learn_mean);
B1 = (W * [base; ones(1, size(base,2))]) > 0;
B1 = compactbit(B1);
B(:, (i-1)*nbuffer+1:(i)*nbuffer) = B1;
end
query = [];
if strcmp(dataset_name, 'sift_1M')
query = fvecs_read([datahome, '/ANN_SIFT1M/sift/sift_query.fvecs']);
elseif strcmp(dataset_name, 'sift_1B')
query = b2fvecs_read([datahome, '/ANN_SIFT1B/bigann_query.bvecs']);
elseif strcmp(dataset_name, 'gist_1M')
query = fvecs_read([datahome, '/ANN_GIST1M/gist/gist_query.fvecs']);
elseif strcmp(dataset_name, 'gist_80M')
query = read_tiny_gist_binary( perm([(79302017-10000+1):79302017]) );
end
if (isempty(query))
Q = [];
else
query = bsxfun(@minus, query, learn_mean);
Q = (W * [query; ones(1, size(query,2))] > 0);
Q = compactbit(Q);
end
fprintf('storing the codes in the file %s ...', [outputdir, '/lsh_', num2str(nb), '_', dataset_name]);
save([outputdir, '/lsh_', num2str(nb), '_', dataset_name], 'B', 'Q', 'W', 'learn_mean', '-v7.3');
clear B Q W;
fprintf('done.\n');