Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
timlautk authored Jan 26, 2018
1 parent 074df6a commit cb5f8af
Show file tree
Hide file tree
Showing 45 changed files with 519 additions and 0 deletions.
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveVb.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [Vnew,b4new,beta1] = AdaptiveVb(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,Vstar,b1,b2,b3,b4,b4star,beta,t)
Vnew = V + beta*(Vstar-V);
b4new = b4 + beta*(b4star-b4);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,Vstar,b1,b2,b3,b4star) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,Vnew,b1,b2,b3,b4new)
beta1 = t*beta;
Vnew = Vstar;
b4new = b4star;
else
% beta1 = min(beta/t,1);
beta1 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveVb_2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [Vnew,b4new,beta1] = AdaptiveVb_2(gamma1,gamma2,gamma3,rho1,rho2,rho3,y,a0,a1,a2,a3,z1,z2,z3,W1,W2,W3,V,Vstar,b1,b2,b3,b4,b4star,beta,t,act)
Vnew = V + beta*(Vstar-V);
b4new = b4 + beta*(b4star-b4);
if loss_fun_3(gamma1,gamma2,gamma3,rho1,rho2,rho3,y,a0,a1,a2,a3,z1,z2,z3,W1,W2,W3,Vstar,b1,b2,b3,b4star,act) <= loss_fun_3(gamma1,gamma2,gamma3,rho1,rho2,rho3,y,a0,a1,a2,a3,z1,z2,z3,W1,W2,W3,Vnew,b1,b2,b3,b4new,act)
beta1 = t*beta;
Vnew = Vstar;
b4new = b4star;
else
% beta1 = min(beta/t,1);
beta1 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveVb_3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W2new,b2new,beta1] = AdaptiveVb_3(lambda,y,a1,W2,W2star,b2,b2star,beta,t)
W2new = W2 + beta*(W2star-W2);
b2new = b2 + beta*(b2star-b2);
if cross_entropy(y,a1,W2star,b2star)+lambda*norm(W2star,'fro')^2 <= cross_entropy(y,a1,W2new,b2new)+lambda*norm(W2new,'fro')^2
beta1 = t*beta;
% W2new = W2star;
% b2new = b2star;
else
% beta1 = min(beta/t,1);
beta1 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveVb_4.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W2new,b2new,beta1] = AdaptiveVb_4(lambda,gamma,y,a1,W2,W2star,b2,b2star,beta,t)
W2new = W2 + beta*(W2star-W2);
b2new = b2 + beta*(b2star-b2);
if gamma*norm(W2star*a1+b2star-y,'fro')^2+lambda*norm(W2star,'fro')^2 <= gamma*norm(W2new*a1+b2new-y,'fro')^2+lambda*norm(W2new,'fro')^2
beta1 = t*beta;
% W2new = W2star;
% b2new = b2star;
else
% beta1 = min(beta/t,1);
beta1 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb1.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W1new,b1new,beta7] = AdaptiveWb1(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,W1star,b1,b2,b3,b4,b1star,beta,t)
W1new = W1 + beta*(W1star-W1);
b1new = b1 + beta*(b1star-b1);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1star,W2,W3,V,b1star,b2,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1new,W2,W3,V,b1new,b2,b3,b4)
beta7 = t*beta;
W1new = W1star;
b1new = b1star;
else
% beta7 = min(beta/t,1);
beta7 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb1_2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W1new,b1new,beta7] = AdaptiveWb1(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,W1star,b1,b2,b3,b4,b1star,beta,t)
W1new = W1 + beta*(W1star-W1);
b1new = b1 + beta*(b1star-b1);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1star,W2,W3,V,b1star,b2,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1new,W2,W3,V,b1new,b2,b3,b4)
beta7 = t*beta;
W1new = W1star;
b1new = b1star;
else
% beta7 = min(beta/t,1);
beta7 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb1_3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W1new,b1new,beta7] = AdaptiveWb1_3(a0,a1,W1,W1star,b1,b1star,beta,t)
W1new = W1 + beta*(W1star-W1);
b1new = b1 + beta*(b1star-b1);
if norm(W1star*a0+b1star-a1,'fro')^2 <= norm(W1new*a0+b1new-a1,'fro')^2
beta7 = t*beta;
% W1new = W1star;
% b1new = b1star;
else
beta7 = min(beta/t,1);
% beta7 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb1_4.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W1new,b1new,beta7] = AdaptiveWb1_4(lambda,gamma,a0,a1,W1,W1star,b1,b1star,u1,beta,t)
W1new = W1 + beta*(W1star-W1);
b1new = b1 + beta*(b1star-b1);
if gamma*norm(W1star*a0+b1star-a1+u1,'fro')^2+lambda*norm(W1star,'fro')^2 <= gamma*norm(W1new*a0+b1new-a1+u1,'fro')^2+lambda*norm(W1new,'fro')^2
beta7 = t*beta;
% W1new = W1star;
% b1new = b1star;
else
% beta7 = min(beta/t,1);
beta7 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W2new,b2new,beta5] = AdaptiveWb2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,W2star,b1,b2,b3,b4,b2star,beta,t)
W2new = W2 + beta*(W2star-W2);
b2new = b2 + beta*(b2star-b2);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2star,W3,V,b1,b2star,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2new,W3,V,b1,b2new,b3,b4)
beta5 = t*beta;
W2new = W2star;
b2new = b2star;
else
% beta5 = min(beta/t,1);
beta5 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb2_2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W2new,b2new,beta5] = AdaptiveWb2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,W2star,b1,b2,b3,b4,b2star,beta,t)
W2new = W2 + beta*(W2star-W2);
b2new = b2 + beta*(b2star-b2);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2star,W3,V,b1,b2star,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2new,W3,V,b1,b2new,b3,b4)
beta5 = t*beta;
W2new = W2star;
b2new = b2star;
else
% beta5 = min(beta/t,1);
beta5 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W3new,b3new,beta3] = AdaptiveWb3(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,W3star,b1,b2,b3,b4,b3star,beta,t)
W3new = W3 + beta*(W3star-W3);
b3new = b3 + beta*(b3star-b3);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3star,V,b1,b2,b3star,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3new,V,b1,b2,b3new,b4)
beta3 = t*beta;
W3new = W3star;
b3new = b3star;
else
% beta3 = min(beta/t,1);
beta3 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb3_2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W3new,b3new,beta3] = AdaptiveWb3(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3,V,W3star,b1,b2,b3,b4,b3star,beta,t)
W3new = W3 + beta*(W3star-W3);
b3new = b3 + beta*(b3star-b3);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3star,V,b1,b2,b3star,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,W1,W2,W3new,V,b1,b2,b3new,b4)
beta3 = t*beta;
W3new = W3star;
b3new = b3star;
else
% beta3 = min(beta/t,1);
beta3 = beta;
end

