Skip to content

Commit

Permalink
[ITensors] [BUG] Fix bug in loginner when inner is negative or comp…
Browse files Browse the repository at this point in the history
…lex (#945)

* Fix bug in `loginner` when inner is negative or complex.

* Fix subtraction bug in `OpSum`.

* Forward truncation arguments to more operations in `rrule` for `apply`.
  • Loading branch information
mtfishman authored Jun 29, 2022
1 parent 9bf2e13 commit 7fabf97
Show file tree
Hide file tree
Showing 10 changed files with 244 additions and 8 deletions.
49 changes: 49 additions & 0 deletions examples/autodiff/mps_autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using ITensors
using OptimKit
using Zygote

function ising(n; J, h)
os = OpSum()
for j in 1:(n - 1)
os += -J, "Z", j, "Z", j + 1
end
for j in 1:n
os += -h, "X", j
end
return os
end

function loss(H, ψ)
n = length(ψ)
ψHψ = ITensor(1.0)
ψψ = ITensor(1.0)
for j in 1:n
ψHψ = ψHψ * dag(ψ[j]') * H[j] * ψ[j]
ψψ = ψψ * replaceinds(dag(ψ[j]'), s[j]' => s[j]) * ψ[j]
end
return ψHψ[] / ψψ[]
end

n = 10
s = siteinds("S=1/2", n)
J = 1.0
h = 0.5

# Loss function only works with `Vector{ITensor}`,
# extract with `ITensors.data`.
ψ0 = ITensors.data(randomMPS(s; linkdims=10))
H = ITensors.data(MPO(ising(n; J, h), s))

loss(ψ) = loss(H, ψ)

optimizer = LBFGS(; maxiter=25, verbosity=2)
function loss_and_grad(x)
y, (∇,) = withgradient(loss, x)
return y, ∇
end
ψ, fs, gs, niter, normgradhistory = optimize(loss_and_grad, ψ0, optimizer)
Edmrg, ψdmrg = dmrg(MPO(H), MPS(ψ0); nsweeps=10, cutoff=1e-8)

@show loss(ψ0), norm(loss'(ψ0))
@show loss(ψ), norm(loss'(ψ))
@show loss(ITensors.data(ψdmrg)), norm(loss'(ITensors.data(ψdmrg)))
63 changes: 63 additions & 0 deletions examples/exact_diagonalization/exact_diagonalization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using ITensors
using KrylovKit
using LinearAlgebra
using MKL

include("fuse_inds.jl")

ITensors.Strided.disable_threads()
ITensors.disable_threaded_blocksparse()

function heisenberg(n)
os = OpSum()
for j in 1:(n - 1)
os += 1 / 2, "S+", j, "S-", j + 1
os += 1 / 2, "S-", j, "S+", j + 1
os += "Sz", j, "Sz", j + 1
end
return os
end

function main(n; blas_num_threads=Sys.CPU_THREADS, fuse=true, binary=true)
if n > 16
@warn "System size of $n is likely too large for exact diagonalization."
end

BLAS.set_num_threads(blas_num_threads)

# Hilbert space
s = siteinds("S=1/2", n; conserve_qns=true)
H = MPO(heisenberg(n), s)
initstate(j) = isodd(j) ? "" : ""
ψ0 = randomMPS(s, initstate; linkdims=10)

edmrg, ψdmrg = dmrg(H, ψ0; nsweeps=10, cutoff=1e-6)

if fuse
if binary
println("Fuse the indices using a binary tree")
T = fusion_tree_binary(s)
H_full = @time fuse_inds_binary(H, T)
ψ0_full = @time fuse_inds_binary(ψ0, T)
else
println("Fuse the indices using an unbalances tree")
T = fusion_tree(s)
H_full = @time fuse_inds(H, T)
ψ0_full = @time fuse_inds(ψ0, T)
end
else
println("Don't fuse the indices")
@disable_warn_order begin
H_full = @time contract(H)
ψ0_full = @time contract(ψ0)
end
end

vals, vecs, info = @time eigsolve(
H_full, ψ0_full, 1, :SR; ishermitian=true, tol=1e-6, krylovdim=30, eager=true
)

@show edmrg, vals[1]
end

main(14)
91 changes: 91 additions & 0 deletions examples/exact_diagonalization/fuse_inds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using ITensors

function fusion_tree(s::Vector{<:Index})
n = length(s)
Cs = Vector{ITensor}(undef, n - 1)
cj = s[1]
for j in 1:(n - 1)
fuse_inds = (cj, s[j + 1])
Cj = combiner(fuse_inds...)
Cs[j] = Cj
cj = uniqueind(Cj, fuse_inds)
end
return Cs
end

function fuse_inds(A::MPS, fusion_tree::Vector{ITensor})
n = length(A)
A_fused = A[1]
for j in 2:n
A_fused = A_fused * A[j] * fusion_tree[j - 1]
end
return A_fused
end

function fuse_inds(A::MPO, fusion_tree::Vector{ITensor})
n = length(A)
A_fused = A[1]
for j in 2:n
A_fused = A_fused * A[j] * dag(fusion_tree[j - 1]) * fusion_tree[j - 1]'
end
return A_fused
end

function fusion_tree_binary_layer(s::Vector{IndexT}; layer=1) where {IndexT<:Index}
n = length(s)
Cs = ITensor[]
cs = IndexT[]
for j in 1:2:(n - 1)
fuse_inds = (s[j], s[j + 1])
Cj = combiner(fuse_inds...; tags="n=$(j)$(j + 1),l=$(layer)")
push!(Cs, Cj)
cj = uniqueind(Cj, fuse_inds)
push!(cs, cj)
end
if isodd(n)
push!(cs, last(s))
end
return Cs, cs
end

function fusion_tree_binary(s::Vector{<:Index}; depth=ceil(Int, log2(length(s))))
Cs = Vector{ITensor}[]
c_layer = s
for layer in 1:depth
C_layer, c_layer = fusion_tree_binary_layer(c_layer; layer)
push!(Cs, C_layer)
end
return Cs
end

function fuse_tensors(A::MPS, fusion_tree_layer::Vector{ITensor}, j::Int)
return A[j] * A[j + 1] * fusion_tree_layer[(j + 1) ÷ 2]
end

function fuse_tensors(A::MPO, fusion_tree_layer::Vector{ITensor}, j::Int)
return A[j] *
A[j + 1] *
dag(fusion_tree_layer[(j + 1) ÷ 2]) *
fusion_tree_layer[(j + 1) ÷ 2]'
end

function fuse_inds_binary_layer(A::Union{MPS,MPO}, fusion_tree_layer::Vector{ITensor})
n = length(fusion_tree_layer)
A_fused = ITensor[]
for j in 1:2:(2n)
push!(A_fused, fuse_tensors(A, fusion_tree_layer, j))
end
if isodd(length(A))
push!(A_fused, A[end])
end
return typeof(A)(A_fused)
end

function fuse_inds_binary(A::Union{MPS,MPO}, fusion_tree::Vector{Vector{ITensor}})
depth = length(fusion_tree)
A_fused = A
for layer in 1:depth
A_fused = fuse_inds_binary_layer(A_fused, fusion_tree[layer])
end
return only(A_fused)
end
15 changes: 9 additions & 6 deletions src/ITensorChainRules/mps/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ function _contract(::Type{MPO}, ψ::MPS, ϕ::MPS; kwargs...)
return contract(ψmat, ϕmat; kwargs...)
end

function rrule(::typeof(apply), x1::Vector{ITensor}, x2::Union{MPS,MPO}; kwargs...)
function rrule(
::typeof(apply), x1::Vector{ITensor}, x2::Union{MPS,MPO}; apply_dag=false, kwargs...
)
N = length(x1) + 1
apply_dag = x2 isa MPO ? get(kwargs, :apply_dag, false) : nothing

# Apply circuit and store intermediates in the forward direction
x1x2 = Vector{typeof(x2)}(undef, N)
x1x2[1] = x2
for n in 2:N
x1x2[n] = apply(x1[n - 1], x1x2[n - 1]; move_sites_back=true, kwargs...)
x1x2[n] = apply(x1[n - 1], x1x2[n - 1]; move_sites_back=true, apply_dag, kwargs...)
end
y = x1x2[end]

Expand All @@ -72,7 +73,9 @@ function rrule(::typeof(apply), x1::Vector{ITensor}, x2::Union{MPS,MPO}; kwargs.
x1dag_ȳ = Vector{typeof(x2)}(undef, N)
x1dag_ȳ[end] =
for n in (N - 1):-1:1
x1dag_ȳ[n] = apply(x1dag[n], x1dag_ȳ[n + 1]; move_sites_back=true, kwargs...)
x1dag_ȳ[n] = apply(
x1dag[n], x1dag_ȳ[n + 1]; move_sites_back=true, apply_dag, kwargs...
)
end

x̄1 = similar(x1)
Expand All @@ -87,7 +90,7 @@ function rrule(::typeof(apply), x1::Vector{ITensor}, x2::Union{MPS,MPO}; kwargs.
# apply U on one side of the MPO
if apply_dag
ϕ̃ = swapprime(x1x2dag[n], 0 => 1)
ϕ̃ = apply(x1[n], ϕ̃; move_sites_back=true, apply_dag=false)
ϕ̃ = apply(x1[n], ϕ̃; move_sites_back=true, apply_dag=false, kwargs...)
ϕ̃ = mapprime(ϕ̃, 1 => 2, 0 => 1)
ϕ̃ = replaceprime(ϕ̃, 1 => 0; inds=gateinds')
ξ̃ = 2 * dag(x1dag_ȳ[n + 1])'
Expand All @@ -97,7 +100,7 @@ function rrule(::typeof(apply), x1::Vector{ITensor}, x2::Union{MPS,MPO}; kwargs.
ξ̃ = mapprime(x1dag_ȳ[n + 1], 0 => 2)
end
end
x̄1[n] = _contract(ITensor, ξ̃, ϕ̃)
x̄1[n] = _contract(ITensor, ξ̃, ϕ̃; kwargs...)
else
s = inds(x1[n])
x̄1[n] = itensor(zeros(dim.(s)), s...)
Expand Down
4 changes: 4 additions & 0 deletions src/LazyApply/LazyApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ function (a1::Sum{Scaled{C,Prod{A}}} - a2::Prod{A}) where {C,A}
return a1 + (-a2)
end

function (a1::Sum{Scaled{C1,Prod{A}}} - a2::Scaled{C2,Prod{A}}) where {C1,C2,A}
return a1 + (-a2)
end

function (a1::Sum{A} + a2::Scaled{C,Prod{A}}) where {C,A}
return Sum{Scaled{C,Prod{A}}}() + a1 + a2
end
Expand Down
3 changes: 2 additions & 1 deletion src/Ops/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ function split(f, t::Tuple)
return s
end

## XXX: Very long compile times
## XXX: Very long compile times:
## https://github.com/JuliaLang/julia/issues/45545
##
## julia> using ITensors
##
Expand Down
5 changes: 4 additions & 1 deletion src/mps/abstractmps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,9 @@ function _log_or_not_dot(
end

if loginner
if !isreal(O[]) || real(O[]) < 0
log_inner_tot += log(complex(O[]))
end
return log_inner_tot
end

Expand Down Expand Up @@ -1213,7 +1216,7 @@ function lognorm(M::AbstractMPS)
"log(norm²) is $lognorm2_M, which is not real up to a relative tolerance of $rtol"
)
end
return 0.5 * lognorm2_M
return 0.5 * real(lognorm2_M)
end

function isapprox(
Expand Down
7 changes: 7 additions & 0 deletions src/physics/autompo/opsum_to_mpo_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ function sorteachterm(os::OpSum, sites)
for n in eachindex(os)
t = os[n]
Nt = length(t)

if maximum(ITensors.sites(t)) > length(sites)
error(
"The OpSum contains a term $t that extends beyond the number of sites $(length(sites)).",
)
end

prevsite = N + 1 #keep track of whether we are switching
#to a new site to make sure F string
#is only placed at most once for each site
Expand Down
2 changes: 2 additions & 0 deletions test/Ops/test_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ end
@test Ops.OpSum() + 1.2 * o1 isa Sum{Scaled{ComplexF64,Prod{Op}}}
@test Ops.OpSum() + (1.2 + 2.3im) * o1 isa Sum{Scaled{ComplexF64,Prod{Op}}}
@test Ops.OpSum() + 1.2 * o1 * o2 isa Sum{Scaled{ComplexF64,Prod{Op}}}
@test Ops.OpSum() - 1.2 * o1 * o2 isa Sum{Scaled{ComplexF64,Prod{Op}}}
@test Ops.OpSum() + o1 * o2 isa Sum{Scaled{ComplexF64,Prod{Op}}}
@test o1 + o2 + 2.3 * o1 * o2 isa Sum{Scaled{Float64,Prod{Op}}}
@test Sum{Op}() + ("X", 1, "Y", 2) + ("Y", 2) isa Sum{Prod{Op}}
@test Sum{Op}() + ("X", 1, "Y", 2) + (1.2, "Y", 2) isa Sum{Scaled{Float64,Prod{Op}}}
@test OpSum() - (0.5, "Z", 1, "Z", 2) isa Sum{Scaled{ComplexF64,Prod{Op}}}

N = 4
s = siteinds("Qubit", N)
Expand Down
13 changes: 13 additions & 0 deletions test/mps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,19 @@ include("util.jl")
@test_throws DimensionMismatch inner(phi, badpsi)
end

@testset "loginner" begin
n = 4
c = 2

s = siteinds("S=1/2", n)
ψ = c .* randomMPS(s; linkdims=4)
@test exp(loginner(ψ, ψ)) c^(2n)
@test exp(loginner(ψ, -ψ)) -c^(2n)

α = randn(ComplexF64)
@test exp(loginner(ψ, α * ψ)) α * c^(2n)
end

@testset "broadcasting" begin
psi = randomMPS(sites)
orthogonalize!(psi, 1)
Expand Down

0 comments on commit 7fabf97

Please sign in to comment.