Skip to content

Commit

Permalink
Automatic Differentiation test rewrite (#150)
Browse files Browse the repository at this point in the history
* Add missing unthunk in rrules

* rewrite AD tests using ChainRulesTestUtils

* Add rrule for `tensorscalar`

* Add mixed scalartype tests
  • Loading branch information
lkdvos authored Sep 29, 2023
1 parent 3a0b1e0 commit e7c5192
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 156 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ julia = "1.6"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["Test", "Random", "DynamicPolynomials", "Zygote", "CUDA", "cuTENSOR", "Aqua", "Logging"]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging"]
17 changes: 14 additions & 3 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ trivtuple(N) = ntuple(identity, N)

# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy

function ChainRulesCore.rrule(::typeof(tensorscalar), C)
function tensorscalar_pullback(Δc)
ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C))
return NoTangent(), fill!(ΔC, unthunk(Δc))
end
return tensorscalar(C), tensorscalar_pullback
end

function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
C, pC::Index2Tuple,
A, conjA::Symbol,
Expand All @@ -33,7 +41,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC)
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm(linearize(pC))
Expand Down Expand Up @@ -76,7 +85,8 @@ function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC)
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
ipC = invperm(linearize(pC))
pΔC = (TupleTools.getindices(ipC, trivtuple(numout(pA))),
TupleTools.getindices(ipC, numout(pA) .+ trivtuple(numin(pB))))
Expand Down Expand Up @@ -141,7 +151,8 @@ function ChainRulesCore.rrule(::typeof(tensortrace!), C, pC::Index2Tuple, A,
projectα = ProjectTo(α)
projectβ = ProjectTo(β)

function pullback(ΔC)
function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dA = @thunk begin
ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...))
Expand Down
211 changes: 60 additions & 151 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,161 +1,70 @@
using TensorOperations
using Test
using Zygote
using LinearAlgebra
using Base.Iterators: product
Zygote.refresh()

function LinAlg_tensoradd(A, pA, conjA, B, pB, conjB, α=true, β=true)
return α * permutedims(conjA == :N ? A : conj(A), linearize(pA)) +
β * permutedims(conjB == :N ? B : conj(B), linearize(pB))
end
function LinAlg_tensorcontract(C, pC, A, pA, conjA, B, pB, conjB, α=true, β=false)
szA(i) = size(A, i)
A′ = reshape(permutedims(conjA == :N ? A : conj(A), linearize(pA)), prod(szA.(pA[1])),
prod(szA.(pA[2])))
szB(i) = size(B, i)
B′ = reshape(permutedims(conjB == :N ? B : conj(B), linearize(pB)), prod(szB.(pB[1])),
prod(szB.(pB[2])))
C′ = reshape(A′ * B′, szA.(pA[1])..., szB.(pB[2])...)
return β * C + α * permutedims(C′, linearize(pC))
end
function LinAlg_tensortrace(C, pC, A, pA, conjA, α=true, β=false)
szA(i) = size(A, i)
A′ = reshape(permutedims(conjA == :N ? A : conj(A),
(linearize(pC)..., pA[1]..., pA[2]...)),
prod(szA.(linearize(pC))), prod(szA.(pA[1])), prod(szA.(pA[2])))
C′ = map(i -> tr(A′[i, :, :]), axes(A′, 1))
return β * C + α * reshape(C′, szA.(linearize(pC)))
using ChainRulesTestUtils

ChainRulesTestUtils.test_method_tables()

precision(::Type{<:Union{Float32,Complex{Float32}}}) = 1e-2
precision(::Type{<:Union{Float64,Complex{Float64}}}) = 1e-8

@testset "tensortrace! ($T₁, $T₂)" for (T₁, T₂) in ((Float64, Float64), (Float32, Float64),
(ComplexF64, ComplexF64), (Float64, ComplexF64))
T = promote_type(T₁, T₂)
atol = max(precision(T₁), precision(T₂))
rtol = max(precision(T₁), precision(T₂))

