Skip to content

Commit

Permalink
Add new Observer that gets updated after each time step (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Apr 21, 2022
1 parent 79932ee commit acd15ef
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 19 deletions.
15 changes: 8 additions & 7 deletions examples/04_tdvp_observers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ttotal = 1.0
s = siteinds("S=1/2", N; conserve_qns=true)
H = MPO(heisenberg(N), s)

function sweep(; sweep, bond, half_sweep)
function step(; sweep, bond, half_sweep)
if bond == 1 && half_sweep == 2
return sweep
end
Expand Down Expand Up @@ -49,10 +49,10 @@ function return_state(; psi, bond, half_sweep)
end

obs = Observer(
"sweeps" => sweep, "times" => current_time, "psis" => return_state, "Sz" => measure_sz
"steps" => step, "times" => current_time, "psis" => return_state, "Sz" => measure_sz
)

psi = productMPS(s, n -> isodd(n) ? "Up" : "Dn")
psi = MPS(s, n -> isodd(n) ? "Up" : "Dn")
psi_f = tdvp(
H,
-im * ttotal,
Expand All @@ -65,14 +65,15 @@ psi_f = tdvp(
)

res = results(obs)

sweeps = res["sweeps"]
steps = res["steps"]
times = res["times"]
psis = res["psis"]
Sz = res["Sz"]

for n in 1:length(sweeps)
print("sweep = ", sweeps[n])
println("\nResults")
println("=======")
for n in 1:length(steps)
print("step = ", steps[n])
print(", time = ", round(times[n]; digits=3))
print(", |⟨ψⁿ|ψⁱ⟩| = ", round(abs(inner(psis[n], psi)); digits=3))
print(", |⟨ψⁿ|ψᶠ⟩| = ", round(abs(inner(psis[n], psi_f)); digits=3))
Expand Down
47 changes: 47 additions & 0 deletions examples/05_tdvp_nonuniform_timesteps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using ITensors
using ITensorTDVP

include("05_utils.jl")

function heisenberg(N)
os = OpSum()
for j in 1:(N - 1)
os += 0.5, "S+", j, "S-", j + 1
os += 0.5, "S-", j, "S+", j + 1
os += "Sz", j, "Sz", j + 1
end
return os
end

N = 10
cutoff = 1e-12
outputlevel = 1
nsteps = 10
time_steps = [n 2 ? -0.2im : -0.1im for n in 1:nsteps]

obs = Observer("times" => (; current_time) -> current_time, "psis" => (; psi) -> psi)

s = siteinds("S=1/2", N; conserve_qns=true)
H = MPO(heisenberg(N), s)

psi0 = MPS(s, n -> isodd(n) ? "Up" : "Dn")
psi = tdvp_nonuniform_timesteps(
ProjMPO(H), psi0; time_steps, cutoff, outputlevel, (step_observer!)=obs
)

res = results(obs)
times = res["times"]
psis = res["psis"]

println("\nResults")
println("=======")
print("step = ", 0)
print(", time = ", zero(ComplexF64))
print(", ⟨Sᶻ⟩ = ", round(expect(psi0, "Sz"; sites=N ÷ 2); digits=3))
println()
for n in 1:length(times)
print("step = ", n)
print(", time = ", round(times[n]; digits=3))
print(", ⟨Sᶻ⟩ = ", round(expect(psis[n], "Sz"; sites=N ÷ 2); digits=3))
println()
end
60 changes: 60 additions & 0 deletions examples/05_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using ITensors
using ITensorTDVP
using Observers
using Printf

using ITensorTDVP: tdvp_solver, process_sweeps, TDVPOrder

function tdvp_nonuniform_timesteps(
solver,
PH,
psi::MPS;
time_steps,
reverse_step=true,
time_start=0.0,
order=2,
(step_observer!)=Observer(),
kwargs...,
)
nsweeps = length(time_steps)
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, kwargs...)
tdvp_order = TDVPOrder(order, Base.Forward)
current_time = time_start
for sw in 1:nsweeps
sw_time = @elapsed begin
psi, PH, info = tdvp(
tdvp_order,
solver,
PH,
time_steps[sw],
psi;
kwargs...,
current_time,
reverse_step,
sweep=sw,
maxdim=maxdim[sw],
mindim=mindim[sw],
cutoff=cutoff[sw],
noise=noise[sw],
)
end
current_time += time_steps[sw]

update!(step_observer!; psi, sweep=sw, outputlevel, current_time)

if outputlevel 1
print("After sweep ", sw, ":")
print(" maxlinkdim=", maxlinkdim(psi))
@printf(" maxerr=%.2E", info.maxtruncerr)
print(" current_time=", round(current_time; digits=3))
print(" time=", round(sw_time; digits=3))
println()
flush(stdout)
end
end
return psi
end

function tdvp_nonuniform_timesteps(H, psi::MPS; kwargs...)
return tdvp_nonuniform_timesteps(tdvp_solver(; kwargs...), H, psi; kwargs...)
end
5 changes: 4 additions & 1 deletion src/tdvp_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
write_when_maxdim_exceeds::Union{Int,Nothing} = get(
kwargs, :write_when_maxdim_exceeds, nothing
)
observer = get(kwargs, :observer, NoObserver())
observer = get(kwargs, :observer!, NoObserver())
step_observer = get(kwargs, :step_observer!, NoObserver())
outputlevel::Int = get(kwargs, :outputlevel, 0)

psi = copy(psi0)
Expand Down Expand Up @@ -97,6 +98,8 @@ function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)