end
13 changes: 13 additions & 0 deletions Algorithms/AdaptiveWb_ResNet.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function [W1new,b1new,beta7] = AdaptiveWb_ResNet(lambda,gamma,x,a0,a1,W1,W1star,b1,b1star,beta,t)
W1new = W1 + beta*(W1star-W1);
b1new = b1 + beta*(b1star-b1);
if gamma*norm(W1star*a0+b1star+x-a1,'fro')^2+lambda*norm(W1star,'fro')^2 <= gamma*norm(W1new*a0+b1new+x-a1,'fro')^2+lambda*norm(W1new,'fro')^2
beta7 = t*beta;
% W1new = W1star;
% b1new = b1star;
else
beta7 = min(beta/t,1);
% beta7 = beta;
end

end
11 changes: 11 additions & 0 deletions Algorithms/Adaptivea1.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [a1new,beta6] = Adaptivea1(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,a1star,W1,W2,W3,V,b1,b2,b3,b4,beta,t)
a1new = a1 + beta*(a1star-a1);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1star,a2,a3,W1,W2,W3,V,b1,b2,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1new,a2,a3,W1,W2,W3,V,b1,b2,b3,b4)
beta6 = t*beta;
a1new = a1star;
else
% beta6 = min(beta/t,1);
beta6 = beta;
end

end
11 changes: 11 additions & 0 deletions Algorithms/Adaptivea1_2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [a1new,beta6] = Adaptivea1_2(gamma1,y,a0,a1,a1star,W1,W2,b1,b2,beta,t)
a1new = a1 + beta*(a1star-a1);
if cross_entropy(y,a1star,W2,b2)+gamma1/2*norm(W1*a0+b1-a1star,'fro')^2 <= cross_entropy(y,a1new,W2,b2)+gamma1/2*norm(W1*a0+b1-a1new,'fro')^2
beta6 = t*beta;
a1new = a1star;
else
% beta6 = min(beta/t,1);
beta6 = beta;
end

end
11 changes: 11 additions & 0 deletions Algorithms/Adaptivea1_3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [a1new,beta6] = Adaptivea1_3(gamma1,gamma2,y,a0,a1,a1star,W1,W2,b1,b2,beta,t)
a1new = a1 + beta*(a1star-a1);
if gamma1*norm(W1*a0+b1-a1star,'fro')^2+gamma2*norm(W2*a1star+b2-y,'fro')^2 <= gamma1*norm(W1*a0+b1-a1new,'fro')^2+gamma2*norm(W2*a1new+b2-y,'fro')^2
beta6 = t*beta;
% a1new = a1star;
else
% beta6 = min(beta/t,1);
beta6 = beta;
end

