Skip to content

Commit

Permalink
second derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 10, 2022
1 parent 17b57f0 commit 6e4f634
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,8 @@ end
_grad!(x, dx::Zero, off, flat::AbstractVector) = nothing
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity

function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
println("grad! fwd ", length(flat))
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
_grad!(x, dx, off, flat), _grad_back
end
29 changes: 21 additions & 8 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ end
@test g6.a isa Vector{Float64}
@test g6.b == [0+im]

# Second derivative -- no method matching rrule(::typeof(Optimisers._rebuild), ...?
@test_broken gradient([1,2,3]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6]))[1][1])
end[1] [8,16,24]
@testset "second derivative" begin
@test_broken gradient([1,2,3.0]) do v
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (v, [4,5,6.0]))[1][1])
end[1] [8,16,24]

@test_skip gradient([1,2,3.0]) do v
sum(gradient(m -> sum(destructure(m)[1]), (v, [4,5,6.0]))[1][1])
end
end
end

@testset "gradient of rebuild" begin
Expand All @@ -85,10 +90,18 @@ end
@test gradient(x -> re7(x).b[2][2], rand(3))[1] == [0,0,0]
@test gradient(x -> re7(x).c[2][1], rand(3))[1] == [0,0,0]

# Second derivative -- error from _tryaxes(x::Tangent) in Zygote's map rule
@test_broken gradient(collect(1:6)) do y
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
end[1] [8,16,24,0,0,0]
@testset "second derivative" begin
# ERROR: Need an adjoint for constructor ChainRulesCore.Tangent{Any, Tuple{Vector{Float64}, ChainRulesCore.ZeroTangent}}. Gradient is of type Tuple{Vector{Float64}, Vector{Float64}}
@test_broken gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re2(x)[1]), y)[1])
end[1] [8,16,24,0,0,0]
# This fixes it!
# Zygote.@adjoint Tangent{T,B}(x::Tuple) where {T,B<:Tuple} = Tangent{T,B}(x), dx -> (dx,)
@test_skip gradient(collect(1:6.0)) do y
sum(abs2, gradient(x -> sum(abs2, re3(x).z), y)[1])
end[1]
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
end
end

@testset "Flux issue 1826" begin
Expand Down

0 comments on commit 6e4f634

Please sign in to comment.