Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basis expansion method with expand #24

Merged
merged 23 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ julia = "1.6"
[extras]
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
mtfishman marked this conversation as resolved.
Show resolved Hide resolved

[targets]
test = ["Test"]
3 changes: 2 additions & 1 deletion src/ITensorTDVP.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module ITensorTDVP
export TimeDependentSum, dmrg_x, linsolve, tdvp, to_vec
export TimeDependentSum, dmrg_x, expand_basis, linsolve, tdvp, to_vec
include("ITensorsExtensions.jl")
using .ITensorsExtensions: to_vec
include("applyexp.jl")
Expand All @@ -17,6 +17,7 @@ include("contract.jl")
include("reducedconstantterm.jl")
include("reducedlinearproblem.jl")
include("linsolve.jl")
include("expand_basis.jl")
using PackageExtensionCompat: @require_extensions
function __init__()
@require_extensions
Expand Down
104 changes: 104 additions & 0 deletions src/expand_basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using ITensors:
ITensors,
Algorithm,
Index,
ITensor,
@Algorithm_str,
δ,
commonind,
dag,
denseblocks,
directsum,
hasqns,
prime,
scalartype,
uniqueinds
using ITensors.ITensorMPS: MPO, MPS, apply, dim, linkind, maxlinkdim, orthogonalize
using LinearAlgebra: normalize, svd, tr

#
# Possible improvements
# - allow a maxdim argument to be passed to `extend`
# and through `basis_extend`
# - current behavior is letting bond dimension get too
# big when used in imaginary time evolution
# - Use (1-tau*operator)|state> to generate "Krylov" vectors
# instead of operator|state>. Needed?
#

function expand_basis(state, reference; alg, kwargs...)
return expand_basis(Algorithm(alg), state, reference; kwargs...)
end

