From d9333748f5d891a555fc8507a71fba189dc610b8 Mon Sep 17 00:00:00 2001 From: Hyeongjin Kim <42390787+hjkqubit@users.noreply.github.com> Date: Wed, 29 Jun 2022 16:53:16 -0400 Subject: [PATCH 1/5] [ITensors] [ENHANCEMENT] `rrule` for `MPO` constructor * `rrule` for `MPO` constructor by generalizing the `rrule` for the `MPS` constructor --- src/ITensorChainRules/mps/abstractmps.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ITensorChainRules/mps/abstractmps.jl b/src/ITensorChainRules/mps/abstractmps.jl index 9d3bae187b..04218590b1 100644 --- a/src/ITensorChainRules/mps/abstractmps.jl +++ b/src/ITensorChainRules/mps/abstractmps.jl @@ -1,6 +1,6 @@ -function rrule(::typeof(MPS), x::Vector{<:ITensor}; kwargs...) - y = MPS(x; kwargs...) - function MPS_pullback(ȳ) +function rrule(::Type{T}, x::Vector{<:ITensor}; kwargs...) where {T<:Union{MPS,MPO}} + y = T(x; kwargs...) + function T_pullback(ȳ) ȳtensors = ȳ.data n = length(ȳtensors) envL = [ȳtensors[1] * dag(x[1])] @@ -18,7 +18,7 @@ function rrule(::typeof(MPS), x::Vector{<:ITensor}; kwargs...) push!(x̄, envL[n - 1] * ȳtensors[n]) return (NoTangent(), x̄) end - return y, MPS_pullback + return y, T_pullback end function rrule(::typeof(inner), x1::T, x2::T; kwargs...) where {T<:Union{MPS,MPO}} From f772def553f7c06af39afacdc1603d2e4e7a0232 Mon Sep 17 00:00:00 2001 From: Miles Date: Tue, 5 Jul 2022 12:19:03 -0400 Subject: [PATCH 2/5] Fix Markdown issue in Observer docs. Fixes #947 [no ci] --- docs/src/Observer.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/Observer.md b/docs/src/Observer.md index 212fbbc9c2..07cdc892ba 100644 --- a/docs/src/Observer.md +++ b/docs/src/Observer.md @@ -97,7 +97,7 @@ which include: - psi: the current wavefunction MPS - bond: the bond `b` that was just optimized, corresponding to sites `(b,b+1)` in the two-site DMRG algorihtm - sweep: the current sweep number - - sweep_is_done: true if at the end of the current sweep, otherwise false + - sweep\_is\_done: true if at the end of the current sweep, otherwise false - half_sweep: the half-sweep number, equal to 1 for a left-to-right, first half sweep, or 2 for the second, right-to-left half sweep - spec: the Spectrum object returned from factorizing the local superblock wavefunction tensor in two-site DMRG - outputlevel: an integer specifying the amount of output to show From b66d1b7c1f0b67332d07d7b539c6916dbe6d673c Mon Sep 17 00:00:00 2001 From: Matthew Fishman Date: Tue, 5 Jul 2022 20:23:36 -0400 Subject: [PATCH 3/5] [ITensors] [NDTensors] [BUG] Fix some AD bugs --- NDTensors/src/imports.jl | 3 + NDTensors/src/linearalgebra.jl | 16 +++- NDTensors/src/tensor.jl | 11 ++- NDTensors/src/tensorstorage.jl | 7 ++ src/ITensorChainRules/ITensorChainRules.jl | 5 +- src/ITensorChainRules/indexset.jl | 50 +++++------- src/ITensorChainRules/itensor.jl | 93 +++++++++++++++------- src/ITensorChainRules/mps/mpo.jl | 35 ++------ src/ITensorChainRules/projection.jl | 10 +++ src/imports.jl | 2 + src/itensor.jl | 15 ++++ src/mps/abstractmps.jl | 7 +- src/mps/mps.jl | 2 +- src/physics/site_types/qudit.jl | 5 +- test/ITensorChainRules/test_chainrules.jl | 24 ++++++ test/decomp.jl | 9 +++ test/phys_site_types.jl | 1 + 17 files changed, 196 insertions(+), 99 deletions(-) create mode 100644 src/ITensorChainRules/projection.jl diff --git a/NDTensors/src/imports.jl b/NDTensors/src/imports.jl index a63dc99940..de8f45d42f 100644 --- a/NDTensors/src/imports.jl +++ b/NDTensors/src/imports.jl @@ -24,15 +24,18 @@ import Base: fill!, getindex, hash, + imag, isempty, isless, iterate, length, + map, ndims, permutedims, permutedims!, promote_rule, randn, + real, reshape, setindex, setindex!, diff --git a/NDTensors/src/linearalgebra.jl b/NDTensors/src/linearalgebra.jl index 62ce96799c..21df0579a0 100644 --- a/NDTensors/src/linearalgebra.jl +++ b/NDTensors/src/linearalgebra.jl @@ -218,7 +218,13 @@ function LinearAlgebra.eigen( use_absolute_cutoff::Bool = get(kwargs, :use_absolute_cutoff, use_absolute_cutoff) use_relative_cutoff::Bool = get(kwargs, :use_relative_cutoff, use_relative_cutoff) - DM, VM = eigen(matrix(T)) + matrixT = matrix(T) + if any(!isfinite, matrixT) + display(matrixT) + throw(ArgumentError("Trying to perform the eigendecomposition of a matrix containing NaNs or Infs")) + end + + DM, VM = eigen(matrixT) # Sort by largest to smallest eigenvalues p = sortperm(DM; rev=true, by=abs) @@ -343,7 +349,13 @@ function LinearAlgebra.eigen( use_absolute_cutoff::Bool = get(kwargs, :use_absolute_cutoff, use_absolute_cutoff) use_relative_cutoff::Bool = get(kwargs, :use_relative_cutoff, use_relative_cutoff) - DM, VM = eigen(matrix(T)) + matrixT = matrix(T) + if any(!isfinite, matrixT) + display(matrixT) + throw(ArgumentError("Trying to perform the eigendecomposition of a matrix containing NaNs or Infs")) + end + + DM, VM = eigen(matrixT) # Sort by largest to smallest eigenvalues #p = sortperm(DM; rev = true) diff --git a/NDTensors/src/tensor.jl b/NDTensors/src/tensor.jl index ab251abc05..a34672913d 100644 --- a/NDTensors/src/tensor.jl +++ b/NDTensors/src/tensor.jl @@ -124,13 +124,20 @@ copyto!(R::Tensor, T::Tensor) = (copyto!(storage(R), storage(T)); R) complex(T::Tensor) = setstorage(T, complex(storage(T))) -Base.real(T::Tensor) = setstorage(T, real(storage(T))) +real(T::Tensor) = setstorage(T, real(storage(T))) -Base.imag(T::Tensor) = setstorage(T, imag(storage(T))) +imag(T::Tensor) = setstorage(T, imag(storage(T))) # Define Base.similar in terms of NDTensors.similar Base.similar(T::Tensor, args...) = similar(T, args...) +function map(f, x::Tensor{T}) where {T} + if !iszero(f(zero(T))) + error("map(f, ::Tensor) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.") + end + return setstorage(x, map(f, storage(x))) +end + # # Necessary to overload since the generic fallbacks are # slow diff --git a/NDTensors/src/tensorstorage.jl b/NDTensors/src/tensorstorage.jl index 619c6f22de..4d4a31a20d 100644 --- a/NDTensors/src/tensorstorage.jl +++ b/NDTensors/src/tensorstorage.jl @@ -72,6 +72,13 @@ Base.copyto!(S1::TensorStorage, S2::TensorStorage) = (copyto!(data(S1), data(S2) Random.randn!(S::TensorStorage) = (randn!(data(S)); S) +function map(f, x::TensorStorage{T}) where {T} + if !iszero(f(zero(T))) + error("map(f, ::TensorStorage) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.") + end + return setdata(x, map(f, data(x))) +end + Base.fill!(S::TensorStorage, v) = (fill!(data(S), v); S) LinearAlgebra.rmul!(S::TensorStorage, v::Number) = (rmul!(data(S), v); S) diff --git a/src/ITensorChainRules/ITensorChainRules.jl b/src/ITensorChainRules/ITensorChainRules.jl index 78fa34fbf2..ad03092fbd 100644 --- a/src/ITensorChainRules/ITensorChainRules.jl +++ b/src/ITensorChainRules/ITensorChainRules.jl @@ -12,8 +12,9 @@ import ChainRulesCore: rrule ITensors.dag(z::AbstractZero) = z -broadcast_notangent(a) = broadcast(_ -> NoTangent(), a) +map_notangent(a) = map(Returns(NoTangent()), a) +include("projection.jl") include(joinpath("NDTensors", "tensor.jl")) include(joinpath("NDTensors", "dense.jl")) include("indexset.jl") @@ -24,7 +25,7 @@ include(joinpath("mps", "mpo.jl")) include(joinpath("LazyApply", "LazyApply.jl")) include("zygoterules.jl") -@non_differentiable broadcast_notangent(::Any) +@non_differentiable map_notangent(::Any) @non_differentiable Index(::Any...) @non_differentiable delta(::Any...) @non_differentiable dag(::Index) diff --git a/src/ITensorChainRules/indexset.jl b/src/ITensorChainRules/indexset.jl index 6cc9ef4e69..3ffb62b00a 100644 --- a/src/ITensorChainRules/indexset.jl +++ b/src/ITensorChainRules/indexset.jl @@ -1,31 +1,6 @@ -function ChainRulesCore.rrule(::typeof(getindex), x::ITensor, I...) - y = getindex(x, I...) - function getindex_pullback(ȳ) - # TODO: add definition `ITensor(::Tuple{}) = ITensor()` - # to ITensors.jl so no splatting is needed here. - x̄ = ITensor(inds(x)...) - x̄[I...] = unthunk(ȳ) - Ī = broadcast_notangent(I) - return (NoTangent(), x̄, Ī...) - end - return y, getindex_pullback -end - -# Specialized version in order to avoid call to `setindex!` -# within the pullback, should be better for taking higher order -# derivatives in Zygote. -function ChainRulesCore.rrule(::typeof(getindex), x::ITensor) - y = x[] - function getindex_pullback(ȳ) - x̄ = ITensor(unthunk(ȳ)) - return (NoTangent(), x̄) - end - return y, getindex_pullback -end - function setinds_pullback(ȳ, x, a...) x̄ = ITensors.setinds(ȳ, inds(x)) - ā = broadcast_notangent(a) + ā = map_notangent(a) return (NoTangent(), x̄, ā...) end @@ -72,7 +47,7 @@ for fname in ( "Trying to differentiate function `$f` with arguments $a and keyword arguments $kwargs. The forward pass indices $(inds(x)) do not match the reverse pass indices $(inds(x̄)). Likely this is because the priming/tagging operation you tried to perform is not invertible. Please write your code in a way where the index manipulation operation you are performing is invertible. For example, `prime(A::ITensor)` is invertible, with an inverse `prime(A, -1)`. However, `noprime(A)` is in general not invertible since the information about the prime levels of the original tensor are lost. Instead, you might try `prime(A, -1)` or `replaceprime(A, 1 => 0)` which are invertible.", ) end - ā = broadcast_notangent(a) + ā = map_notangent(a) return (NoTangent(), x̄, ā...) end return y, f_pullback @@ -102,7 +77,7 @@ for fname in ( function f_pullback(ȳ) uȳ = unthunk(ȳ) x̄ = replaceinds(uȳ, inds(y), inds(x)) - ā = broadcast_notangent(a) + ā = map_notangent(a) return (NoTangent(), x̄, ā...) end return y, f_pullback @@ -110,4 +85,23 @@ for fname in ( end end +function ChainRulesCore.rrule(::typeof(adjoint), x::ITensor) + y = x' + function adjoint_pullback(ȳ) + uȳ = unthunk(ȳ) + x̄ = replaceinds(uȳ, inds(y), inds(x)) + return (NoTangent(), x̄) + end + return y, adjoint_pullback +end + +function ChainRulesCore.rrule(::typeof(adjoint), x::Union{MPS,MPO}) + y = x' + function adjoint_pullback(ȳ) + x̄ = inv_op(prime, ȳ) + return (NoTangent(), x̄) + end + return y, adjoint_pullback +end + @non_differentiable permute(::Indices, ::Indices) diff --git a/src/ITensorChainRules/itensor.jl b/src/ITensorChainRules/itensor.jl index 874a56593e..38c4d1b06e 100644 --- a/src/ITensorChainRules/itensor.jl +++ b/src/ITensorChainRules/itensor.jl @@ -1,3 +1,28 @@ +function ChainRulesCore.rrule(::typeof(getindex), x::ITensor, I...) + y = getindex(x, I...) + function getindex_pullback(ȳ) + # TODO: add definition `ITensor(::Tuple{}) = ITensor()` + # to ITensors.jl so no splatting is needed here. + x̄ = ITensor(inds(x)...) + x̄[I...] = unthunk(ȳ) + Ī = map_notangent(I) + return (NoTangent(), x̄, Ī...) + end + return y, getindex_pullback +end + +# Specialized version in order to avoid call to `setindex!` +# within the pullback, should be better for taking higher order +# derivatives in Zygote. +function ChainRulesCore.rrule(::typeof(getindex), x::ITensor) + y = x[] + function getindex_pullback(ȳ) + x̄ = ITensor(unthunk(ȳ)) + return (NoTangent(), x̄) + end + return y, getindex_pullback +end + function rrule(::Type{ITensor}, x1::AllowAlias, x2::TensorStorage, x3) y = ITensor(x1, x2, x3) function ITensor_pullback(ȳ) @@ -67,91 +92,103 @@ end # Special case for contracting a pair of ITensors function ChainRulesCore.rrule(::typeof(contract), x1::ITensor, x2::ITensor) - y = x1 * x2 + project_x1 = ProjectTo(x1) + project_x2 = ProjectTo(x2) function contract_pullback(ȳ) - x̄1 = ȳ * dag(x2) - x̄2 = dag(x1) * ȳ + x̄1 = project_x1(ȳ * dag(x2)) + x̄2 = project_x2(dag(x1) * ȳ) return (NoTangent(), x̄1, x̄2) end - return y, contract_pullback + return x1 * x2, contract_pullback end @non_differentiable ITensors.optimal_contraction_sequence(::Any) function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor) - y = x1 * x2 + project_x1 = ProjectTo(x1) + project_x2 = ProjectTo(x2) function contract_pullback(ȳ) - x̄1 = ȳ * dag(x2) - x̄2 = dag(x1) * ȳ - return (NoTangent(), x̄1[], x̄2) + x̄1 = project_x1((ȳ * dag(x2))[]) + x̄2 = project_x2(dag(x1) * ȳ) + return (NoTangent(), x̄1, x̄2) end - return y, contract_pullback + return x1 * x2, contract_pullback end function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number) - y = x1 * x2 + project_x1 = ProjectTo(x1) + project_x2 = ProjectTo(x2) function contract_pullback(ȳ) - x̄1 = ȳ * dag(x2) - x̄2 = dag(x1) * ȳ - return (NoTangent(), x̄1, x̄2[]) + x̄1 = project_x1(ȳ * dag(x2)) + x̄2 = project_x2((dag(x1) * ȳ)[]) + return (NoTangent(), x̄1, x̄2) end - return y, contract_pullback + return x1 * x2, contract_pullback end function ChainRulesCore.rrule(::typeof(+), x1::ITensor, x2::ITensor) - y = x1 + x2 function add_pullback(ȳ) return (NoTangent(), ȳ, ȳ) end - return y, add_pullback + return x1 + x2, add_pullback +end + +function ChainRulesCore.rrule(::typeof(-), x1::ITensor, x2::ITensor) + function subtract_pullback(ȳ) + return (NoTangent(), ȳ, -ȳ) + end + return x1 - x2, subtract_pullback +end + +function ChainRulesCore.rrule(::typeof(-), x::ITensor) + function minus_pullback(ȳ) + return (NoTangent(), -ȳ) + end + return -x, minus_pullback end function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...) - y = itensor(x, a...) function itensor_pullback(ȳ) uȳ = permute(unthunk(ȳ), a...) x̄ = reshape(array(uȳ), size(x)) - ā = broadcast_notangent(a) + ā = map_notangent(a) return (NoTangent(), x̄, ā...) end - return y, itensor_pullback + return itensor(x, a...), itensor_pullback end function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...) - y = ITensor(x, a...) function ITensor_pullback(ȳ) # TODO: define `Array(::ITensor)` directly uȳ = Array(unthunk(ȳ), a...) x̄ = reshape(uȳ, size(x)) - ā = broadcast_notangent(a) + ā = map_notangent(a) return (NoTangent(), x̄, ā...) end - return y, ITensor_pullback + return ITensor(x, a...), ITensor_pullback end function ChainRulesCore.rrule(::Type{ITensor}, x::Number) - y = ITensor(x) function ITensor_pullback(ȳ) x̄ = ȳ[] return (NoTangent(), x̄) end - return y, ITensor_pullback + return ITensor(x), ITensor_pullback end -function ChainRulesCore.rrule(::typeof(dag), x) - y = dag(x) +function ChainRulesCore.rrule(::typeof(dag), x::ITensor) function dag_pullback(ȳ) x̄ = dag(unthunk(ȳ)) return (NoTangent(), x̄) end - return y, dag_pullback + return dag(x), dag_pullback end function ChainRulesCore.rrule(::typeof(permute), x::ITensor, a...) y = permute(x, a...) function permute_pullback(ȳ) x̄ = permute(unthunk(ȳ), inds(x)) - ā = broadcast_notangent(a) + ā = map_notangent(a) return (NoTangent(), x̄, ā...) end return y, permute_pullback diff --git a/src/ITensorChainRules/mps/mpo.jl b/src/ITensorChainRules/mps/mpo.jl index 8a544a7c49..0167dcb9ef 100644 --- a/src/ITensorChainRules/mps/mpo.jl +++ b/src/ITensorChainRules/mps/mpo.jl @@ -1,28 +1,3 @@ -#function rrule(::typeof(MPO), x::Vector{<:ITensor}; kwargs...) -# y = MPO(x; kwargs...) -# #@show y -# function MPO_pullback(ȳ) -# #@show ȳ -# return ȳ.data -# #ȳtensors = ȳ.data -# #n = length(ȳtensors) -# #envL = [ȳtensors[1] * dag(x[1]), ] -# #envR = [ȳtensors[n] * dag(x[n]), ] -# #for j in 2:n-1 -# # push!(envL, envL[j-1] * ȳtensors[j] * dag(x[j])) -# # push!(envR, envR[j-1] * ȳtensors[n+1-j] * dag(x[n+1-j])) -# #end -# #x̄= ITensor[] -# #push!(x̄, ȳtensors[1] * envR[n-1]) -# #for j in 2:n-1 -# # push!(x̄, envL[j-1] * ȳtensors[j] * envR[n-j]) -# #end -# #push!(x̄, envL[n-1] * ȳtensors[n]) -# #return (NoTangent(), x̄) -# end -# return y, MPO_pullback -#end - function rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) y = *(x1, x2; kwargs...) function contract_pullback(ȳ) @@ -43,18 +18,18 @@ end function ChainRulesCore.rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...) y = -(x1, x2; kwargs...) - function add_pullback(ȳ) + function subtract_pullback(ȳ) return (NoTangent(), ȳ, -ȳ) end - return y, add_pullback + return y, subtract_pullback end function rrule(::typeof(tr), x::MPO; kwargs...) y = tr(x; kwargs...) - function contract_pullback(ȳ) + function tr_pullback(ȳ) s = noprime(firstsiteinds(x)) n = length(s) - x̄ = ȳ * MPO(s, "Id") + x̄ = MPO(s, "Id") plev = get(kwargs, :plev, 0 => 1) for j in 1:n @@ -62,7 +37,7 @@ function rrule(::typeof(tr), x::MPO; kwargs...) end return (NoTangent(), ȳ * x̄) end - return y, contract_pullback + return y, tr_pullback end function rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) diff --git a/src/ITensorChainRules/projection.jl b/src/ITensorChainRules/projection.jl new file mode 100644 index 0000000000..443d6e5fcb --- /dev/null +++ b/src/ITensorChainRules/projection.jl @@ -0,0 +1,10 @@ +function ChainRulesCore.ProjectTo(x::ITensor) + return ProjectTo{ITensor}(; element=ProjectTo(zero(eltype(x)))) +end + +function (project::ProjectTo{ITensor})(dx::ITensor) + S = eltype(dx) + T = ChainRulesCore.project_type(project.element) + dy = S <: T ? dx : map(project.element, dx) + return dy +end diff --git a/src/imports.jl b/src/imports.jl index b8b3a01f61..63c7e03473 100644 --- a/src/imports.jl +++ b/src/imports.jl @@ -42,6 +42,8 @@ import Base: isassigned, isempty, isless, + isreal, + iszero, iterate, keys, lastindex, diff --git a/src/itensor.jl b/src/itensor.jl index 378499f3c3..af15cad4fb 100644 --- a/src/itensor.jl +++ b/src/itensor.jl @@ -1544,6 +1544,16 @@ function randomITensor(::Type{S}, is...) where {S<:Number} return randomITensor(S, indices(is...)) end +# To fix ambiguity with QN version +function randomITensor(::Type{ElT}, ::Tuple{}) where {ElT<:Number} + return randomITensor(ElT, Index{Int}[]) +end + +# To fix ambiguity with QN version +function randomITensor(is::Tuple{}) + return randomITensor(Float64, is) +end + # To fix ambiguity errors with QN version function randomITensor(::Type{ElT}) where {ElT<:Number} return randomITensor(ElT, ()) @@ -2598,6 +2608,8 @@ function map!(f::Function, R::ITensor, T1::ITensor, T2::ITensor) return settensor!(R, _map!!(f, tensor(R), tensor(T1), tensor(T2))) end +map(f, x::ITensor) = itensor(map(f, tensor(x))) + """ axpy!(a::Number, v::ITensor, w::ITensor) ``` @@ -2723,6 +2735,9 @@ isemptystorage(T::ITensor) = isemptystorage(tensor(T)) isemptystorage(T::Tensor) = isempty(T) isempty(T::ITensor) = isemptystorage(T) +isreal(T::ITensor) = eltype(T) <: Real +iszero(T::ITensor) = all(iszero, T) + ####################################################################### # # Developer functions diff --git a/src/mps/abstractmps.jl b/src/mps/abstractmps.jl index 031dcc082e..7fd8bd8556 100644 --- a/src/mps/abstractmps.jl +++ b/src/mps/abstractmps.jl @@ -1092,11 +1092,8 @@ function _log_or_not_dot( dot_M1_M2 = O[] - T = promote_type(ITensors.promote_itensor_eltype(M1), ITensors.promote_itensor_eltype(M2)) - _max_dot_warn = inv(eps(real(float(T)))) - - if isnan(dot_M1_M2) || isinf(dot_M1_M2) || abs(dot_M1_M2) > _max_dot_warn - @warn "The inner product (or norm²) you are computing is very large: $dot_M1_M2, which is greater than $_max_dot_warn and may lead to floating point errors when used. You should consider using `lognorm` or `loginner` instead, which will help avoid floating point errors. For example if you are trying to normalize your MPS/MPO `A`, the normalized MPS/MPO `B` would be given by `B = A ./ z` where `z = exp(lognorm(A) / length(A))`." + if !isfinite(dot_M1_M2) + @warn "The inner product (or norm²) you are computing is very large ($dot_M1_M2). You should consider using `lognorm` or `loginner` instead, which will help avoid floating point errors. For example if you are trying to normalize your MPS/MPO `A`, the normalized MPS/MPO `B` would be given by `B = A ./ z` where `z = exp(lognorm(A) / length(A))`." end return dot_M1_M2 diff --git a/src/mps/mps.jl b/src/mps/mps.jl index 83bbbbc7e3..3b9120a02a 100644 --- a/src/mps/mps.jl +++ b/src/mps/mps.jl @@ -141,7 +141,7 @@ function randomCircuitMPS( M = MPS(N) if N == 1 - M[1] = ITensor(randn(dim(sites[1])), sites[1]) + M[1] = ITensor(randn(ElT, dim(sites[1])), sites[1]) M[1] /= norm(M[1]) return M end diff --git a/src/physics/site_types/qudit.jl b/src/physics/site_types/qudit.jl index 333f31ec4c..d9f336016e 100644 --- a/src/physics/site_types/qudit.jl +++ b/src/physics/site_types/qudit.jl @@ -1,4 +1,3 @@ - """ space(::SiteType"Qudit"; dim = 2, @@ -44,6 +43,10 @@ function _op(::OpName"Id", ::SiteType"Qudit"; dim::Tuple=(2,)) return mat end +function _op(::OpName"I", st::SiteType"Qudit"; kwargs...) + return _op(OpName"Id"(), st; kwargs...) +end + function _op(::OpName"Adag", ::SiteType"Qudit"; dim::Tuple=(2,)) d = dim[1] mat = zeros(d, d) diff --git a/test/ITensorChainRules/test_chainrules.jl b/test/ITensorChainRules/test_chainrules.jl index 3b6dbeb216..b83039312e 100644 --- a/test/ITensorChainRules/test_chainrules.jl +++ b/test/ITensorChainRules/test_chainrules.jl @@ -262,6 +262,30 @@ Random.seed!(1234) rtol=1e-4, atol=1e-4, ) + + # https://github.com/ITensor/ITensors.jl/issues/933 + f2 = function (x, a) + y = a + im * x + return real(dag(y) * y)[] + end + a = randomITensor() + f_itensor = x -> f2(x, a) + f_number = x -> f2(x, a[]) + x = randomITensor() + @test f_number(x[]) ≈ f_itensor(x) + @test f_number'(x[]) ≈ f_itensor'(x)[] + @test isreal(f_itensor'(x)) + + # https://github.com/ITensor/ITensors.jl/issues/936 + n = 2 + s = siteinds("S=1/2", n) + x = randomMPS(s) |> x -> outer(x', x) + f1 = x -> tr(x) + f2 = x -> 2tr(x) + f3 = x -> -tr(x) + @test f1'(x) ≈ MPO(s, "I") + @test f2'(x) ≈ 2MPO(s, "I") + @test f3'(x) ≈ -MPO(s, "I") end @testset "ChainRules rrules: op" begin diff --git a/test/decomp.jl b/test/decomp.jl index 0443c61909..c8ee14e25f 100644 --- a/test/decomp.jl +++ b/test/decomp.jl @@ -74,6 +74,15 @@ using ITensors, LinearAlgebra, Test eigArr = eigen(array(A)) @test diag(array(eigA.D), 0) ≈ eigArr.values @test diag(array(Dt), 0) == eigArr.values + + @test_throws ArgumentError eigen(ITensor(NaN, i', i)) + @test_throws ArgumentError eigen(ITensor(NaN, i', i); ishermitian=true) + @test_throws ArgumentError eigen(ITensor(complex(NaN), i', i)) + @test_throws ArgumentError eigen(ITensor(complex(NaN), i', i); ishermitian=true) + @test_throws ArgumentError eigen(ITensor(Inf, i', i)) + @test_throws ArgumentError eigen(ITensor(Inf, i', i); ishermitian=true) + @test_throws ArgumentError eigen(ITensor(complex(Inf), i', i)) + @test_throws ArgumentError eigen(ITensor(complex(Inf), i', i); ishermitian=true) end @testset "exp function" begin diff --git a/test/phys_site_types.jl b/test/phys_site_types.jl index 432b7b2381..80605c2ae2 100644 --- a/test/phys_site_types.jl +++ b/test/phys_site_types.jl @@ -435,6 +435,7 @@ using ITensors, Test s = siteinds(st, 4; dim=3, conserve_qns=true) @test all(hasqns, s) @test op(s, "Id", 2) == itensor([1 0 0; 0 1 0; 0 0 1], s[2]', dag(s[2])) + @test op(s, "I", 2) == itensor([1 0 0; 0 1 0; 0 0 1], s[2]', dag(s[2])) @test op(s, "N", 2) == itensor([0 0 0; 0 1 0; 0 0 2], s[2]', dag(s[2])) @test op(s, "n", 2) == itensor([0 0 0; 0 1 0; 0 0 2], s[2]', dag(s[2])) @test op(s, "Adag", 2) ≈ itensor([0 0 0; 1 0 0; 0 √2 0], s[2]', dag(s[2])) From efacea7b2e5bc1894b58ce75874352d825eb8954 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 6 Jul 2022 21:22:10 -0400 Subject: [PATCH 4/5] [ITensors] [ENHANCEMENT] Fix `apply(::MPO, ::MPO)` autodiff, add Vararg `apply(::MPO...)` (#949) --- NDTensors/NEWS.md | 11 +++++++ NDTensors/Project.toml | 2 +- NDTensors/src/linearalgebra.jl | 14 +++++--- NDTensors/src/tensor.jl | 4 ++- NDTensors/src/tensorstorage.jl | 4 ++- NEWS.md | 37 ++++++++++++++++++++++ Project.toml | 4 +-- src/ITensorChainRules/ITensorChainRules.jl | 7 +++- src/ITensorChainRules/itensor.jl | 9 ++++++ src/ITensorChainRules/mps/mpo.jl | 22 ++++++------- src/ITensorChainRules/zygoterules.jl | 16 +--------- src/mps/abstractmps.jl | 2 +- src/mps/mpo.jl | 4 +++ test/ITensorChainRules/test_chainrules.jl | 21 +++++++++++- test/itensor.jl | 33 +++++++++++++++++++ test/mpo.jl | 8 +++++ 16 files changed, 160 insertions(+), 38 deletions(-) diff --git a/NDTensors/NEWS.md b/NDTensors/NEWS.md index 7a5a55c3e0..fe27388823 100644 --- a/NDTensors/NEWS.md +++ b/NDTensors/NEWS.md @@ -6,6 +6,17 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t After we release v1 of the package, we will start following [semantic versioning](https://semver.org). +NDTensors v0.1.42 Release Notes +=============================== + +Bugs: + +Enhancements: + +- Define `map` for Tensor and TensorStorage (b66d1b7) +- Define `real` and `imag` for Tensor (b66d1b7) +- Throw error when trying to do an eigendecomposition of Tensor with Infs or NaNs (b66d1b7) + NDTensors v0.1.41 Release Notes =============================== diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 79c0ebb60a..4629b07ecb 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.1.41" +version = "0.1.42" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" diff --git a/NDTensors/src/linearalgebra.jl b/NDTensors/src/linearalgebra.jl index 21df0579a0..22e90cee84 100644 --- a/NDTensors/src/linearalgebra.jl +++ b/NDTensors/src/linearalgebra.jl @@ -220,8 +220,11 @@ function LinearAlgebra.eigen( matrixT = matrix(T) if any(!isfinite, matrixT) - display(matrixT) - throw(ArgumentError("Trying to perform the eigendecomposition of a matrix containing NaNs or Infs")) + throw( + ArgumentError( + "Trying to perform the eigendecomposition of a matrix containing NaNs or Infs" + ), + ) end DM, VM = eigen(matrixT) @@ -351,8 +354,11 @@ function LinearAlgebra.eigen( matrixT = matrix(T) if any(!isfinite, matrixT) - display(matrixT) - throw(ArgumentError("Trying to perform the eigendecomposition of a matrix containing NaNs or Infs")) + throw( + ArgumentError( + "Trying to perform the eigendecomposition of a matrix containing NaNs or Infs" + ), + ) end DM, VM = eigen(matrixT) diff --git a/NDTensors/src/tensor.jl b/NDTensors/src/tensor.jl index a34672913d..68e4a23908 100644 --- a/NDTensors/src/tensor.jl +++ b/NDTensors/src/tensor.jl @@ -133,7 +133,9 @@ Base.similar(T::Tensor, args...) = similar(T, args...) function map(f, x::Tensor{T}) where {T} if !iszero(f(zero(T))) - error("map(f, ::Tensor) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.") + error( + "map(f, ::Tensor) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.", + ) end return setstorage(x, map(f, storage(x))) end diff --git a/NDTensors/src/tensorstorage.jl b/NDTensors/src/tensorstorage.jl index 4d4a31a20d..08bf03a5fb 100644 --- a/NDTensors/src/tensorstorage.jl +++ b/NDTensors/src/tensorstorage.jl @@ -74,7 +74,9 @@ Random.randn!(S::TensorStorage) = (randn!(data(S)); S) function map(f, x::TensorStorage{T}) where {T} if !iszero(f(zero(T))) - error("map(f, ::TensorStorage) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.") + error( + "map(f, ::TensorStorage) currently doesn't support functions that don't preserve zeros, while you passed a function such that f(0) = $(f(zero(T))). This isn't supported right now because it doesn't necessarily preserve the sparsity structure of the input tensor.", + ) end return setdata(x, map(f, data(x))) end diff --git a/NEWS.md b/NEWS.md index 434e763b7e..72208ea452 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,43 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t After we release v1 of the package, we will start following [semantic versioning](https://semver.org). +ITensors v0.3.18 Release Notes +============================== + +Bugs: + +- Extend `apply(::MPO, ::MPO)` to `apply(::MPO, ::MPO, ::MPO...)` (#949) +- Fix AD for `apply(::MPO, ::MPO)` and `contract(::MPO, ::MPO)` (#949) +- Properly use element type in `randomMPS` in the 1-site case (b66d1b7) +- Fix bug in `tr(::MPO)` rrule where the derivative was being multiplied twice into the identity MPO (b66d1b7) +- Fix directsum when specifying a single `Index` (#930) +- Fix bug in loginner when inner is negative or complex (#945) +- Fix subtraction bug in `OpSum` (#945) + +Enhancements: + +- Define "I" for Qudit/Boson type (b66d1b7) +- Only warn in `inner` if the result is `Inf` or `NaN` (b66d1b7) +- Make sure `randomITensor(())` and `randomITensor(Float64, ())` returns a Dense storage type (b66d1b7) +- Define `isreal` and `iszero` for ITensors (b66d1b7) +- Project element type of ITensor in reverse pass of tensor-tensor or scalar-tensor contraction (b66d1b7) +- Define reverse rules for ITensor subtraction and negation (b66d1b7) +- Define `map` for ITensors (b66d1b7) +- Throw error when performing eigendecomposition of tensor with NaN or Inf elements (b66d1b7) +- Fix `rrule` for `MPO` constructor by generalizing the `rrule` for the `MPS` constructor (#946) +- Forward truncation arguments to more operations in `rrule` for `apply` (#945) +- Add rrules for addition and subtraction of MPOs (#935) + +ITensors v0.3.17 Release Notes +============================== + +Bugs: + +Enhancements: + +- Add Zp as alias for operator Z+, etc. (#942) +- Export diag (#942) + ITensors v0.3.16 Release Notes ============================== diff --git a/Project.toml b/Project.toml index 5918cb9b57..edb849908c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensors" uuid = "9136182c-28ba-11e9-034c-db9fb085ebd5" authors = ["Matthew Fishman ", "Miles Stoudenmire "] -version = "0.3.17" +version = "0.3.18" [deps] BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1" @@ -36,7 +36,7 @@ HDF5 = "0.14, 0.15, 0.16" IsApprox = "0.1" KrylovKit = "0.4.2, 0.5" LinearMaps = "3" -NDTensors = "0.1.41" +NDTensors = "0.1.42" PackageCompiler = "1.0.0, 2" Requires = "1.1" SerializedElementArrays = "0.1" diff --git a/src/ITensorChainRules/ITensorChainRules.jl b/src/ITensorChainRules/ITensorChainRules.jl index ad03092fbd..c500df4456 100644 --- a/src/ITensorChainRules/ITensorChainRules.jl +++ b/src/ITensorChainRules/ITensorChainRules.jl @@ -12,7 +12,11 @@ import ChainRulesCore: rrule ITensors.dag(z::AbstractZero) = z -map_notangent(a) = map(Returns(NoTangent()), a) +if VERSION < v"1.7" + map_notangent(a) = map(_ -> NoTangent(), a) +else + map_notangent(a) = map(Returns(NoTangent()), a) +end include("projection.jl") include(joinpath("NDTensors", "tensor.jl")) @@ -40,5 +44,6 @@ include("zygoterules.jl") @non_differentiable ITensors.filter_inds_set_function(::Function, ::Any...) @non_differentiable ITensors.indpairs(::Any...) @non_differentiable onehot(::Any...) +@non_differentiable Base.convert(::Type{TagSet}, str::String) end diff --git a/src/ITensorChainRules/itensor.jl b/src/ITensorChainRules/itensor.jl index 38c4d1b06e..9dbda701f9 100644 --- a/src/ITensorChainRules/itensor.jl +++ b/src/ITensorChainRules/itensor.jl @@ -194,4 +194,13 @@ function ChainRulesCore.rrule(::typeof(permute), x::ITensor, a...) return y, permute_pullback end +# Needed because by default it was calling the generic +# `rrule` for `tr` inside ChainRules. +# TODO: Raise an issue with ChainRules. +function ChainRulesCore.rrule( + config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs... +) + return rrule_via_ad(config, ITensors._tr, x; kwargs...) +end + @non_differentiable combiner(::Indices) diff --git a/src/ITensorChainRules/mps/mpo.jl b/src/ITensorChainRules/mps/mpo.jl index 0167dcb9ef..78c0d1fba2 100644 --- a/src/ITensorChainRules/mps/mpo.jl +++ b/src/ITensorChainRules/mps/mpo.jl @@ -1,13 +1,17 @@ -function rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) - y = *(x1, x2; kwargs...) +function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) + y = contract(x1, x2; kwargs...) function contract_pullback(ȳ) - x̄1 = *(ȳ, dag(x2); kwargs...) - x̄2 = *(dag(x1), ȳ; kwargs...) + x̄1 = contract(ȳ, dag(x2); kwargs...) + x̄2 = contract(dag(x1), ȳ; kwargs...) return (NoTangent(), x̄1, x̄2) end return y, contract_pullback end +function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) + return rrule(contract, x1, x2; kwargs...) +end + function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) y = +(x1, x2; kwargs...) function add_pullback(ȳ) @@ -17,14 +21,10 @@ function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) end function ChainRulesCore.rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...) - y = -(x1, x2; kwargs...) - function subtract_pullback(ȳ) - return (NoTangent(), ȳ, -ȳ) - end - return y, subtract_pullback + return rrule(+, x1, -x2; kwargs...) end -function rrule(::typeof(tr), x::MPO; kwargs...) +function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...) y = tr(x; kwargs...) function tr_pullback(ȳ) s = noprime(firstsiteinds(x)) @@ -40,7 +40,7 @@ function rrule(::typeof(tr), x::MPO; kwargs...) return y, tr_pullback end -function rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) +function ChainRulesCore.rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) if !hassameinds(siteinds, x1, (x2, x3)) || !hassameinds(siteinds, x3, (x2, x1)) error( "Taking gradients of `inner(x::MPS, A::MPO, y::MPS)` is not supported if the site indices of the input MPS and MPO don't match. Try using if you input `inner(x, A, y), try `inner(x', A, y)` instead.", diff --git a/src/ITensorChainRules/zygoterules.jl b/src/ITensorChainRules/zygoterules.jl index 7141bf3d4d..65e82549c2 100644 --- a/src/ITensorChainRules/zygoterules.jl +++ b/src/ITensorChainRules/zygoterules.jl @@ -1,8 +1,7 @@ +using ZygoteRules: @adjoint # Needed for defining the rule for `adjoint(A::ITensor)` # which currently doesn't work by overloading `ChainRulesCore.rrule` -using ZygoteRules: @adjoint - @adjoint function Base.adjoint(x::Union{ITensor,MPS,MPO}) y = prime(x) function adjoint_pullback(ȳ) @@ -11,16 +10,3 @@ using ZygoteRules: @adjoint end return y, adjoint_pullback end - -## XXX: raise issue about `tr` being too generically -## defined in ChainRules -## -## using Zygote -## -## # Needed because by default it was calling the generic -## # rrule for `tr` inside ChainRules -## function rrule(::typeof(tr), x::ITensor; kwargs...) -## y, tr_pullback_zygote = pullback(ITensors._tr, x; kwargs...) -## tr_pullback(ȳ) = (NoTangent(), tr_pullback_zygote(ȳ)...) -## return y, tr_pullback -## end diff --git a/src/mps/abstractmps.jl b/src/mps/abstractmps.jl index 7fd8bd8556..f2785e0398 100644 --- a/src/mps/abstractmps.jl +++ b/src/mps/abstractmps.jl @@ -1624,7 +1624,7 @@ function truncate(ψ0::AbstractMPS; kwargs...) return ψ end -# Make `*` and alias for `contract` of two `AbstractMPS` +# Make `*` an alias for `contract` of two `AbstractMPS` *(A::AbstractMPS, B::AbstractMPS; kwargs...) = contract(A, B; kwargs...) function _apply_to_orthocenter!(f, ψ::AbstractMPS, x) diff --git a/src/mps/mpo.jl b/src/mps/mpo.jl index 251e1a7a82..c7cf5faaad 100644 --- a/src/mps/mpo.jl +++ b/src/mps/mpo.jl @@ -783,6 +783,10 @@ function apply(A::MPO, B::MPO; kwargs...) return replaceprime(AB, 2 => 1) end +function apply(A1::MPO, A2::MPO, A3::MPO, As::MPO...; kwargs...) + return apply(apply(A1, A2; kwargs...), A3, As...; kwargs...) +end + (A::MPO)(B::MPO; kwargs...) = apply(A, B; kwargs...) contract_mpo_mpo_doc = """ diff --git a/test/ITensorChainRules/test_chainrules.jl b/test/ITensorChainRules/test_chainrules.jl index b83039312e..d7899d9224 100644 --- a/test/ITensorChainRules/test_chainrules.jl +++ b/test/ITensorChainRules/test_chainrules.jl @@ -279,7 +279,7 @@ Random.seed!(1234) # https://github.com/ITensor/ITensors.jl/issues/936 n = 2 s = siteinds("S=1/2", n) - x = randomMPS(s) |> x -> outer(x', x) + x = (x -> outer(x', x))(randomMPS(s)) f1 = x -> tr(x) f2 = x -> 2tr(x) f3 = x -> -tr(x) @@ -627,3 +627,22 @@ end ∇num = (f(θ + ϵ) - f(θ)) / ϵ @test ∇f ≈ ∇num atol = 1e-5 end + +@testset "contract/apply MPOs" begin + n = 2 + s = siteinds("S=1/2", n) + x = (x -> outer(x', x))(randomMPS(s; linkdims=4)) + x_itensor = contract(x) + + f = x -> tr(apply(x, x)) + @test f(x) ≈ f(x_itensor) + @test contract(f'(x)) ≈ f'(x_itensor) + + f = x -> tr(replaceprime(contract(x', x), 2 => 1)) + @test f(x) ≈ f(x_itensor) + @test contract(f'(x)) ≈ f'(x_itensor) + + f = x -> tr(replaceprime(*(x', x), 2 => 1)) + @test f(x) ≈ f(x_itensor) + @test contract(f'(x)) ≈ f'(x_itensor) +end diff --git a/test/itensor.jl b/test/itensor.jl index e09da45793..94cc518b89 100644 --- a/test/itensor.jl +++ b/test/itensor.jl @@ -54,6 +54,39 @@ end @test !hascommoninds(A, C) end + @testset "isreal, iszero, real, imag" begin + i, j = Index.(2, ("i", "j")) + A = randomITensor(i, j) + Ac = randomITensor(ComplexF64, i, j) + Ar = real(Ac) + Ai = imag(Ac) + @test Ac ≈ Ar + im * Ai + @test isreal(A) + @test !isreal(Ac) + @test isreal(Ar) + @test isreal(Ai) + @test !iszero(A) + @test !iszero(real(A)) + @test iszero(imag(A)) + @test iszero(ITensor(0.0, i, j)) + @test iszero(ITensor(i, j)) + end + + @testset "map" begin + A = randomITensor(Index(2)) + @test eltype(A) == Float64 + B = map(ComplexF64, A) + @test B ≈ A + @test eltype(B) == ComplexF64 + B = map(Float32, A) + @test B ≈ A + @test eltype(B) == Float32 + B = map(x -> 2x, A) + @test B ≈ 2A + @test eltype(B) == Float64 + @test_throws ErrorException map(x -> x + 1, A) + end + @testset "getindex with state string" begin i₁ = Index(2, "S=1/2") i₂ = Index(2, "S=1/2") diff --git a/test/mpo.jl b/test/mpo.jl index 1b4cfd779f..0984856b39 100644 --- a/test/mpo.jl +++ b/test/mpo.jl @@ -370,6 +370,14 @@ end @test_throws DimensionMismatch K * badL end + @testset "Multi-arg apply(::MPO...)" begin + ρ1 = (x -> outer(x', x; maxdim=4))(randomMPS(sites; linkdims=2)) + ρ2 = (x -> outer(x', x; maxdim=4))(randomMPS(sites; linkdims=2)) + ρ3 = (x -> outer(x', x; maxdim=4))(randomMPS(sites; linkdims=2)) + @test apply(ρ1, ρ2, ρ3; cutoff=1e-8) ≈ + apply(apply(ρ1, ρ2; cutoff=1e-8), ρ3; cutoff=1e-8) + end + sites = siteinds("S=1/2", N) O = MPO(sites, "Sz") @test length(O) == N # just make sure this works From 6a83a8327425a2008d4f1f7d66b8377e6bac2aba Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 7 Jul 2022 12:10:01 -0400 Subject: [PATCH 5/5] [ITensors] Simplify the rrules for priming and tagging MPS/MPO (#950) --- NEWS.md | 9 +++ src/ITensorChainRules/indexset.jl | 94 ++++++---------------------- src/ITensorChainRules/itensor.jl | 30 +++++---- src/ITensorChainRules/mps/mpo.jl | 12 ++-- src/ITensorChainRules/zygoterules.jl | 8 +-- 5 files changed, 51 insertions(+), 102 deletions(-) diff --git a/NEWS.md b/NEWS.md index 72208ea452..c0ac9334de 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,15 @@ Note that as of Julia v1.5, in order to see deprecation warnings you will need t After we release v1 of the package, we will start following [semantic versioning](https://semver.org). +ITensors v0.3.19 Release Notes +============================== + +Bugs: + +Enhancements: + +- Simplify the `rrule`s for priming and tagging MPS/MPO + ITensors v0.3.18 Release Notes ============================== diff --git a/src/ITensorChainRules/indexset.jl b/src/ITensorChainRules/indexset.jl index 3ffb62b00a..253a8157c8 100644 --- a/src/ITensorChainRules/indexset.jl +++ b/src/ITensorChainRules/indexset.jl @@ -1,60 +1,3 @@ -function setinds_pullback(ȳ, x, a...) - x̄ = ITensors.setinds(ȳ, inds(x)) - ā = map_notangent(a) - return (NoTangent(), x̄, ā...) -end - -function inv_op(f::Function, args...; kwargs...) - return error( - "Trying to differentiate `$f` but the inverse of the operation (`inv_op`) `$f` with arguments $args and keyword arguments $kwargs is not defined.", - ) -end - -function inv_op(::typeof(prime), x, n::Integer=1; kwargs...) - return prime(x, -n; kwargs...) -end - -function inv_op(::typeof(replaceprime), x, n1n2::Pair; kwargs...) - return replaceprime(x, reverse(n1n2); kwargs...) -end - -function inv_op(::typeof(addtags), x, args...; kwargs...) - return removetags(x, args...; kwargs...) -end - -function inv_op(::typeof(removetags), x, args...; kwargs...) - return addtags(x, args...; kwargs...) -end - -function inv_op(::typeof(replacetags), x, n1n2::Pair; kwargs...) - return replacetags(x, reverse(n1n2); kwargs...) -end - -_check_inds(x::ITensor, y::ITensor) = hassameinds(x, y) -_check_inds(x::MPS, y::MPS) = hassameinds(siteinds, x, y) -_check_inds(x::MPO, y::MPO) = hassameinds(siteinds, x, y) - -for fname in ( - :prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags -) - @eval begin - function ChainRulesCore.rrule(f::typeof($fname), x::Union{MPS,MPO}, a...; kwargs...) - y = f(x, a...; kwargs...) - function f_pullback(ȳ) - x̄ = inv_op(f, unthunk(ȳ), a...; kwargs...) - if !_check_inds(x, x̄) - error( - "Trying to differentiate function `$f` with arguments $a and keyword arguments $kwargs. The forward pass indices $(inds(x)) do not match the reverse pass indices $(inds(x̄)). Likely this is because the priming/tagging operation you tried to perform is not invertible. Please write your code in a way where the index manipulation operation you are performing is invertible. For example, `prime(A::ITensor)` is invertible, with an inverse `prime(A, -1)`. However, `noprime(A)` is in general not invertible since the information about the prime levels of the original tensor are lost. Instead, you might try `prime(A, -1)` or `replaceprime(A, 1 => 0)` which are invertible.", - ) - end - ā = map_notangent(a) - return (NoTangent(), x̄, ā...) - end - return y, f_pullback - end - end -end - for fname in ( :prime, :setprime, @@ -72,11 +15,10 @@ for fname in ( :swapinds, ) @eval begin - function ChainRulesCore.rrule(f::typeof($fname), x::ITensor, a...; kwargs...) + function rrule(f::typeof($fname), x::ITensor, a...; kwargs...) y = f(x, a...; kwargs...) function f_pullback(ȳ) - uȳ = unthunk(ȳ) - x̄ = replaceinds(uȳ, inds(y), inds(x)) + x̄ = replaceinds(unthunk(ȳ), inds(y) => inds(x)) ā = map_notangent(a) return (NoTangent(), x̄, ā...) end @@ -85,23 +27,25 @@ for fname in ( end end -function ChainRulesCore.rrule(::typeof(adjoint), x::ITensor) - y = x' - function adjoint_pullback(ȳ) - uȳ = unthunk(ȳ) - x̄ = replaceinds(uȳ, inds(y), inds(x)) - return (NoTangent(), x̄) +for fname in ( + :prime, :setprime, :noprime, :replaceprime, :addtags, :removetags, :replacetags, :settags +) + @eval begin + function rrule(f::typeof($fname), x::Union{MPS,MPO}, a...; kwargs...) + y = f(x, a...; kwargs...) + function f_pullback(ȳ) + x̄ = copy(unthunk(ȳ)) + for j in eachindex(x̄) + x̄[j] = replaceinds(ȳ[j], inds(y[j]) => inds(x[j])) + end + ā = map_notangent(a) + return (NoTangent(), x̄, ā...) + end + return y, f_pullback + end end - return y, adjoint_pullback end -function ChainRulesCore.rrule(::typeof(adjoint), x::Union{MPS,MPO}) - y = x' - function adjoint_pullback(ȳ) - x̄ = inv_op(prime, ȳ) - return (NoTangent(), x̄) - end - return y, adjoint_pullback -end +rrule(::typeof(adjoint), x::Union{ITensor,MPS,MPO}) = rrule(prime, x) @non_differentiable permute(::Indices, ::Indices) diff --git a/src/ITensorChainRules/itensor.jl b/src/ITensorChainRules/itensor.jl index 9dbda701f9..a830cc77bd 100644 --- a/src/ITensorChainRules/itensor.jl +++ b/src/ITensorChainRules/itensor.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(getindex), x::ITensor, I...) +function rrule(::typeof(getindex), x::ITensor, I...) y = getindex(x, I...) function getindex_pullback(ȳ) # TODO: add definition `ITensor(::Tuple{}) = ITensor()` @@ -14,7 +14,7 @@ end # Specialized version in order to avoid call to `setindex!` # within the pullback, should be better for taking higher order # derivatives in Zygote. -function ChainRulesCore.rrule(::typeof(getindex), x::ITensor) +function rrule(::typeof(getindex), x::ITensor) y = x[] function getindex_pullback(ȳ) x̄ = ITensor(unthunk(ȳ)) @@ -91,7 +91,7 @@ function rrule(::typeof(tensor), x1::ITensor) end # Special case for contracting a pair of ITensors -function ChainRulesCore.rrule(::typeof(contract), x1::ITensor, x2::ITensor) +function rrule(::typeof(contract), x1::ITensor, x2::ITensor) project_x1 = ProjectTo(x1) project_x2 = ProjectTo(x2) function contract_pullback(ȳ) @@ -104,7 +104,7 @@ end @non_differentiable ITensors.optimal_contraction_sequence(::Any) -function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor) +function rrule(::typeof(*), x1::Number, x2::ITensor) project_x1 = ProjectTo(x1) project_x2 = ProjectTo(x2) function contract_pullback(ȳ) @@ -115,7 +115,7 @@ function ChainRulesCore.rrule(::typeof(*), x1::Number, x2::ITensor) return x1 * x2, contract_pullback end -function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number) +function rrule(::typeof(*), x1::ITensor, x2::Number) project_x1 = ProjectTo(x1) project_x2 = ProjectTo(x2) function contract_pullback(ȳ) @@ -126,28 +126,28 @@ function ChainRulesCore.rrule(::typeof(*), x1::ITensor, x2::Number) return x1 * x2, contract_pullback end -function ChainRulesCore.rrule(::typeof(+), x1::ITensor, x2::ITensor) +function rrule(::typeof(+), x1::ITensor, x2::ITensor) function add_pullback(ȳ) return (NoTangent(), ȳ, ȳ) end return x1 + x2, add_pullback end -function ChainRulesCore.rrule(::typeof(-), x1::ITensor, x2::ITensor) +function rrule(::typeof(-), x1::ITensor, x2::ITensor) function subtract_pullback(ȳ) return (NoTangent(), ȳ, -ȳ) end return x1 - x2, subtract_pullback end -function ChainRulesCore.rrule(::typeof(-), x::ITensor) +function rrule(::typeof(-), x::ITensor) function minus_pullback(ȳ) return (NoTangent(), -ȳ) end return -x, minus_pullback end -function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...) +function rrule(::typeof(itensor), x::Array, a...) function itensor_pullback(ȳ) uȳ = permute(unthunk(ȳ), a...) x̄ = reshape(array(uȳ), size(x)) @@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(itensor), x::Array, a...) return itensor(x, a...), itensor_pullback end -function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...) +function rrule(::Type{ITensor}, x::Array{<:Number}, a...) function ITensor_pullback(ȳ) # TODO: define `Array(::ITensor)` directly uȳ = Array(unthunk(ȳ), a...) @@ -168,7 +168,7 @@ function ChainRulesCore.rrule(::Type{ITensor}, x::Array{<:Number}, a...) return ITensor(x, a...), ITensor_pullback end -function ChainRulesCore.rrule(::Type{ITensor}, x::Number) +function rrule(::Type{ITensor}, x::Number) function ITensor_pullback(ȳ) x̄ = ȳ[] return (NoTangent(), x̄) @@ -176,7 +176,7 @@ function ChainRulesCore.rrule(::Type{ITensor}, x::Number) return ITensor(x), ITensor_pullback end -function ChainRulesCore.rrule(::typeof(dag), x::ITensor) +function rrule(::typeof(dag), x::ITensor) function dag_pullback(ȳ) x̄ = dag(unthunk(ȳ)) return (NoTangent(), x̄) @@ -184,7 +184,7 @@ function ChainRulesCore.rrule(::typeof(dag), x::ITensor) return dag(x), dag_pullback end -function ChainRulesCore.rrule(::typeof(permute), x::ITensor, a...) +function rrule(::typeof(permute), x::ITensor, a...) y = permute(x, a...) function permute_pullback(ȳ) x̄ = permute(unthunk(ȳ), inds(x)) @@ -197,9 +197,7 @@ end # Needed because by default it was calling the generic # `rrule` for `tr` inside ChainRules. # TODO: Raise an issue with ChainRules. -function ChainRulesCore.rrule( - config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs... -) +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(tr), x::ITensor; kwargs...) return rrule_via_ad(config, ITensors._tr, x; kwargs...) end diff --git a/src/ITensorChainRules/mps/mpo.jl b/src/ITensorChainRules/mps/mpo.jl index 78c0d1fba2..c33b620ab8 100644 --- a/src/ITensorChainRules/mps/mpo.jl +++ b/src/ITensorChainRules/mps/mpo.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) y = contract(x1, x2; kwargs...) function contract_pullback(ȳ) x̄1 = contract(ȳ, dag(x2); kwargs...) @@ -8,11 +8,11 @@ function ChainRulesCore.rrule(::typeof(contract), x1::MPO, x2::MPO; kwargs...) return y, contract_pullback end -function ChainRulesCore.rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(*), x1::MPO, x2::MPO; kwargs...) return rrule(contract, x1, x2; kwargs...) end -function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) y = +(x1, x2; kwargs...) function add_pullback(ȳ) return (NoTangent(), ȳ, ȳ) @@ -20,11 +20,11 @@ function ChainRulesCore.rrule(::typeof(+), x1::MPO, x2::MPO; kwargs...) return y, add_pullback end -function ChainRulesCore.rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...) +function rrule(::typeof(-), x1::MPO, x2::MPO; kwargs...) return rrule(+, x1, -x2; kwargs...) end -function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...) +function rrule(::typeof(tr), x::MPO; kwargs...) y = tr(x; kwargs...) function tr_pullback(ȳ) s = noprime(firstsiteinds(x)) @@ -40,7 +40,7 @@ function ChainRulesCore.rrule(::typeof(tr), x::MPO; kwargs...) return y, tr_pullback end -function ChainRulesCore.rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) +function rrule(::typeof(inner), x1::MPS, x2::MPO, x3::MPS; kwargs...) if !hassameinds(siteinds, x1, (x2, x3)) || !hassameinds(siteinds, x3, (x2, x1)) error( "Taking gradients of `inner(x::MPS, A::MPO, y::MPS)` is not supported if the site indices of the input MPS and MPO don't match. Try using if you input `inner(x, A, y), try `inner(x', A, y)` instead.", diff --git a/src/ITensorChainRules/zygoterules.jl b/src/ITensorChainRules/zygoterules.jl index 65e82549c2..3a3607ea8b 100644 --- a/src/ITensorChainRules/zygoterules.jl +++ b/src/ITensorChainRules/zygoterules.jl @@ -2,11 +2,9 @@ using ZygoteRules: @adjoint # Needed for defining the rule for `adjoint(A::ITensor)` # which currently doesn't work by overloading `ChainRulesCore.rrule` +# since it is defined in `Zygote`, which takes precedent. @adjoint function Base.adjoint(x::Union{ITensor,MPS,MPO}) - y = prime(x) - function adjoint_pullback(ȳ) - x̄ = inv_op(prime, ȳ) - return (x̄,) - end + y, adjoint_rrule_pullback = rrule(adjoint, x) + adjoint_pullback(ȳ) = Base.tail(adjoint_rrule_pullback(ȳ)) return y, adjoint_pullback end