end
11 changes: 11 additions & 0 deletions Algorithms/Adaptivea2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [a2new,beta4] = Adaptivea2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,a2star,W1,W2,W3,V,b1,b2,b3,b4,beta,t)
a2new = a2 + beta*(a2star-a2);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2star,a3,W1,W2,W3,V,b1,b2,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2new,a3,W1,W2,W3,V,b1,b2,b3,b4)
beta4 = t*beta;
a2new = a2star;
else
% beta4 = min(beta/t,1);
beta4 = beta;
end

end
11 changes: 11 additions & 0 deletions Algorithms/Adaptivea3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [a3new,beta2] = Adaptivea3(gamma1,gamma2,gamma3,y,a0,a1,a2,a3,a3star,W1,W2,W3,V,b1,b2,b3,b4,beta,t)
a3new = a3 + beta*(a3star-a3);
if loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3star,W1,W2,W3,V,b1,b2,b3,b4) <= loss_fun_2(gamma1,gamma2,gamma3,y,a0,a1,a2,a3new,W1,W2,W3,V,b1,b2,b3,b4)
beta2 = t*beta;
a3new = a3star;
else
% beta2 = min(beta/t,1);
beta2 = beta;
end

end
11 changes: 11 additions & 0 deletions Algorithms/Adaptivea3_2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [a3new,beta2] = Adaptivea3_2(gamma1,gamma2,gamma3,rho1,rho2,rho3,y,a0,a1,a2,a3,a3star,z1,z2,z3,W1,W2,W3,V,b1,b2,b3,b4,beta,t,act)
a3new = a3 + beta*(a3star-a3);
if loss_fun_3(gamma1,gamma2,gamma3,rho1,rho2,rho3,y,a0,a1,a2,a3star,z1,z2,z3,W1,W2,W3,V,b1,b2,b3,b4,act) <= loss_fun_3(gamma1,gamma2,gamma3,rho1,rho2,rho3,y,a0,a1,a2,a3new,z1,z2,z3,W1,W2,W3,V,b1,b2,b3,b4,act)
beta2 = t*beta;
a3new = a3star;
else
% beta2 = min(beta/t,1);
beta2 = beta;
end

end
12 changes: 12 additions & 0 deletions Algorithms/Adaptivea3_3.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function [a3new,beta2] = Adaptivea3_3(gamma3,gamma4,y,a2,a3,a3star,W3,V,b3,c,beta,t)
[~,N] = size(a3);
a3new = a3 + beta*(a3star-a3);
if gamma4/2*norm(V*a3star+c-y,'fro')^2+gamma3/2*norm(W3*a2+b3-a3star,'fro')^2 <= norm(V*a3new+c-y,'fro')^2/(2*N)+gamma3/2*norm(W3*a2+b3-a3new,'fro')^2
beta2 = t*beta;
a3new = a3star;
else
% beta2 = min(beta/t,1);
beta2 = beta;
end

