Skip to content

Commit

Permalink
v0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed May 12, 2024
1 parent d7c90d5 commit 7f60c30
Show file tree
Hide file tree
Showing 16 changed files with 464 additions and 343 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorTDVP"
uuid = "25707e16-a4db-4a07-99d9-4d67b7af0342"
authors = ["Matthew Fishman <mfishman@flatironinstitute.org> and contributors"]
version = "0.3.1"
version = "0.4.0"

[deps]
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
Expand Down
4 changes: 2 additions & 2 deletions examples/01_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ function main()

ϕ = tdvp(
H,
-1.0,
-20.0,
ψ;
nsweeps=20,
time_step=-1.0,
reverse_step=false,
normalize=true,
maxdim=30,
Expand Down
4 changes: 1 addition & 3 deletions examples/02_dmrg-x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ function main()
initstate = rand(["", ""], n)
ψ = MPS(s, initstate)

dmrg_x_kwargs = (
nsweeps=10, reverse_step=false, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=1
)
dmrg_x_kwargs = (nsweeps=10, normalize=true, maxdim=20, cutoff=1e-10, outputlevel=1)

e, ϕ = dmrg_x(H, ψ; dmrg_x_kwargs...)

Expand Down
120 changes: 64 additions & 56 deletions src/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
using ITensors: permute
using ITensors: ITensors, permute
using ITensors.ITensorMPS:
AbstractObserver,
## AbstractObserver,
MPO,
MPS,
ProjMPO,
ProjMPOSum,
check_hascommoninds,
checkdone!,
## checkdone!,
disk,
linkind,
maxlinkdim,
siteinds

function _compute_nsweeps(t; time_step=default_time_step(t), nsweeps=default_nsweeps())
if isinf(t) && isnothing(nsweeps)
nsweeps = 1
elseif !isnothing(nsweeps) && time_step != t
error("Cannot specify both time_step and nsweeps in alternating_update")
elseif isfinite(time_step) && abs(time_step) > 0 && isnothing(nsweeps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step t)
error("Time step $time_step not commensurate with total time t=$t")
end
end
return nsweeps
end
## function _compute_nsweeps(t; time_step=default_time_step(t), nsweeps=default_nsweeps())
##
## @show t, time_step, nsweeps
##
## if isinf(t) && isnothing(nsweeps)
## nsweeps = 1
## elseif !isnothing(nsweeps) && time_step != t
## error("Cannot specify both time_step and nsweeps in alternating_update")
## elseif isfinite(time_step) && abs(time_step) > 0 && isnothing(nsweeps)
## nsweeps = convert(Int, ceil(abs(t / time_step)))
## if !(nsweeps * time_step ≈ t)
## error("Time step $time_step not commensurate with total time t=$t")
## end
## end
##
## @show nsweeps
##
## return nsweeps
## end

function _extend_sweeps_param(param, nsweeps)
if param isa Number
Expand All @@ -48,30 +54,29 @@ end

function alternating_update(
solver,
PH,
t::Number,
psi0::MPS;
reduced_operator,
init::MPS;
nsweeps=default_nsweeps(),
checkdone=default_checkdone(),
write_when_maxdim_exceeds=default_write_when_maxdim_exceeds(),
nsite=default_nsite(),
reverse_step=default_reverse_step(),
time_start=default_time_start(),
time_step=default_time_step(t),
time_step=default_time_step(),
order=default_order(),
(observer!)=default_observer!(),
(step_observer!)=default_step_observer!(),
outputlevel=default_outputlevel(),
normalize=default_normalize(),
maxdim=default_maxdim(),
mindim=default_mindim(),
cutoff=default_cutoff(Float64),
cutoff=default_cutoff(ITensors.scalartype(init)),
noise=default_noise(),
)
nsweeps = _compute_nsweeps(t; time_step, nsweeps)
## nsweeps = _compute_nsweeps(t; time_step, nsweeps)
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, maxdim, mindim, cutoff, noise)
forward_order = TDVPOrder(order, Base.Forward)
psi = copy(psi0)
state = copy(init)
# Keep track of the start of the current time step.
# Helpful for tracking the total time, for example
# when using time-dependent solvers.
Expand All @@ -86,17 +91,17 @@ function alternating_update(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim(sweeps, sw) = $(maxdim(sweeps, sweep)), writing environment tensors to disk",
)
end
PH = disk(PH)
reduced_operator = disk(reduced_operator)
end
sweep_time = @elapsed begin
psi, PH, info = sweep_update(
sweep_elapsed_time = @elapsed begin
state, reduced_operator, info = sweep_update(
forward_order,
solver,
PH,
time_step,
psi;
reduced_operator,
state;
nsite,
current_time,
time_step,
reverse_step,
sweep,
observer!,
Expand All @@ -107,51 +112,54 @@ function alternating_update(
noise=noise[sweep],
)
end
current_time += time_step
update_observer!(step_observer!; psi, sweep, outputlevel, current_time)
if !isnothing(time_step)
current_time += time_step
end
update_observer!(step_observer!; state, sweep, outputlevel, current_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print(" maxlinkdim=", maxlinkdim(state))
@printf(" maxerr=%.2E", info.maxtruncerr)
print(" current_time=", round(current_time; digits=3))
print(" time=", round(sweep_time; digits=3))
print(" time=", round(sweep_elapsed_time; digits=3))
println()
flush(stdout)
end
isdone = false
if !isnothing(checkdone)
isdone = checkdone(; psi, sweep, outputlevel)
elseif observer! isa AbstractObserver
isdone = checkdone!(observer!; psi, sweep, outputlevel)
end
isdone = checkdone(; state, sweep, outputlevel)
## isdone = false
## if !isnothing(checkdone)
## isdone = checkdone(; state, sweep, outputlevel)
## elseif observer! isa AbstractObserver
## isdone = checkdone!(observer!; state, sweep, outputlevel)
## end
isdone && break
end
return psi
return state
end

# Convenience wrapper to not have to specify time step.
# Use a time step of `Inf` as a convention, since TDVP
# with an infinite time step corresponds to DMRG.
function alternating_update(solver, H, psi0::MPS; kwargs...)
return alternating_update(solver, H, ITensors.scalartype(psi0)(Inf), psi0; kwargs...)
end
## function alternating_update(solver, operator, init::MPS; kwargs...)
## return alternating_update(solver, operator, ITensors.scalartype(init)(Inf), init; kwargs...)
## end

function alternating_update(solver, H::MPO, t::Number, psi0::MPS; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
function alternating_update(solver, operator::MPO, init::MPS; kwargs...)
check_hascommoninds(siteinds, operator, init)
check_hascommoninds(siteinds, operator, init')
# Permute the indices to have a better memory layout
# and minimize permutations
H = permute(H, (linkind, siteinds, linkind))
PH = ProjMPO(H)
return alternating_update(solver, PH, t, psi0; kwargs...)
operator = permute(operator, (linkind, siteinds, linkind))
reduced_operator = ProjMPO(operator)
return alternating_update(solver, reduced_operator, init; kwargs...)
end

function alternating_update(solver, Hs::Vector{MPO}, t::Number, psi0::MPS; kwargs...)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
function alternating_update(solver, operators::Vector{MPO}, init::MPS; kwargs...)
for operator in operators
check_hascommoninds(siteinds, operator, init)
check_hascommoninds(siteinds, operator, init')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjMPOSum(Hs)
return alternating_update(solver, PHs, t, psi0; kwargs...)
operators .= ITensors.permute.(operators, Ref((linkind, siteinds, linkind)))
reduced_operator = ProjMPOSum(operators)
return alternating_update(solver, reduced_operator, init; kwargs...)
end
57 changes: 32 additions & 25 deletions src/contract_mpo_mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,55 @@ using ITensors:
siteinds

function contractmpo_solver(; kwargs...)
function solver(PH, t, psi; kws...)
v = ITensor(true)
for j in (PH.lpos + 1):(PH.rpos - 1)
v *= PH.psi0[j]
function solver(reduced_operator, psi; kws...)
reduced_state = ITensor(true)
for j in (reduced_operator.lpos + 1):(reduced_operator.rpos - 1)
reduced_state *= reduced_operator.input_state[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
reduced_state = contract(reduced_operator, reduced_state)
return reduced_state, nothing
end
return solver
end

# `init_mps` is for backwards compatibility.
function ITensors.contract(
::Algorithm"fit", A::MPO, psi0::MPS; init_mps=psi0, nsweeps=1, kwargs...
::Algorithm"fit",
operator::MPO,
input_state::MPS;
init=input_state,
init_mps=init,
kwargs...,
)::MPS
n = length(A)
n != length(psi0) &&
throw(DimensionMismatch("lengths of MPO ($n) and MPS ($(length(psi0))) do not match"))
n = length(operator)
n != length(input_state) && throw(
DimensionMismatch("lengths of MPO ($n) and MPS ($(length(input_state))) do not match")
)
if n == 1
return MPS([A[1] * psi0[1]])
return MPS([operator[1] * input_state[1]])
end
any(i -> isempty(i), siteinds(commoninds, A, psi0)) &&
error("In `contract(A::MPO, x::MPS)`, `A` and `x` must share a set of site indices")
# In case A and psi0 have the same link indices
A = sim(linkinds, A)
# Fix site and link inds of init_mps
init_mps = deepcopy(init_mps)
init_mps = sim(linkinds, init_mps)
Ai = siteinds(A)
any(i -> isempty(i), siteinds(commoninds, operator, input_state)) && error(
"In `contract(operator::MPO, x::MPS)`, `operator` and `x` must share a set of site indices",
)
# In case operator and input_state have the same link indices
operator = sim(linkinds, operator)
# Fix site and link inds of init
init = deepcopy(init)
init = sim(linkinds, init)
siteinds_operator = siteinds(operator)
ti = Vector{Index}(undef, n)
for j in 1:n
for i in Ai[j]
if !hasind(psi0[j], i)
for i in siteinds_operator[j]
if !hasind(input_state[j], i)
ti[j] = i
break
end
end
end
replace_siteinds!(init_mps, ti)
reverse_step = false
PH = ProjMPOApply(psi0, A)
replace_siteinds!(init, ti)
reduced_operator = ProjMPOApply(input_state, operator)
psi = alternating_update(
contractmpo_solver(; kwargs...), PH, init_mps; nsweeps, reverse_step, kwargs...
contractmpo_solver(; kwargs...), reduced_operator, init; kwargs...
)
return psi
end
14 changes: 7 additions & 7 deletions src/defaults.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
using ITensors: NoObserver
## using ITensors: NoObserver
using KrylovKit: eigsolve, exponentiate

default_nsweeps() = nothing
default_checkdone() = nothing
default_checkdone() = Returns(false)
default_write_when_maxdim_exceeds() = nothing
default_nsite() = 2
default_reverse_step() = true
default_time_start() = 0
default_time_step(t) = t
default_reverse_step() = false
default_time_start() = nothing
default_time_step() = nothing
default_order() = 2
default_observer!() = NoObserver()
default_step_observer!() = NoObserver()
default_observer!() = EmptyObserver()
default_step_observer!() = EmptyObserver()
default_outputlevel() = 0
default_normalize() = false
default_sweep() = 1
Expand Down
20 changes: 9 additions & 11 deletions src/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ function dmrg_solver(
solver_maxiter,
solver_verbosity,
)
function solver(H, t, psi0; current_time, outputlevel)
function solver(operator, init; current_time, time_step, outputlevel)
howmany = 1
which = solver_which_eigenvalue
vals, vecs, info = f(
H,
psi0,
operator,
init,
howmany,
which;
ishermitian=default_ishermitian(),
Expand All @@ -27,8 +27,8 @@ function dmrg_solver(
end

function dmrg(
H,
psi0::MPS;
operator,
init::MPS;
ishermitian=default_ishermitian(),
solver_which_eigenvalue=default_solver_which_eigenvalue(eigsolve),
solver_tol=default_solver_tol(eigsolve),
Expand All @@ -38,11 +38,10 @@ function dmrg(
(observer!)=default_observer!(),
kwargs...,
)
reverse_step = false
info_ref! = Ref{Any}()
info_observer! = values_observer(; info=info_ref!)
observer! = compose_observers(observer!, info_observer!)
psi = alternating_update(
state = alternating_update(
dmrg_solver(
eigsolve;
solver_which_eigenvalue,
Expand All @@ -52,11 +51,10 @@ function dmrg(
solver_maxiter,
solver_verbosity,
),
H,
psi0;
reverse_step,
operator,
init;
observer!,
kwargs...,
)
return info_ref![].eigval, psi
return info_ref![].eigval, state
end
Loading

0 comments on commit 7f60c30

Please sign in to comment.