Skip to content

Commit

Permalink
tt_gmres output refactored to the debug structure of dmrg and amen
Browse files Browse the repository at this point in the history
  • Loading branch information
dolgov committed Jun 1, 2014
1 parent f7927bb commit 0b398d1
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 113 deletions.
4 changes: 2 additions & 2 deletions fmex/Makefile.in
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ ifeq ($(CPU),dc-laptop-intel)
CC = icc
COPT = -O2 -fPIC
LIB = -mkl
MATLAB = /home/dc/matlab_inst/bin/matlab
MATLAB = matlab
MCFLAGS = -fexceptions -fPIC -fno-omit-frame-pointer -pthread
MLLIB2 = /opt/intel/composerxe/lib/intel64/libifcoremt_pic.a
MLLIB2 = /opt/intel/composerxe-2013.2.144/compiler/lib/intel64/libifcoremt_pic.a
endif
ifeq ($(CPU),mpg-intel)
FOPT = -i8 -O2 -fPIC -vec-report=0 -Itt-fort/
Expand Down
2 changes: 1 addition & 1 deletion fmex/tt-fort
Submodule tt-fort updated 2 files
+6 −2 timef.f90
+31 −6 tt.f90
159 changes: 49 additions & 110 deletions solve/tt_gmres.m
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@


function [x,RESVEC,rw,rx] = tt_gmres(A, b, tol, maxout, maxin, eps_x, eps_z, M1, M2, M3, x0, verbose, varargin)
function [x,td] = tt_gmres(A, b, tol, maxout, maxin, eps_x, eps_z, M1, M2, M3, x0, verbose, varargin)
%TT-GMRES method
% [x] = TT_GMRES(A, B, TOL, MAXOUT, MAXIN, EPS_X,
% [x,td] = TT_GMRES(A, B, TOL, MAXOUT, MAXIN, EPS_X,
% EPS_Z, M1, M2, M3, X0, VERBOSE, VARARGIN)
% GMRES in TT format. Solve a system A*x=b, up to the accuracy tol,
% with maximum number of outer iterations maxout, the size of Krylov basis
% maxin, compression of solution eps_x, compression of Krylov vectors
% eps_z, optional LEFT preconditioner: [M1*[M2]*[M3]], initial guess [x0]
% (default is 1e-308*b), [verbose] - if 1 or unspecified - print messages,
% if 0 - silent mode
% if 0 - silent mode, if 2 - print messages, and also log solutions to td.
%
% Debug uutput td is currently syncronised with the amen and dmrg:
% td{1} - cumulative CPU time,
% td{2} - current solutions x,
% td{3} - local residuals
% If you want the older (also full gmres) syntax with local residuals as
% the second output, please git checkout to some commit before 01.06.2014
%
% TT-Toolbox 2.2, 2009-2012
%
Expand All @@ -30,12 +34,19 @@
max_zrank_factor = 4; % maximum rank of Krylov vectors with respect to the rank of X.
derr_tol_for_sp = 1.0; % minimal jump of residual at one iteration, which is considered as
% a stagnation.
use_err_trunc = 1; % use compression accuracy eps_z/err for Krylov vectors
err_trunc_power = 0.8; % theoretical estimate - 1 - sometimes sucks
compute_real_res = 1; % Compute the real residual on each step (expensive!)
use_err_trunc = 1; % use relaxation, i.e. compression accuracy eps_z/err for Krylov vectors
err_trunc_power = 1; % theoretical estimate - 1 - sometimes worse than something like 0.8
compute_real_res = 0; % Compute the real residual on each step (expensive!)
% extra param in tt_iterapp.m: matvec_alg
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if (nargout>1)
td = cell(3,1);
td{1} = zeros(maxin, maxout);
td{2} = cell(maxin, maxout);
td{3} = zeros(maxin, maxout);
end;

t0 = tic;

existsM1=1;
Expand Down Expand Up @@ -84,7 +95,6 @@
pre_b2 = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,pre_b2,min(eps_z, 0.5), max_rank, max_swp,varargin{:});
end;
end;
% pre_b = tt_stabilize(pre_b,0);

norm_f = tt_dot2(pre_b,pre_b)*0.5;
mod_norm_f = sign(norm_f)*mod(abs(norm_f), 10);
Expand All @@ -94,28 +104,12 @@
if (existsx0)
x = x0;
else
% x = pre_b;
% x = tt_scal2(b, log(1e-308), 1);
% x = tt_zeros(max(size(b)), tt_size(b));
x = core(tt_zeros(tt_size(b), max(size(b))));
end;

% x_old = x;

H = zeros(maxin+1, maxin);
v = cell(maxin, 1);

if (nargout>1)
RESVEC = ones(maxout, maxin);
end;
if (nargout>2)
rw = zeros(maxout,maxin);
end;
if (nargout>2)
rx = zeros(maxout,maxin);
end;
% z = cell(maxin, 1);