end
4 changes: 4 additions & 0 deletions Algorithms/MomentumWb.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
function [Wnew,bnew] = MomentumWb(W,W0,b,b0,beta)
Wnew = W0 + beta*(W-W0);
bnew = b0 + beta*(b-b0);
end
27 changes: 27 additions & 0 deletions Algorithms/updateU.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function [U1new,U2new,U3new,beta1] = updateU(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,V,alpha,beta,t)
[N,K] = size(U0);
U = [U0;U1;U2;U3];
I = sparse(eye(N));
Q = 2*[
gamma1/2*(W10'*W10)+(gamma2+gamma3)/2*I, gamma2*W21+gamma3/2*I, -gamma2/2*I+gamma3*W32, -gamma3/2*I;
-gamma1*W10+gamma3/2*I, (gamma1+gamma3)/2*I+gamma2/2*(W21'*W21), gamma3*W32, -gamma3/2*I;
-gamma2/2*I, -gamma2*W21, gamma2/2*I+gamma3/2*(W32'*W32), zeros(N);
-gamma3/2*I, -gamma3/2*I, -gamma3*W32, gamma3/2*I
];
P = sparse([zeros(N,3*N) I]);
Ustar = (2/K*(V*P)'*(V*P)+Q+alpha*eye(4*N))\(2/K*(V*P)'*y+alpha*U);
Ustar = max(0,Ustar);
% Ustar = max(-1,min(Ustar,1));
U = U + beta*(Ustar-U);
U1new = U(N+1:2*N,:);
U2new = U(2*N+1:3*N,:);
U3new = U(3*N+1:end,:);
if loss_fun(gamma1,gamma2,gamma3,y,U0,Ustar(N+1:2*N,:),Ustar(2*N+1:3*N,:),Ustar(3*N+1:end,:),W10,W21,W32,V) <= loss_fun(gamma1,gamma2,gamma3,y,U0,U1new,U2new,U3new,W10,W21,W32,V)
beta1 = t*beta;
U1new = Ustar(N+1:2*N,:);
U2new = Ustar(2*N+1:3*N,:);
U3new = Ustar(3*N+1:end,:);
else
beta1 = min(beta/t,1);
end
end
27 changes: 27 additions & 0 deletions Algorithms/updateU2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
function [U1new,U2new,U3new,beta1] = updateU2(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,V,alpha,beta,t)
[N,K] = size(U0);
U = [U0;U1;U2;U3];
I = sparse(eye(N));
Q = 2*[
gamma1/2*(W10'*W10)+(gamma2+gamma3)/2*I, gamma2*W21+gamma3/2*I, -gamma2/2*I+gamma3*W32, -gamma3/2*I;
-gamma1*W10+gamma3/2*I, (gamma1+gamma3)/2*I+gamma2/2*(W21'*W21), gamma3*W32, -gamma3/2*I;
-gamma2/2*I, -gamma2*W21, gamma2/2*I+gamma3/2*(W32'*W32), zeros(N);
-gamma3/2*I, -gamma3/2*I, -gamma3*W32, gamma3/2*I
];
P = sparse([zeros(N,3*N) I]);
M = V*P;
Ustar = U - 1/alpha*((2/K*(M'*M)+Q)*U-2/K*M'*y);
Ustar = max(0,Ustar);
U = U + beta*(Ustar-U);
U1new = U(N+1:2*N,:);
U2new = U(2*N+1:3*N,:);
U3new = U(3*N+1:end,:);
if loss_fun(gamma1,gamma2,gamma3,y,U0,Ustar(N+1:2*N,:),Ustar(2*N+1:3*N,:),Ustar(3*N+1:end,:),W10,W21,W32,V) <= loss_fun(gamma1,gamma2,gamma3,y,U0,U1new,U2new,U3new,W10,W21,W32,V)
beta1 = t*beta;
U1new = Ustar(N+1:2*N,:);
U2new = Ustar(2*N+1:3*N,:);
U3new = Ustar(3*N+1:end,:);
else
beta1 = min(beta/t,1);
end
end
14 changes: 14 additions & 0 deletions Algorithms/updateV.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function [Vnew,beta2] = updateV(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,V,alpha,beta,t)
[N,K] = size(U0);
U = [U0;U1;U2;U3];
I = sparse(eye(N));
P = sparse([zeros(N,3*N) I]);
Vstar = (2/K*(y*U.')*P'+alpha*V)*pinv(2/K*P*(U*U.')*P'+alpha*I);
Vnew = V + beta*(Vstar-V);
if loss_fun(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,V) <= loss_fun(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,Vnew)
beta2 = t*beta;
Vnew = Vstar;
else
beta2 = min(beta/t,1);
end
end
14 changes: 14 additions & 0 deletions Algorithms/updateV2.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
function [Vnew,beta2] = updateV2(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,V,alpha,beta,t)
[N,K] = size(U0);
U = [U0;U1;U2;U3];
I = sparse(eye(N));
P = sparse([zeros(N,3*N) I]);
Vstar = V - 2/(alpha*K)*((V*P*(U*U.')-y*U.')*P');
Vnew = V + beta*(Vstar-V);
if loss_fun(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,V) <= loss_fun(gamma1,gamma2,gamma3,y,U0,U1,U2,U3,W10,W21,W32,Vnew)
beta2 = t*beta;
Vnew = Vstar;
else
beta2 = min(beta/t,1);
end
end
5 changes: 5 additions & 0 deletions Algorithms/updateVb.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
function [Wstar,bstar] = updateVb(y,a,W,b,alpha,gamma,lambda)
[K,~] = size(y);
Wstar = W - 1/alpha*(gamma*(-y*a'+softmax(W*a+b-max(W*a+b,[],2))*a')+lambda*W);
bstar = b - gamma/alpha*(-ones(K,1)+sum(softmax(W*a+b-max(W*a+b,[],2)),2));
end
Loading

0 comments on commit cb5f8af

Please sign in to comment.