Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ITensorMPS] Observers package extension #1376

Closed
wants to merge 9 commits into from
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
Zeros = "bd1ec220-6eb4-527a-9b49-e79c3db6233b"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"

[extensions]
ITensorsObserversExt = "Observers"

[extras]
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"

[compat]
Adapt = "3.5, 4"
BitIntegers = "0.2, 0.3"
Expand Down
11 changes: 11 additions & 0 deletions ext/ITensorsObserversExt/ITensorsObserversExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module ITensorsObserversExt

using ITensors.ITensorMPS: ITensorMPS
using Observers: Observers
using Observers.DataFrames: AbstractDataFrame

function ITensorMPS.update_observer!(observer::AbstractDataFrame; kwargs...)
return Observers.update!(observer; kwargs...)
end

end
2 changes: 1 addition & 1 deletion src/ITensorMPS/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function alternating_update(
)
end
current_time += time_step
update!(step_observer!; psi, sweep, outputlevel, current_time)
update_observer!(step_observer!; psi, sweep, outputlevel, current_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
Expand Down
8 changes: 8 additions & 0 deletions src/ITensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,12 @@ end
# _precompile_()
#end

#
# See section on "Transition from normal dependency to extension"
# in https://pkgdocs.julialang.org/v1/creating-packages
#
if !isdefined(Base, :get_extension)
include("../ext/ITensorsObserversExt.jl")
end

emstoudenmire marked this conversation as resolved.
Show resolved Hide resolved
end # module ITensors
3 changes: 3 additions & 0 deletions test/ITensorsObserversExt/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[deps]
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
25 changes: 25 additions & 0 deletions test/ITensorsObserversExt/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@eval module $(gensym())
using Test: @test, @testset
using ITensors.ITensorMPS: update_observer!
using Observers: observer

@testset "ITensorsObserversExt" begin
function iterative_function(niter; observer!, observe_step)
for n in 1:niter
if iszero(n % observe_step)
update_observer!(observer!; iteration=n)
end
end
end

# Record the iteration
iteration(; iteration) = iteration

obs = observer(iteration)
niter = 100
iterative_function(niter; (observer!)=obs, observe_step=10)

@test size(obs) == (10, 1)
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ITensors.disable_threaded_blocksparse()
"ContractionSequenceOptimization",
"ITensorChainRules",
"ITensorNetworkMaps",
"ITensorsObserversExt",
]
@time for dir in dirs
println("\nTest $(@__DIR__)/$(dir)")
Expand Down
Loading