-
Notifications
You must be signed in to change notification settings - Fork 0
/
AGN.m
130 lines (108 loc) · 3.14 KB
/
AGN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
clear; close all; clc;
warning off;
n = 30;
r = 2;
kappa_list = [1000];
m = 3*n*r;
T = 10000;
eta = 0.5;
thresh_up = 1e10; thresh_low = 1e-14;
errors_GD = zeros(length(kappa_list), T);
errors_AGN = zeros(length(kappa_list), T);
U_seed = sign(rand(n, r) - 0.5);
[U_star, ~, ~] = svds(U_seed, r);
V_seed = sign(rand(n, r) - 0.5);
[V_star, ~, ~] = svds(U_seed, r);
As = cell(m, 1);
for k = 1:m
As{k} = randn(n, n)/sqrt(m);
end
for i_kappa = 1:length(kappa_list)
kappa = kappa_list(i_kappa);
sigma_star = linspace(1, 1/kappa, r);
L_star = U_star*diag(sqrt(sigma_star));
R_star = V_star*diag(sqrt(sigma_star));
X_star = L_star*R_star';
y = zeros(m, 1);
for k = 1:m
y(k) = As{k}(:)'*X_star(:);
end
%% Spectral initialization
Y = zeros(n, n);
for k = 1:m
Y = Y + y(k)*As{k};
end
d = 2*r;
%% GD
L = randn(n,d)/10;
R = randn(n,d)/10;
for t = 1:T
% update L
X = L*R';
error = norm(X - X_star, 'fro');
errors_GD(i_kappa, t) = error;
if ~isfinite(error) || error > thresh_up || error < thresh_low
break;
end
Z = zeros(n, n);
for k = 1:m
Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
end
L_plus = L - eta*Z*R;
R_plus = R - eta*Z'*L;
L = L_plus;
R = R_plus;
end
%% AGN
L = randn(n,d)/10;
R = randn(n,d)/10;
for t = 1:T
X = L*R';
error = norm(X - X_star, 'fro');%/norm(X_star, 'fro');
errors_AGN(i_kappa, t) = error;
if ~isfinite(error) || error > thresh_up || error < thresh_low
break;
end
% update L
Z = zeros(n, n);
for k = 1:m
Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
end
Delta_L = mldivide(R, Z');
L = L - eta*Delta_L';
% update R
X = L*R';
Z = zeros(n, n);
for k = 1:m
Z = Z + (As{k}(:)'*X(:) - y(k))*As{k};
end
Delta_R = mldivide(L,Z);
R = R - eta*Delta_R';
end
end
clrs = {[.5,0,.5], [1,.5,0], [1,0,0], [0,.5,0], [0,0,1]};
mks = {'o', 'x', 'p', 's', 'd'};
figure('Position', [0,0,800,600], 'DefaultAxesFontSize', 20);
lgd = {};
for i_kappa = 1:length(kappa_list)
kappa = kappa_list(i_kappa);
errors = errors_GD(i_kappa, :);
errors = errors(errors > thresh_low);
t_subs = 1:1:length(errors);
semilogy(t_subs-1, errors(t_subs), 'Color', clrs{1}, 'Marker', mks{i_kappa}, 'MarkerSize', 9);
hold on; grid on;
lgd{end+1} = sprintf('$\\mathrm{GD}~\\kappa=%d$', kappa);
end
for i_kappa = 1:length(kappa_list)
kappa = kappa_list(i_kappa);
errors = errors_AGN(i_kappa, :);
errors = errors(errors > thresh_low);
t_subs = 1:1:length(errors);
semilogy(t_subs-1, errors(t_subs), 'Color', clrs{2}, 'Marker', mks{i_kappa}, 'MarkerSize', 9);
hold on; grid on;
lgd{end+1} = sprintf('$\\mathrm{AGN}~\\kappa=%d$', kappa);
end
xlabel('Iteration count');
ylabel('Relative error');
legend(lgd, 'Location', 'northeast', 'Interpreter', 'latex', 'FontSize', 24);
fig_name = sprintf('MS_n=%d_r=%d_m=%d', n, r, m);