pC = ((3, 5, 2), ())
pA = ((1,), (4,))
α = rand(T)
β = rand(T)
A = rand(T₁, (2, 3, 4, 2, 5))
C = rand(T₂, size.(Ref(A), pC[1]))
test_rrule(tensortrace!, C, pC, A, pA, :N, α, β; atol, rtol)
end

precision(T::Type{<:Complex}) = precision(real(T))
precision(T::Type{<:Number}) = eps(T)^(3 / 4)

@testset "tensoradd" begin
f(A, B) = tensoradd(A, ((1, 2, 3), ()), :N, B, ((1, 3, 2), ()), :N)
f′(A, B) = LinAlg_tensoradd(A, ((1, 2, 3), ()), :N, B, ((1, 3, 2), ()), :N)

@testset for T in (Float64, ComplexF64)
A = rand(T, 2, 3, 4)
B = rand(T, 2, 4, 3)

C, pullback = Zygote.pullback(f, A, B)
C′, pullback′ = Zygote.pullback(f′, A, B)

@test C C′ rtol = precision(T)

ΔC = rand(T, size(C))
ΔA, ΔB = pullback(ΔC)
ΔA′, ΔB′ = pullback′(ΔC)
@test ΔA ΔA′ rtol = precision(T)
@test ΔB ΔB′ rtol = precision(T)

D = rand(T, 4, 2, 3, 2)
E = rand(T, 2, 3, 4, 2)
α = rand(T)
β = rand(T)

pD = ((2, 1, 4, 3), ())
pE = ((1, 3, 4, 2), ())

for conjD in (:N, :C), conjE in (:N, :C)
F, pullback2 = Zygote.pullback(tensoradd, D, pD, conjD, E, pE, conjE, α, β)
F′, pullback2′ = Zygote.pullback(LinAlg_tensoradd, D, pD, conjD, E, pE, conjE,
α, β)
@test F F′ rtol = precision(T)

ΔF = rand(T, size(F))
ΔD, ΔpD, ΔconjD, ΔE, ΔpE, ΔconjE, Δα, Δβ = pullback2(ΔF)
ΔD′, ΔpD′, ΔconjD′, ΔE′, ΔpE′, ΔconjE′, Δα′, Δβ′ = pullback2′(ΔF)
@test ΔD ΔD′ rtol = precision(T)
@test ΔE ΔE′ rtol = precision(T)
@test Δα Δα′ rtol = precision(T)
@test Δβ Δβ′ rtol = precision(T)
end
end
@testset "tensoradd! ($T₁, $T₂)" for (T₁, T₂) in ((Float64, Float64), (Float32, Float64),
(ComplexF64, ComplexF64), (Float64, ComplexF64))
T = promote_type(T₁, T₂)
atol = max(precision(T₁), precision(T₂))
rtol = max(precision(T₁), precision(T₂))

pC = ((2, 1, 4, 3, 5), ())
A = rand(T₁, (2, 3, 4, 2, 1))
C = rand(T₂, size.(Ref(A), pC[1]))
α = rand(T)
β = rand(T)
test_rrule(tensoradd!, C, pC, A, :N, α, β; atol, rtol)
test_rrule(tensoradd!, C, pC, A, :C, α, β; atol, rtol)
end

@testset "tensorcontract" begin
@testset for T in (Float64, ComplexF64)
A = rand(T, 2, 4, 3, 2)
B = rand(T, 1, 3, 2)
C = rand(T, 1, 4, 2)

α = rand(T)
β = rand(T)

pA = ((2, 4), (1, 3))
pB = ((3, 2), (1,))
pC = ((3, 1, 2), ())

for conjA in (:N, :C), conjB in (:N, :C)
D, pullback = Zygote.pullback(tensorcontract!, C, pC, A, pA, conjA, B, pB,
conjB, α,
β)
D′, pullback′ = Zygote.pullback(LinAlg_tensorcontract, C, pC, A, pA, conjA, B,
pB,
conjB, α, β)