"""
Given an MPS state and a collection of MPS references,
returns an MPS which is equal to state
(has fidelity 1.0 with state) but whose MPS basis
is expanded to contain a portion of the basis of
the references that is orthogonal to the MPS basis of state.
"""
function expand_basis(
::Algorithm"orthogonalize",
state::MPS,
references::Vector{MPS};
cutoff=10^2 * eps(real(scalartype(state))),
)
n = length(state)
state = orthogonalize(state, n)
references = map(reference -> orthogonalize(reference, n), references)
s = siteinds(state)
for j in reverse(2:n)
# SVD state[j] to compute basisⱼ
linds = [s[j - 1]; linkinds(state, j - 1)]
_, λⱼ, basisⱼ = svd(state[j], linds; righttags="bψ_$j,Link")
rinds = uniqueinds(basisⱼ, λⱼ)
# Make projectorⱼ
idⱼ = prod(r -> denseblocks(δ(scalartype(state), r', dag(r))), rinds)
projectorⱼ = idⱼ - prime(basisⱼ, rinds) * dag(basisⱼ)
# Sum reference density matrices
ρⱼ = sum(reference -> prime(reference[j], rinds) * dag(reference[j]), references)
# TODO: Fix bug that `tr` isn't preserving the element type.
ρⱼ /= scalartype(state)(tr(ρⱼ))
# Apply projectorⱼ
ρⱼ_projected = apply(apply(projectorⱼ, ρⱼ), projectorⱼ)
expanded_basisⱼ = basisⱼ
if norm(ρⱼ_projected) > 10^3 * eps(real(scalartype(state)))
# Diagonalize projected density matrix ρⱼ_projected
# to compute reference_basisⱼ, which spans part of right basis
# of references which is orthogonal to right basis of state
dⱼ, reference_basisⱼ = eigen(
ρⱼ_projected; cutoff, ishermitian=true, righttags="bϕ_$j,Link"
)
state_indⱼ = only(commoninds(basisⱼ, λⱼ))
reference_indⱼ = only(commoninds(reference_basisⱼ, dⱼ))
expanded_basisⱼ, bx = directsum(
basisⱼ => state_indⱼ, reference_basisⱼ => reference_indⱼ
)
end
# Shift ortho center one site left using dag(expanded_basisⱼ)
# and replace tensor at site j with expanded_basisⱼ
state[j - 1] = state[j - 1] * (state[j] * dag(expanded_basisⱼ))
state[j] = expanded_basisⱼ
for reference in references
reference[j - 1] = reference[j - 1] * (reference[j] * dag(expanded_basisⱼ))
reference[j] = expanded_basisⱼ
end
end
return state
end

function expand_basis(
::Algorithm"global_krylov",
state::MPS,
operator::MPO;
krylovdim=2,
cutoff=(√(eps(real(scalartype(state))))),
)
maxdim = maxlinkdim(state) + 1
references = Vector{MPS}(undef, krylovdim)
for k in 1:krylovdim
prev = k == 1 ? state : references[k - 1]
references[k] = normalize(apply(operator, prev; maxdim))
end
return expand_basis(state, references; alg="orthogonalize", cutoff)
end
10 changes: 9 additions & 1 deletion src/sweep_update.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
using ITensors: ITensors, uniqueinds
using ITensors.ITensorMPS:
ITensorMPS, MPS, isortho, orthocenter, orthogonalize!, position!, replacebond!, set_nsite!
ITensorMPS,
MPS,
isortho,
noiseterm,
orthocenter,
orthogonalize!,
position!,
replacebond!,
set_nsite!
using LinearAlgebra: norm, normalize!, svd
using Printf: @printf

Expand Down
86 changes: 86 additions & 0 deletions test/test_expand_basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
@eval module $(gensym())
using ITensors: scalartype
using ITensors.ITensorMPS: OpSum, MPO, MPS, inner, linkdims, maxlinkdim, randomMPS, siteinds
using ITensorTDVP: dmrg, expand_basis, tdvp
using LinearAlgebra: normalize
using Test: @test, @testset
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "expand_basis (eltype=$elt)" for elt in elts
@testset "expand_basis (alg=\"orthogonalize\", conserve_qns=$conserve_qns, eltype=$elt)" for conserve_qns in
(
false, true
)
n = 6
s = siteinds("S=1/2", n; conserve_qns)
state = randomMPS(elt, s, j -> isodd(j) ? "↑" : "↓"; linkdims=4)
reference = randomMPS(elt, s, j -> isodd(j) ? "↑" : "↓"; linkdims=2)
state_expanded = expand_basis(state, [reference]; alg="orthogonalize")
@test scalartype(state_expanded) === elt
@test inner(state_expanded, state) ≈ inner(state, state)
@test inner(state_expanded, reference) ≈ inner(state, reference)
end
@testset "basis_extend (alg=\"global_krylov\", conserve_qns=$conserve_qns, eltype=$elt)" for conserve_qns in
(
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
false, true
)
n = 10
s = siteinds("S=1/2", n; conserve_qns)
opsum = OpSum()
for j in 1:(n - 1)
opsum += 0.5, "S+", j, "S-", j + 1
opsum += 0.5, "S-", j, "S+", j + 1
opsum += "Sz", j, "Sz", j + 1
end
operator = MPO(elt, opsum, s)
state = MPS(elt, s, j -> isodd(j) ? "↑" : "↓")
state_expanded = expand_basis(state, operator; alg="global_krylov")
@test scalartype(state_expanded) === elt
@test maxlinkdim(state_expanded) > 1
@test inner(state_expanded, state) ≈ inner(state, state)
end
@testset "Decoupled ladder (alg=\"global_krylov\", eltype=$elt)" begin
nx = 10
ny = 2
n = nx * ny
s = siteinds("S=1/2", n)
opsum = OpSum()
for j in 1:2:(n - 2)
opsum += 1 / 2, "S+", j, "S-", j + 2
opsum += 1 / 2, "S-", j, "S+", j + 2
opsum += "Sz", j, "Sz", j + 2
end
for j in 2:2:(n - 2)
opsum += 1 / 2, "S+", j, "S-", j + 2
opsum += 1 / 2, "S-", j, "S+", j + 2
opsum += "Sz", j, "Sz", j + 2
end
operator = MPO(elt, opsum, s)
init = randomMPS(elt, s; linkdims=30)
reference_energy, reference_state = dmrg(
operator,
init;
nsweeps=15,
maxdim=[10, 10, 20, 20, 40, 80, 100],
cutoff=(√(eps(real(elt)))),
noise=(√(eps(real(elt)))),
)
state = randomMPS(elt, s)
nexpansions = 10
tau = elt(0.5)
for step in 1:nexpansions
state = expand_basis(state, operator; alg="global_krylov", cutoff=(∜(eps(real(elt)))))
state = tdvp(
operator,
-4tau,
state;
nsteps=4,
cutoff=1e-5,
updater_kwargs=(; tol=1e-3, krylovdim=5),
)
state = normalize(state)
end
@test scalartype(state) === elt
@test inner(state', operator, state) ≈ reference_energy rtol = 2 * ∜(eps(real(elt)))
end
end
end
2 changes: 1 addition & 1 deletion test/test_exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Test: @test, @testset
@testset "Test exports" begin
@test issetequal(
names(ITensorTDVP),
[:ITensorTDVP, :TimeDependentSum, :dmrg_x, :linsolve, :tdvp, :to_vec],
[:ITensorTDVP, :TimeDependentSum, :dmrg_x, :expand_basis, :linsolve, :tdvp, :to_vec],
)
end
end
Loading