current_time += time_step

update!(step_observer; psi, sweep=sw, outputlevel, current_time)

if outputlevel >= 1
print("After sweep ", sw, ":")
print(" maxlinkdim=", maxlinkdim(psi))
Expand Down
2 changes: 1 addition & 1 deletion src/tdvp_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ function tdvp(direction::Base.Ordering, solver, PH, time_step::Number, psi::MPS;
@printf(" cutoff=%.1E", cutoff)
@printf(" maxdim=%.1E", maxdim)
print(" mindim=", mindim)
print(" current_time=", current_time)
print(" current_time=", round(current_time; digits=3))
println()
if spec != nothing
@printf(
Expand Down
31 changes: 21 additions & 10 deletions test/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,23 +376,34 @@ end

obs = Observer("Sz" => measure_sz, "En" => measure_en)

psi2 = productMPS(s, n -> isodd(n) ? "Up" : "Dn")
tdvp(H, -im * ttotal, psi2; time_step=-im * tau, cutoff, normalize=false, (observer!)=obs)
step_measure_sz(; psi) = expect(psi, "Sz"; sites=c)

step_measure_en(; psi) = real(inner(psi', H, psi))

step_obs = Observer("Sz" => step_measure_sz, "En" => step_measure_en)

psi2 = MPS(s, n -> isodd(n) ? "Up" : "Dn")
tdvp(
H,
-im * ttotal,
psi2;
time_step=-im * tau,
cutoff,
normalize=false,
(observer!)=obs,
(step_observer!)=step_obs,
)

# Using filter here just due to the current
# behavior of Observers that nothing gets appended:
Sz2 = results(obs)["Sz"]
En2 = results(obs)["En"]

#display(En1)
#display(En2)
#display(Sz1)
#display(Sz2)
#@show norm(Sz1 - Sz2)
#@show norm(En1 - En2)
Sz2_step = results(step_obs)["Sz"]
En2_step = results(step_obs)["En"]

@test Sz1 Sz2
@test En1 En2
@test Sz1 Sz2_step
@test En1 En2_step
end

nothing

2 comments on commit acd15ef

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/60668

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.0.1 -m "<description of version>" acd15efb066c44fb5801c764f6e7e4bd9b28bddb
git push origin v0.0.1

Please sign in to comment.