err=2;
max_err=Inf;
old_err = 1;
Expand All @@ -128,62 +122,45 @@
r = tt_axpy2(0,1, pre_b, 0,-1, Ax, min(eps_z/err_for_trunc, 0.5), max_rank);
if (existsM1)
r = tt_iterapp('mldivide',m1fun,m1type,m1fcnstr,r,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
% Ax = tt_stabilize(Ax,0);
if (existsM2)
r = tt_iterapp('mldivide',m2fun,m2type,m2fcnstr,r,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
% Ax = tt_stabilize(Ax,0);
end;
if (existsM3)
r = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,r,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
% Ax = tt_stabilize(Ax,0);
end;
end;
% Ax = tt_stabilize(Ax,0);
% r = tt_add(pre_b, tt_scal(Ax, -1));
% r = tt_compr2(r, min(eps_z/err, 0.5), max_rank);
% beta = sqrt(tt_dot(r, r))
beta = tt_dot2(r,r)*0.5;
mod_beta = sign(beta)*mod(abs(beta), 10);
order_beta = beta - mod_beta;
cur_beta = exp(mod_beta);
if (verbose==1)
if (verbose>0)
real_beta = exp(beta);
fprintf(' cur_beta = %g, real_beta = %g\n', cur_beta, real_beta);
end;
% if (beta<tol) break; end;
if (nitout==1)
cur_normb = cur_beta;
order_normb = order_beta;
% normb = beta;
end;
% if (nitout==1)
% cur_normb = cur_beta;
% order_normb = order_beta;
% end;
v{1} = tt_scal2(r, -beta, 1);
% v{1}=tt_stabilize(v{1},0);
for j=1:maxin
max_w_rank = 0;
w = tt_iterapp('mtimes',afun,atype,afcnstr,v{j},min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
max_w_rank = max([max_w_rank; tt_ranks(w)]);
% w=tt_stabilize(w,0);
% max_Mx_rank = max(tt_ranks(w))
if (existsM1)
w = tt_iterapp('mldivide',m1fun,m1type,m1fcnstr,w,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
max_w_rank = max([max_w_rank; tt_ranks(w)]);
% w=tt_stabilize(w,0);
if (existsM2)
w = tt_iterapp('mldivide',m2fun,m2type,m2fcnstr,w,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
max_w_rank = max([max_w_rank; tt_ranks(w)]);
% w=tt_stabilize(w,0);
% max_Px_rank = max(tt_ranks(w))
end;
if (existsM3)
w = tt_iterapp('mldivide',m3fun,m3type,m3fcnstr,w,min(eps_z/err_for_trunc, 0.5), max_rank, max_swp,varargin{:});
max_w_rank = max([max_w_rank; tt_ranks(w)]);
% w=tt_stabilize(w,0);
end;
end;

max_wrank = max(tt_ranks(w));

% wnew = w;
for i=1:j
H(i,j)=tt_dot(w, v{i});
% w = tt_axpy2(0,1, w, log(abs(H(i,j))+1e-308), -1*sign(H(i,j)), v{i}, min(eps_z/err_for_trunc, 0.5), max_rank);
Expand All @@ -192,78 +169,49 @@
% Orthogonality muss sein!
% w = tt_axpy3(0,1, w, log(abs(H(i,j))+1e-308), -1*sign(H(i,j)), v{i});
end;
% w = tt_wround([], w, min(eps_z/err_for_trunc,0.5), 'rmax', max_rank, 'nswp', max_swp, 'verb', 1);
% for i=1:j
% wv = abs(tt_dot(w, v{i}));
% if (wv>eps_z)
% fprintf('Warning: dot(w,v{%d}) = %3.3e\n', i, wv);
% end;
% end;
% w=wnew;
if (nargout>2)
rw(nitout,j)=max_w_rank;
end;


H(j+1,j) = sqrt(tt_dot(w, w));
if (j<maxin)
v{j+1}=tt_scal2(w, -log(H(j+1,j)), 1);
% v{j+1}=tt_stabilize(v{j+1},0);
end;

[UH,SH,VH]=svd(H(1:j+1, 1:j), 0);
SH = diag(SH);
sigma_min_H = min(SH);
sigma_max_H = max(SH);
if (verbose==1)
if (verbose>0)
fprintf(' min(sigma(H)) = %g\n', sigma_min_H);
end;
SH(1:numel(find(SH>1e-100)))=1./SH(1:numel(find(SH>1e-100))); %pseudoinv(SH)
SH = diag(SH);
y = cur_beta*VH*SH*(UH(1,:)'); % y = pinv(H)*(beta*e1)
% y = beta*VH*SH*(UH(1,:)'); % y = pinv(H)*(beta*e1)

% err = norm(H(1:j+1, 1:j)*y-[beta zeros(1,j)]', 'fro')/normb;
% err = log(norm(H(1:j+1, 1:j)*y-[cur_beta zeros(1,j)]', 'fro')/cur_normb+1e-308);
% err = err+order_beta-order_normb;
err = log(norm(H(1:j+1, 1:j)*y-[cur_beta zeros(1,j)]', 'fro')/cur_norm_f+1e-308);
err = err+order_beta-order_norm_f;
err = exp(err);
if (use_err_trunc==1)
% err_for_trunc=err;
err_for_trunc = log(norm(H(1:j+1, 1:j)*y-[cur_beta zeros(1,j)]', 'fro')/cur_beta+1e-308);
err_for_trunc = exp(err_for_trunc);
err_for_trunc = (err_for_trunc*maxin*sigma_max_H/sigma_min_H)^(err_trunc_power);
err_for_trunc = min(err_for_trunc, 1);
% err_for_trunc = err;
end;
% err_to_f = log(norm(H(1:j+1, 1:j)*y-[cur_beta zeros(1,j)]', 'fro')/cur_norm_f+1e-308);
% err_to_f = err_to_f+order_beta-order_norm_f;
% err_to_f = exp(err_to_f);

% report the residual
if (nargout>1)
RESVEC(nitout,j)=err;
td{3}(j,nitout)=err;
end;


x_new = x;
max_x_rank = 0;
% dx = tt_scal2(v{j}, log(abs(y(j))+1e-308)+order_beta, sign(y(j)));
for i=j:-1:1
% dx = tt_add(dx, tt_scal2(v{i}, log(abs(y(i))+1e-308)+order_beta, sign(y(i))));
% dx = tt_axpy2(0,1, dx, log(abs(y(i))+1e-308)+order_beta, sign(y(i)), v{i}, eps_x);
x_new = tt_axpy2(0,1, x_new, log(abs(y(i))+1e-308)+order_beta, sign(y(i)), v{i}, eps_x);
% x_new = tt_axpy3(0,1, x_new, log(abs(y(i))+1e-308)+order_beta, sign(y(i)), v{i});
max_x_rank = max([max_x_rank; tt_ranks(x_new)]);
end;
% x_new = tt_axpy2(0,1,x,0,1,dx,eps_x);
if (nargout>3)
rx(nitout,j)=max_x_rank;
% Report the current solution, if required.
if (verbose>1)
td{2}{j,nitout} = x_new;
end;
% x_new = tt_wround([], x_new, eps_x, 'verb', 1);
% x_new = tt_add(x, dx);
% x_new = tt_compr2(x_new, eps_x);
% x_new = tt_axpy2(0,1,x,0,1,dx,eps_x);


max_xrank = max(tt_ranks(x_new));
max_rank = max_zrank_factor*max_xrank;

Expand All @@ -283,49 +231,40 @@

derr = old_err/err;
old_err = err;
% dx = tt_add(x_new, tt_scal(x_old, -1));
% normx = sqrt(tt_dot(x_new,x_new));
% x_old=x_new;
% dx_norm = sqrt(tt_dot(dx, dx))/normx;
if (derr<derr_tol_for_sp) stagpoints=stagpoints+1; end;
if (verbose==1)
% err_for_trunc
if (derr<derr_tol_for_sp); stagpoints=stagpoints+1; end;
if (verbose>0)
if (compute_real_res==1)
fprintf('iter = [%d,%d], derr = %3.2f, resid=%3.2e, real_res=%3.2e, rank_w=%d, rank_x=%d, sp=%d, time=%g\n', nitout, j, derr, err, res, max_wrank, max_xrank, stagpoints, toc(t0));
else
fprintf('iter = [%d,%d], derr = %3.2f, resid=%3.2e, rank_w=%d, rank_x=%d, sp=%d, time=%g\n', nitout, j, derr, err, max_w_rank, max_x_rank, stagpoints, toc(t0));
end;
end;

% Report time
td{1}(j,nitout) = toc(t0);

if (err<max_err)
x_good = x_new;
max_err=err;
end;
if (err<tol) break; end;
if (stagpoints>=maxin*max_sp_factor) break; end;
if (err<tol); break; end;
if (stagpoints>=maxin*max_sp_factor); break; end;
end;
x = x_good;

if (err<tol) break; end;
if (stagpoints>=maxin*max_sp_factor) break; end;
if (err<tol); break; end;
if (stagpoints>=maxin*max_sp_factor); break; end;
end;

if (nargout>1)
RESVEC=RESVEC(1:nitout,:);
if (nitout==1)
RESVEC=RESVEC(:,1:j);
end;
end;
if (nargout>2)
rw=rw(1:nitout,:);
if (nitout==1)
rw=rw(:,1:j);
end;
end;
if (nargout>3)
rx=rx(1:nitout,:);
td{1}=td{1}(:, 1:nitout);
td{2}=td{2}(:, 1:nitout);
td{3}=td{3}(:, 1:nitout);
if (nitout==1)
rx=rx(:,1:j);
td{1}=td{1}(1:j);
td{2}=td{2}(1:j);
td{2}=td{3}(1:j);
end;
end;

end
end

0 comments on commit 0b398d1

Please sign in to comment.