-
Notifications
You must be signed in to change notification settings - Fork 2
/
salWtUpKSVD_graph.m
executable file
·318 lines (287 loc) · 15.5 KB
/
salWtUpKSVD_graph.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
function [Dictionary,output] = salWtUpKSVD_graph(...
Data,... % an nXN matrix that contins N signals (Y), each of dimension n.
param)
% =========================================================================
% K-SVD algorithm
% =========================================================================
% The K-SVD algorithm finds a dictionary for linear representation of
% signals. Given a set of signals, it searches for the best dictionary that
% can sparsely represent each signal. Detailed discussion on the algorithm
% and possible applications can be found in "The K-SVD: An Algorithm for
% Designing of Overcomplete Dictionaries for Sparse Representation", written
% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans.
% On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006.
% =========================================================================
% INPUT ARGUMENTS:
% Data an nXN matrix that contins N signals (Y), each of dimension n.
% param structure that includes all required
% parameters for the K-SVD execution.
% Required fields are:
% K, ... the number of dictionary elements to train
% numIteration,... number of iterations to perform.
% errorFlag... if =0, a fix number of coefficients is
% used for representation of each signal. If so, param.L must be
% specified as the number of representing atom. if =1, arbitrary number
% of atoms represent each signal, until a specific representation error
% is reached. If so, param.errorGoal must be specified as the allowed
% error.
% preserveDCAtom... if =1 then the first atom in the dictionary
% is set to be constant, and does not ever change. This
% might be useful for working with natural
% images (in this case, only param.K-1
% atoms are trained).
% (optional, see errorFlag) L,... % maximum coefficients to use in OMP coefficient calculations.
% (optional, see errorFlag) errorGoal, ... % allowed representation error in representing each signal.
% InitializationMethod,... mehtod to initialize the dictionary, can
% be one of the following arguments:
% * 'DataElements' (initialization by the signals themselves), or:
% * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).
% (optional, see InitializationMethod) initialDictionary,... % if the initialization method
% is 'GivenMatrix', this is the matrix that will be used.
% (optional) TrueDictionary, ... % if specified, in each
% iteration the difference between this dictionary and the trained one
% is measured and displayed.
% displayProgress, ... if =1 progress information is displyed. If param.errorFlag==0,
% the average repersentation error (RMSE) is displayed, while if
% param.errorFlag==1, the average number of required coefficients for
% representation of each signal is displayed.
% =========================================================================
% OUTPUT ARGUMENTS:
% Dictionary The extracted dictionary of size nX(param.K).
% output Struct that contains information about the current run. It may include the following fields:
% CoefMatrix The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.
% ratio If the true dictionary was defined (in
% synthetic experiments), this parameter holds a vector of length
% param.numIteration that includes the detection ratios in each
% iteration).
% totalerr The total representation error after each
% iteration (defined only if
% param.displayProgress=1 and
% param.errorFlag = 0)
% numCoef A vector of length param.numIteration that
% include the average number of coefficients required for representation
% of each signal (in each iteration) (defined only if
% param.displayProgress=1 and
% param.errorFlag = 1)
% =========================================================================
if (~isfield(param,'displayProgress'))
param.displayProgress = 0;
end
totalerr(1) = 99999;
if (isfield(param,'errorFlag')==0)
param.errorFlag = 0;
end
if (isfield(param,'TrueDictionary'))
displayErrorWithTrueDictionary = 1;
ErrorBetweenDictionaries = zeros(param.numIteration+1,1);
ratio = zeros(param.numIteration+1,1);
else
displayErrorWithTrueDictionary = 0;
ratio = 0;
end
if (param.preserveDCAtom>0)
FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));
else
FixedDictionaryElement = [];
end
if(isfield(param,'weight'))
weight = param.weight;
else
weight = eye(size(Data,2),size(Data,2)); % if weights not specified, make weights =1;
end
weight_init = (diag(weight))'/sum(diag(weight)); %vectorize
weight_i = weight_init;
optAff = param.aff;
% coefficient calculation method is OMP with fixed number of coefficients
if (size(Data,2) < param.K)
disp('Size of data is smaller than the dictionary size. Trivial solution...');
Dictionary = Data(:,1:size(Data,2));
return;
elseif (strcmp(param.InitializationMethod,'DataElements'))
data_ids = find(colnorms_squared(Data) > 1e-6); % ensure no zero data elements are chosen
perm = randperm(length(data_ids)); % choose elements randomly from the data to update
Dictionary = Data(:,data_ids(perm(1:param.K-param.preserveDCAtom)));
%Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);
elseif (strcmp(param.InitializationMethod,'GivenMatrix'))
Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);
end
% reduce the components in Dictionary that are spanned by the fixed
% elements
if (param.preserveDCAtom)
tmpMat = FixedDictionaryElement \ Dictionary;
Dictionary = Dictionary - FixedDictionaryElement*tmpMat;
end
%normalize the dictionary.
Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));
Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.
totalErr = zeros(1,param.numIteration);
weight = weight_i;
weight_t = weight;
% the K-SVD algorithm starts here. % %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
output.Dinit = Dictionary ;
for iterNum = 1:param.numIteration
% OMP computation
if (param.errorFlag==0)
%CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);
CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);
else
%CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);
CoefMatrix = OMPerrSal([FixedDictionaryElement,Dictionary],Data, param.errorGoal,diag(weight)); % matrix form of weight
param.L = 1;
end
% compute the reconstruction error for using in the graph
err_recon = (1/0.1)*(sum((Data-Dictionary*CoefMatrix).^2,1))/size(Data,1);
replacedVectorCounter = 0;
rPerm = randperm(size(Dictionary,2));
%CoefMatrix_nw = CoefMatrix;
for j = rPerm
[betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...
[FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...
CoefMatrix ,param.L,weight); % input vector weight
Dictionary(:,j) = betterDictionaryElement;
%CoefMatrix
if (param.preserveDCAtom)
tmpCoef = FixedDictionaryElement\betterDictionaryElement;
Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;
Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));
end
replacedVectorCounter = replacedVectorCounter+addedNewVector;
end
if (iterNum>1 && param.displayProgress)
if (param.errorFlag==0)
output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));
disp(['Iteration ',num2str(iterNum),' Total error is: ',num2str(output.totalerr(iterNum-1))]);
else
output.totalerr(iterNum-1) = sqrt(sum(sum((((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix)')').^2))/prod(size(Data)));
output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);
disp(['Iteration ',num2str(iterNum),' Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);
end
end
if (displayErrorWithTrueDictionary )
[ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);
disp(strcat(['Iteration ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));
output.ratio = ratio;
end
Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);
output.Dict{iterNum} = Dictionary;
%output.CfNew{iterNum} = CoefMatrix;
%%%%%%%%%%%%%%%%%%%%%%%%% weight calculation
%sqrt(sum((weight-weight_t).^2))
output.errW(iterNum) = sqrt(sum((weight-weight_i).^2));
if iterNum==1
wei = optAff*((weight'-err_recon'));
weight_i = (wei/sum(wei))';
else if sqrt(sum((weight-weight_t).^2))>0.005 && iterNum>1
wei = optAff*((weight'-err_recon'));
weight_i = (wei/sum(wei))';
end
end
weight_t = weight;
weight = weight_i;
output.W{iterNum} = weight_i;
output.err{iterNum} = err_recon';
%pause(0.03)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (isfield(param,'waitBarHandle'))
waitbar(iterNum/param.counterForWaitBar);
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
CoefMatrixF = OMPerrSal([FixedDictionaryElement,Dictionary],Data, param.errorGoal,diag(weight)); % matrix form of weight
param.L = 1;
%figure()
%displaySalwt(param.segImg,weight1);
output.CoefMatrix = CoefMatrixF;
Dictionary = [FixedDictionaryElement,Dictionary];
output.weightF = weight;
%output.indiverrwt = sqrt(sum(((weight_i*(Data-Dictionary*output.CoefMatrix)')').^2,1));
%output.indiverr = sqrt(sum(((Data-Dictionary*output.CoefMatrix)).^2,1));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% findBetterDictionaryElement
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [betterDictionaryElement,CoefMatrix,NewVectorAdded] =...
I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed,W)
if (length(who('numCoefUsed'))==0)
numCoefUsed = 1;
end
relevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.
weights = diag(W); % input as vector ; weights is a matrix
wData = (W*Data')';% same size as data
wj = weights(j,j);% weight for the spefic superpixel
if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0)
ErrorMat = wData-(W*((Dictionary*CoefMatrix)'))';
ErrorNormVec = sum(ErrorMat.^2);
[d,i] = max(ErrorNormVec);
betterDictionaryElement = wData(:,i);%ErrorMat(:,i); %
betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'*betterDictionaryElement);
betterDictionaryElement = betterDictionaryElement.*sign(betterDictionaryElement(1));
CoefMatrix(j,:) = 0;
NewVectorAdded = 1;
return;
end
NewVectorAdded = 0;
tmpCoefMatrix = CoefMatrix(:,relevantDataIndices);
tmpweights = W(relevantDataIndices);
tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.
errors =((diag(tmpweights)*((Data(:,relevantDataIndices)) - Dictionary*tmpCoefMatrix)')'); % vector of errors that we want to minimize with the new element
%errors =((diag(tmpweights)*((Data(:,relevantDataIndices)) - Dictionary*tmpCoefMatrix)')'+Dictionary*tmpCoefMatrix); % vector of errors that we want to minimize with the new element
% % the better dictionary element and the values of beta are found using svd.
% % This is because we would like to minimize || errors - beta*element ||_F^2.
% % that is, to approximate the matrix 'errors' with a one-rank matrix. This
% % is done using the largest singular value.
[betterDictionaryElement,singularValue,betaVector] = svds(errors,1);
CoefMatrix(j,relevantDataIndices) = (singularValue*betaVector');%/wj);% *signOfFirstElem
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% findDistanseBetweenDictionaries
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [ratio,totalDistances] = I_findDistanseBetweenDictionaries(original,new)
% first, all the column in oiginal starts with positive values.
catchCounter = 0;
totalDistances = 0;
for i = 1:size(new,2)
new(:,i) = sign(new(1,i))*new(:,i);
end
for i = 1:size(original,2)
d = sign(original(1,i))*original(:,i);
distances =sum ( (new-repmat(d,1,size(new,2))).^2);
[minValue,index] = min(distances);
errorOfElement = 1-abs(new(:,index)'*d);
totalDistances = totalDistances+errorOfElement;
catchCounter = catchCounter+(errorOfElement<0.01);
end
ratio = 100*catchCounter/size(original,2);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% I_clearDictionary
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)
T2 = 0.999;
T1 = 3;
K=size(Dictionary,2);
Er=sum((((Data-Dictionary*CoefMatrix)')').^2,1); % remove identical atoms
G=Dictionary'*Dictionary ;
%pause
G = G-diag(diag(G));
%pause
for jj=1:1:K,
% jj
% max(G(jj,:))
% length(find(abs(CoefMatrix(jj,:))>1e-7))
if max(G(jj,:))>T2 | length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 ,
% jj
[val,pos]=max(Er);
Er(pos(1))=0;
Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1)));
G=Dictionary'*Dictionary; G = G-diag(diag(G));
end;
end;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% misc functions %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function Y = colnorms_squared(X)
% compute in blocks to conserve memory
Y = zeros(1,size(X,2));
blocksize = 2000;
for i = 1:blocksize:size(X,2)
blockids = i : min(i+blocksize-1,size(X,2));
Y(blockids) = sum(X(:,blockids).^2);
end