@test D D′ rtol = precision(T)
ΔD = rand(T, size(D))
ΔC, ΔpC, ΔA, ΔpA, ΔconjA, ΔB, ΔpB, ΔconjB, Δα, Δβ = pullback(ΔD)
ΔC′, ΔpC′, ΔA′, ΔpA′, ΔconjA′, ΔB′, ΔpB′, ΔconjB′, Δα′, Δβ′ = pullback′(ΔD)
@test ΔC ΔC′ rtol = precision(T)
@test ΔA ΔA′ rtol = precision(T)
@test ΔB ΔB′ rtol = precision(T)
@test Δα Δα′ rtol = precision(T)
@test Δβ Δβ′ rtol = precision(T)
end
end
@testset "tensorcontract! ($T₁, $T₂)" for (T₁, T₂) in
((Float64, Float64), (Float32, Float64),
(ComplexF64, ComplexF64), (Float64, ComplexF64))
T = promote_type(T₁, T₂)
atol = max(precision(T₁), precision(T₂))
rtol = max(precision(T₁), precision(T₂))

pC = ((3, 2, 4, 1), ())
pA = ((2, 4, 5), (1, 3))
pB = ((2, 1), (3,))

A = rand(T₁, (2, 3, 4, 2, 5))
B = rand(T₂, (4, 2, 3))
C = rand(T, (5, 2, 3, 3))
α = randn(T)
β = randn(T)

test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :N, α, β; atol, rtol)
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :N, α, β; atol, rtol)
test_rrule(tensorcontract!, C, pC, A, pA, :N, B, pB, :C, α, β; atol, rtol)
test_rrule(tensorcontract!, C, pC, A, pA, :C, B, pB, :C, α, β; atol, rtol)
end

@testset "tensortrace" begin
# single trace index, homogeneous scalar type, no conjugation
@testset for T in (Float64, ComplexF64)
A = rand(T, 2, 3, 4, 2)
C = rand(T, 4, 3)
α = rand(T)
β = rand(T)

pA = ((1,), (4,))
pC = ((3, 2), ())

conjA = :N

D, pullback = Zygote.pullback(tensortrace!, C, pC, A, pA, conjA, α, β)
D′, pullback′ = Zygote.pullback(LinAlg_tensortrace, C, pC, A, pA, conjA, α, β)
@test D D′ rtol = precision(T)

ΔD = rand(T, size(D))
ΔC, ΔpC, ΔA, ΔpA, ΔconjA, Δα, Δβ = pullback(ΔD)
ΔC′, ΔpC′, ΔA′, ΔpA′, ΔconjA′, Δα′, Δβ′ = pullback′(ΔD)
@test ΔC ΔC′ rtol = precision(T)
@test ΔA ΔA′ rtol = precision(T)
@test Δα Δα′ rtol = precision(T)
@test Δβ Δβ′ rtol = precision(T)
end
# multiple trace indices, mixed scalar types, conjugation
@testset for T in (Float64, ComplexF64)
A = rand(T, 2, 4, 3, 3, 4, 2, 3, 4)
C = rand(T, 4, 3)
α = rand(T)
β = rand(real(T))

pA = ((1, 2, 7), (6, 8, 3))
pC = ((5, 4), ())

conjA = :C

D, pullback = Zygote.pullback(tensortrace!, C, pC, A, pA, conjA, α, β)
D′, pullback′ = Zygote.pullback(LinAlg_tensortrace, C, pC, A, pA, conjA, α, β)
@test D D′ rtol = precision(T)
@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64)
atol = precision(T)
rtol = precision(T)

ΔD = rand(T, size(D))
ΔC, ΔpC, ΔA, ΔpA, ΔconjA, Δα, Δβ = pullback(ΔD)
ΔC′, ΔpC′, ΔA′, ΔpA′, ΔconjA′, Δα′, Δβ′ = pullback′(ΔD)
@test ΔC ΔC′ rtol = precision(T)
@test ΔA ΔA′ rtol = precision(T)
@test Δα Δα′ rtol = precision(T)
@test Δβ Δβ′ rtol = precision(T)
end
C = Array{T,0}(undef, ())
fill!(C, rand(T))
test_rrule(tensorscalar, C; atol, rtol)
end

0 comments on commit e7c5192

Please sign in to comment.