Skip to content

Commit

Permalink
Extend MPS solvers to trees (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
leburgel authored Jan 11, 2023
1 parent 0d64da8 commit 18bad9f
Show file tree
Hide file tree
Showing 20 changed files with 1,288 additions and 783 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
Expand Down
6 changes: 5 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using IsApprox
using ITensors
using ITensors.ContractionSequenceOptimization
using ITensors.ITensorVisualizationCore
using IterTools
using KrylovKit: KrylovKit
using NamedGraphs
using Observers
Expand Down Expand Up @@ -87,6 +88,8 @@ include(joinpath("treetensornetworks", "abstractprojttno.jl"))
include(joinpath("treetensornetworks", "projttno.jl"))
include(joinpath("treetensornetworks", "projttnosum.jl"))
include(joinpath("treetensornetworks", "projttno_apply.jl"))
# Compatibility of ITensors.MPS/MPO with tree sweeping routines
include(joinpath("treetensornetworks", "solvers", "tree_patch.jl"))
# Compatibility of ITensor observer and Observers
# TODO: Delete this
include(joinpath("treetensornetworks", "solvers", "update_observer.jl"))
Expand All @@ -103,10 +106,11 @@ include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
include(joinpath("treetensornetworks", "solvers", "projmpo_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "contract_mpo_mps.jl"))
include(joinpath("treetensornetworks", "solvers", "contract_operator_state.jl"))
include(joinpath("treetensornetworks", "solvers", "projmps2.jl"))
include(joinpath("treetensornetworks", "solvers", "projmpo_mps2.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))

include("exports.jl")

Expand Down
52 changes: 0 additions & 52 deletions src/treetensornetworks/solvers/contract_mpo_mps.jl

This file was deleted.

62 changes: 62 additions & 0 deletions src/treetensornetworks/solvers/contract_operator_state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
function contract_solver(; kwargs...)
function solver(PH, t, psi; kws...)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
end
return solver
end

function ITensors.contract(
::ITensors.Algorithm"fit",
A::IsTreeOperator,
psi0::ST;
init_state=psi0,
nsweeps=1,
kwargs...,
)::ST where {ST<:IsTreeState}
n = nv(A)
n != nv(psi0) && throw(
DimensionMismatch("Number of sites operator ($n) and state ($(nv(psi0))) do not match"),
)
if n == 1
v = only(vertices(psi0))
return ST([A[v] * psi0[v]])
end

check_hascommoninds(siteinds, A, psi0)

# In case A and psi0 have the same link indices
A = sim(linkinds, A)

# Fix site and link inds of init_state
init_state = deepcopy(init_state)
init_state = sim(linkinds, init_state)
for v in vertices(psi0)
replaceinds!(
init_state[v], siteinds(init_state, v), uniqueinds(siteinds(A, v), siteinds(psi0, v))
)
end

t = Inf
reverse_step = false
PH = proj_operator_apply(psi0, A)
psi = tdvp(
contract_solver(; kwargs...), PH, t, init_state; nsweeps, reverse_step, kwargs...
)

return psi
end

# extra ITensors overloads for tree tensor networks
function ITensors.contract(A::TTNO, ψ::TTNS; alg="fit", kwargs...)
return contract(ITensors.Algorithm(alg), A, ψ; kwargs...)
end

function ITensors.apply(A::TTNO, ψ::TTNS; kwargs...)
= contract(A, ψ; kwargs...)
return replaceprime(Aψ, 1 => 0)
end
4 changes: 2 additions & 2 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ function eigsolve_solver(; kwargs...)
return solver
end

function dmrg(H, psi0::MPS; kwargs...)
function dmrg(H, psi0::IsTreeState; kwargs...)
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
reverse_step = false
psi = tdvp(eigsolve_solver(; kwargs...), H, t, psi0; reverse_step, kwargs...)
return psi
end

# Alias for DMRG
function eigsolve(H, psi0::MPS; kwargs...)
function eigsolve(H, psi0::IsTreeState; kwargs...)
return dmrg(H, psi0; kwargs...)
end
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function dmrg_x_solver(PH, t, psi0; kwargs...)
return U_max, nothing
end

function dmrg_x(PH, psi0::MPS; reverse_step=false, kwargs...)
function dmrg_x(PH, psi0::IsTreeState; reverse_step=false, kwargs...)
t = Inf
psi = tdvp(dmrg_x_solver, PH, t, psi0; reverse_step, kwargs...)
return psi
Expand Down
2 changes: 2 additions & 0 deletions src/treetensornetworks/solvers/projmpo_mps2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ end
contract(P::ProjMPO_MPS2, v::ITensor) = contract(P.PH, v)

proj_mps(P::ProjMPO_MPS2) = [proj_mps(m) for m in P.Ms]

underlying_graph(P::ProjMPO_MPS2) = chain_lattice_graph(length(P.PH.H)) # tree patch
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/solver_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct TimeDependentSum{S,T}
f::Vector{S}
H0::T
end
TimeDependentSum(f::Vector, H0::ProjMPOSum) = TimeDependentSum(f, H0.pm)
TimeDependentSum(f::Vector, H0::IsTreeProjOperatorSum) = TimeDependentSum(f, H0.pm)
Base.length(H::TimeDependentSum) = length(H.f)

function Base.:*(c::Number, H::TimeDependentSum)
Expand Down
6 changes: 3 additions & 3 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ function tdvp_solver(; kwargs...)
end
end

function tdvp(H, t::Number, psi0::MPS; kwargs...)
function tdvp(H, t::Number, psi0::IsTreeState; kwargs...)
return tdvp(tdvp_solver(; kwargs...), H, t, psi0; kwargs...)
end

function tdvp(t::Number, H, psi0::MPS; kwargs...)
function tdvp(t::Number, H, psi0::IsTreeState; kwargs...)
return tdvp(H, t, psi0; kwargs...)
end

function tdvp(H, psi0::MPS, t::Number; kwargs...)
function tdvp(H, psi0::IsTreeState, t::Number; kwargs...)
return tdvp(H, t, psi0; kwargs...)
end
24 changes: 13 additions & 11 deletions src/treetensornetworks/solvers/tdvp_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function process_sweeps(; kwargs...)
return (; maxdim, mindim, cutoff, noise)
end

function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
function tdvp(solver, PH, t::Number, psi0::IsTreeState; kwargs...)
reverse_step = get(kwargs, :reverse_step, true)

nsweeps = _tdvp_compute_nsweeps(t; kwargs...)
Expand Down Expand Up @@ -124,37 +124,37 @@ function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
end

"""
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
tdvp(H::MPS,psi0::MPO,t::Number; kwargs...)
tdvp(H::TTNS,psi0::TTNO,t::Number; kwargs...)
Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
on alternating optimization of the state tensors and local Krylov
exponentiation of H.
Returns:
* `psi::MPS` - time-evolved MPS
* `psi` - time-evolved state
Optional keyword arguments:
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(solver, H::MPO, t::Number, psi0::MPS; kwargs...)
function tdvp(solver, H::IsTreeOperator, t::Number, psi0::IsTreeState; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjMPO(H)
PH = proj_operator(H)
return tdvp(solver, PH, t, psi0; kwargs...)
end

function tdvp(solver, t::Number, H, psi0::MPS; kwargs...)
function tdvp(solver, t::Number, H, psi0::IsTreeState; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
end

function tdvp(solver, H, psi0::MPS, t::Number; kwargs...)
function tdvp(solver, H, psi0::IsTreeState, t::Number; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
end

Expand All @@ -177,12 +177,14 @@ each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function tdvp(solver, Hs::Vector{MPO}, t::Number, psi0::MPS; kwargs...)
function tdvp(
solver, Hs::Vector{<:IsTreeOperator}, t::Number, psi0::IsTreeState; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjMPOSum(Hs)
PHs = proj_operator_sum(Hs)
return tdvp(solver, PHs, t, psi0; kwargs...)
end
Loading

0 comments on commit 18bad9f

Please sign in to comment.