Skip to content

Commit

Permalink
Change some internal names from tdvp to alternating_update, sweep_upd…
Browse files Browse the repository at this point in the history
…ate, region_update
  • Loading branch information
mtfishman committed Feb 16, 2024
1 parent d830f37 commit 755bc92
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 112 deletions.
14 changes: 7 additions & 7 deletions examples/05_utils.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ITensors: MPS, maxlinkdim
using ITensorTDVP: TDVPOrder, process_sweeps, tdvp_solver, tdvp_step, process_sweeps
using ITensorTDVP: ITensorTDVP
using Observers: observer, update!
using Printf: @printf

Expand All @@ -20,12 +20,14 @@ function tdvp_nonuniform_timesteps(
kwargs...,
)
nsweeps = length(time_steps)
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, maxdim, mindim, cutoff, noise)
tdvp_order = TDVPOrder(order, Base.Forward)
maxdim, mindim, cutoff, noise = ITensorTDVP.process_sweeps(;
nsweeps, maxdim, mindim, cutoff, noise
)
tdvp_order = ITensorTDVP.TDVPOrder(order, Base.Forward)
current_time = time_start
for sw in 1:nsweeps
sw_time = @elapsed begin
psi, PH, info = tdvp_step(
psi, PH, info = ITensorTDVP.sweep_update(
tdvp_order,
solver,
PH,
Expand All @@ -42,9 +44,7 @@ function tdvp_nonuniform_timesteps(
)
end
current_time += time_steps[sw]

update!(step_observer!; psi, sweep=sw, outputlevel, current_time)

if outputlevel 1
print("After sweep ", sw, ":")
print(" maxlinkdim=", maxlinkdim(psi))
Expand All @@ -70,7 +70,7 @@ function tdvp_nonuniform_timesteps(
kwargs...,
)
return tdvp_nonuniform_timesteps(
tdvp_solver(
ITensorTDVP.tdvp_solver(
exponentiate;
ishermitian,
issymmetric,
Expand Down
4 changes: 2 additions & 2 deletions src/ITensorTDVP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ include("update_observer.jl")
include("solver_utils.jl")
include("tdvporder.jl")
include("tdvpinfo.jl")
include("tdvp_step.jl")
include("tdvp_generic.jl")
include("sweep_update.jl")
include("alternating_update.jl")
include("tdvp.jl")
include("dmrg.jl")
include("dmrg_x.jl")
Expand Down
85 changes: 19 additions & 66 deletions src/tdvp_generic.jl → src/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
using ITensors:
AbstractObserver,
MPO,
MPS,
ProjMPOSum,
check_hascommoninds,
checkdone!,
disk,
linkind,
maxlinkdim,
permute
using ITensors: AbstractObserver, MPS, checkdone!, disk, maxlinkdim

function _tdvp_compute_nsweeps(t; time_step=default_time_step(t), nsweeps=default_nsweeps())
function _compute_nsweeps(t; time_step=default_time_step(t), nsweeps=default_nsweeps())
if isinf(t) && isnothing(nsweeps)
nsweeps = 1
elseif !isnothing(nsweeps) && time_step != t
error("Cannot specify both time_step and nsweeps in tdvp")
error("Cannot specify both time_step and nsweeps in alternating_update")
elseif isfinite(time_step) && abs(time_step) > 0 && isnothing(nsweeps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step t)
Expand Down Expand Up @@ -44,7 +34,7 @@ function process_sweeps(; nsweeps, maxdim, mindim, cutoff, noise)
return (; maxdim, mindim, cutoff, noise)
end

function tdvp(
function alternating_update(
solver,
PH,
t::Number,
Expand All @@ -66,9 +56,9 @@ function tdvp(
cutoff=default_cutoff(Float64),
noise=default_noise(),
)
nsweeps = _tdvp_compute_nsweeps(t; time_step, nsweeps)
nsweeps = _compute_nsweeps(t; time_step, nsweeps)
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, maxdim, mindim, cutoff, noise)
tdvp_order = TDVPOrder(order, Base.Forward)
forward_order = TDVPOrder(order, Base.Forward)
psi = copy(psi0)
# Keep track of the start of the current time step.
# Helpful for tracking the total time, for example
Expand All @@ -87,8 +77,8 @@ function tdvp(
PH = disk(PH)
end
sweep_time = @elapsed begin
psi, PH, info = tdvp_step(
tdvp_order,
psi, PH, info = sweep_update(
forward_order,
solver,
PH,
time_step,
Expand Down Expand Up @@ -118,7 +108,7 @@ function tdvp(
end
isdone = false
if !isnothing(checkdone)
isdone = checkdone(; psi, sweep, outputlevel) #, kwargs...)
isdone = checkdone(; psi, sweep, outputlevel)
elseif observer! isa AbstractObserver
isdone = checkdone!(observer!; psi, sweep, outputlevel)
end
Expand All @@ -127,66 +117,29 @@ function tdvp(
return psi
end

"""
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.
Returns:
* `psi::MPS` - time-evolved MPS
# Convenience wrapper to not have to specify time step.
# Use a time step of `Inf` as a convention, since TDVP
# with an infinite time step corresponds to DMRG.
function alternating_update(solver, H, psi0::MPS; kwargs...)
return alternating_update(solver, H, ITensors.scalartype(psi0)(Inf), psi0; kwargs...)
end

Optional keyword arguments:
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(solver, H::MPO, t::Number, psi0::MPS; kwargs...)
function alternating_update(solver, H::MPO, t::Number, psi0::MPS; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = permute(H, (linkind, siteinds, linkind))
PH = ProjMPO(H)
return tdvp(solver, PH, t, psi0; kwargs...)
end

function tdvp(solver, t::Number, H, psi0::MPS; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
return alternating_update(solver, PH, t, psi0; kwargs...)
end

function tdvp(solver, H, psi0::MPS, t::Number; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
end

"""
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number; kwargs...)
tdvp(Hs::Vector{MPO},psi0::MPS,t::Number, sweeps::Sweeps; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
exponentiation of H.
This version of `tdvp` accepts a representation of H as a
Vector of MPOs, Hs = [H1,H2,H3,...] such that H is defined
as H = H1+H2+H3+...
Note that this sum of MPOs is not actually computed; rather
the set of MPOs [H1,H2,H3,..] is efficiently looped over at
each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function tdvp(solver, Hs::Vector{MPO}, t::Number, psi0::MPS; kwargs...)
function alternating_update(solver, Hs::Vector{MPO}, t::Number, psi0::MPS; kwargs...)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjMPOSum(Hs)
return tdvp(solver, PHs, t, psi0; kwargs...)
return alternating_update(solver, PHs, t, psi0; kwargs...)
end
10 changes: 2 additions & 8 deletions src/contract_mpo_mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,10 @@ function ITensors.contract(
if n == 1
return MPS([A[1] * psi0[1]])
end

any(i -> isempty(i), siteinds(commoninds, A, psi0)) &&
error("In `contract(A::MPO, x::MPS)`, `A` and `x` must share a set of site indices")

# In case A and psi0 have the same link indices
A = sim(linkinds, A)

# Fix site and link inds of init_mps
init_mps = deepcopy(init_mps)
init_mps = sim(linkinds, init_mps)
Expand All @@ -53,13 +50,10 @@ function ITensors.contract(
end
end
replace_siteinds!(init_mps, ti)

t = Inf
reverse_step = false
PH = ProjMPOApply(psi0, A)
psi = tdvp(
contractmpo_solver(; kwargs...), PH, t, init_mps; nsweeps, reverse_step, kwargs...
psi = alternating_update(
contractmpo_solver(; kwargs...), PH, init_mps; nsweeps, reverse_step, kwargs...
)

return psi
end
4 changes: 1 addition & 3 deletions src/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ function dmrg(
solver_verbosity=default_solver_verbosity(),
kwargs...,
)
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
reverse_step = false
psi = tdvp(
psi = alternating_update(
dmrg_solver(
eigsolve;
solver_which_eigenvalue,
Expand All @@ -51,7 +50,6 @@ function dmrg(
solver_verbosity,
),
H,
t,
psi0;
reverse_step,
kwargs...,
Expand Down
3 changes: 1 addition & 2 deletions src/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ function dmrg_x_solver(PH, t, psi0; current_time, outputlevel)
end

function dmrg_x(PH, psi0::MPS; reverse_step=false, kwargs...)
t = ITensors.scalartype(psi0)(Inf)
psi = tdvp(dmrg_x_solver, PH, t, psi0; reverse_step, kwargs...)
psi = alternating_update(dmrg_x_solver, PH, psi0; reverse_step, kwargs...)
return psi
end
4 changes: 1 addition & 3 deletions src/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ function KrylovKit.linsolve(
x, info = linsolve(P, b, x₀, a₀, a₁; solver_kwargs...)
return x, nothing
end

t = Inf
P = ProjMPO_MPS2(A, b)
return tdvp(linsolve_solver, P, t, x₀; reverse_step=false, tdvp_kwargs...)
return alternating_update(linsolve_solver, P, x₀; reverse_step=false, tdvp_kwargs...)
end
26 changes: 13 additions & 13 deletions src/tdvp_step.jl → src/sweep_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ using LinearAlgebra: norm, normalize!, svd
using Observers: update!
using Printf: @printf

function tdvp_step(
function sweep_update(
order::TDVPOrder, solver, PH, time_step::Number, psi::MPS; current_time=0, kwargs...
)
order_orderings = orderings(order)
order_sub_time_steps = eltype(time_step).(sub_time_steps(order))
order_sub_time_steps *= time_step
info = nothing
for substep in 1:length(order_sub_time_steps)
psi, PH, info = tdvp_sweep(
psi, PH, info = sub_sweep_update(
order_orderings[substep],
solver,
PH,
Expand Down Expand Up @@ -47,7 +47,7 @@ function is_half_sweep_done(direction, b, n; ncenter)
is_reverse_done(direction, b, n; ncenter)
end

function tdvp_sweep(
function sub_sweep_update(
direction::Base.Ordering,
solver,
PH,
Expand All @@ -71,7 +71,7 @@ function tdvp_sweep(
psi = copy(psi)
if length(psi) == 1
error(
"`tdvp` currently does not support system sizes of 1. You can diagonalize the MPO tensor directly with tools like `LinearAlgebra.eigen`, `KrylovKit.exponentiate`, etc.",
"`tdvp`, `dmrg`, `linsolve`, etc. currently does not support system sizes of 1. You can diagonalize the MPO tensor directly with tools like `LinearAlgebra.eigen`, `KrylovKit.exponentiate`, etc.",
)
end
N = length(psi)
Expand All @@ -92,7 +92,7 @@ function tdvp_sweep(
maxtruncerr = 0.0
info = nothing
for b in sweep_bonds(direction, N; ncenter=nsite)
current_time, maxtruncerr, spec, info = tdvp_site_update!(
current_time, maxtruncerr, spec, info = region_update!(
solver,
PH,
psi,
Expand Down Expand Up @@ -149,7 +149,7 @@ function tdvp_sweep(
return psi, PH, TDVPInfo(maxtruncerr)
end

function tdvp_site_update!(
function region_update!(
solver,
PH,
psi,
Expand All @@ -169,7 +169,7 @@ function tdvp_site_update!(
mindim,
maxtruncerr,
)
return tdvp_site_update!(
return region_update!(
Val(nsite),
Val(reverse_step),
solver,
Expand All @@ -191,7 +191,7 @@ function tdvp_site_update!(
)
end

function tdvp_site_update!(
function region_update!(
nsite_val::Val{1},
reverse_step_val::Val{false},
solver,
Expand Down Expand Up @@ -230,7 +230,7 @@ function tdvp_site_update!(
return current_time, maxtruncerr, spec, info
end

function tdvp_site_update!(
function region_update!(
nsite_val::Val{1},
reverse_step_val::Val{true},
solver,
Expand Down Expand Up @@ -290,7 +290,7 @@ function tdvp_site_update!(
return current_time, maxtruncerr, spec, info
end

function tdvp_site_update!(
function region_update!(
nsite_val::Val{2},
reverse_step_val::Val{false},
solver,
Expand Down Expand Up @@ -342,7 +342,7 @@ function tdvp_site_update!(
return current_time, maxtruncerr, spec, info
end

function tdvp_site_update!(
function region_update!(
nsite_val::Val{2},
reverse_step_val::Val{true},
solver,
Expand Down Expand Up @@ -407,7 +407,7 @@ function tdvp_site_update!(
return current_time, maxtruncerr, spec, info
end

function tdvp_site_update!(
function region_update!(
::Val{nsite},
::Val{reverse_step},
solver,
Expand All @@ -428,6 +428,6 @@ function tdvp_site_update!(
maxtruncerr,
) where {nsite,reverse_step}
return error(
"`tdvp` with `nsite=$nsite` and `reverse_step=$reverse_step` not implemented."
"`tdvp`, `dmrg`, `linsolve`, etc. with `nsite=$nsite` and `reverse_step=$reverse_step` not implemented.",
)
end
Loading

0 comments on commit 755bc92

Please sign in to comment.