diff --git a/Project.toml b/Project.toml index 1e22fb09ef..b0e43f719d 100644 --- a/Project.toml +++ b/Project.toml @@ -17,7 +17,6 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" -Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" PackageCompiler = "9b87118b-4619-50d2-8e1e-99f35a4d4d9d" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -35,12 +34,15 @@ Zeros = "bd1ec220-6eb4-527a-9b49-e79c3db6233b" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] +Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" [extensions] +ITensorsObserversExt = "Observers" ITensorsVectorInterfaceExt = "VectorInterface" [extras] +Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8" [compat] diff --git a/ext/ITensorsObserversExt/ITensorsObserversExt.jl b/ext/ITensorsObserversExt/ITensorsObserversExt.jl new file mode 100644 index 0000000000..32c16e1a09 --- /dev/null +++ b/ext/ITensorsObserversExt/ITensorsObserversExt.jl @@ -0,0 +1,11 @@ +module ITensorsObserversExt + +using ITensors.ITensorMPS: ITensorMPS +using Observers: Observers +using Observers.DataFrames: AbstractDataFrame + +function ITensorMPS.update_observer!(observer::AbstractDataFrame; kwargs...) + return Observers.update!(observer; kwargs...) +end + +end diff --git a/src/ITensorMPS/alternating_update.jl b/src/ITensorMPS/alternating_update.jl index 3770f8b86e..63c62154c6 100644 --- a/src/ITensorMPS/alternating_update.jl +++ b/src/ITensorMPS/alternating_update.jl @@ -98,7 +98,7 @@ function alternating_update( ) end current_time += time_step - update!(step_observer!; psi, sweep, outputlevel, current_time) + update_observer!(step_observer!; psi, sweep, outputlevel, current_time) if outputlevel >= 1 print("After sweep ", sweep, ":") print(" maxlinkdim=", maxlinkdim(psi)) diff --git a/test/ITensorsObserversExt/Project.toml b/test/ITensorsObserversExt/Project.toml new file mode 100644 index 0000000000..c279eee897 --- /dev/null +++ b/test/ITensorsObserversExt/Project.toml @@ -0,0 +1,3 @@ +[deps] +ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" +Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" diff --git a/test/ITensorsObserversExt/runtests.jl b/test/ITensorsObserversExt/runtests.jl new file mode 100644 index 0000000000..5cdd0bdc2a --- /dev/null +++ b/test/ITensorsObserversExt/runtests.jl @@ -0,0 +1,25 @@ +@eval module $(gensym()) +using Test: @test, @testset +using ITensors.ITensorMPS: update_observer! +using Observers: observer + +@testset "ITensorsObserversExt" begin + function iterative_function(niter; observer!, observe_step) + for n in 1:niter + if iszero(n % observe_step) + update_observer!(observer!; iteration=n) + end + end + end + + # Record the iteration + iteration(; iteration) = iteration + + obs = observer(iteration) + niter = 100 + iterative_function(niter; (observer!)=obs, observe_step=10) + + @test size(obs) == (10, 1) +end + +end diff --git a/test/Project.toml b/test/Project.toml index 3d7da6330d..7d45ec9848 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" +Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" diff --git a/test/runtests.jl b/test/runtests.jl index 332b41a6a0..08fea5beb9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,6 +21,7 @@ ITensors.disable_threaded_blocksparse() "ContractionSequenceOptimization", "ITensorChainRules", "ITensorNetworkMaps", + "ITensorsObserversExt", "ITensorsVectorInterfaceExt", ] @time for dir in dirs