Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 10, 2022
1 parent 520efbe commit af14f84
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
24 changes: 12 additions & 12 deletions src/destructure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, unthunk
const NoT = NoTangent()

"""
Expand All @@ -11,11 +11,11 @@ Differentiable.
# Example
```jldoctest
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0])))
([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3 + 4im])))
(ComplexF64[1.0 + 0.0im, 2.0 + 0.0im, 3.0 + 4.0im], Restructure(NamedTuple, ..., 3))
julia> re([10,20,30])
(x = [10.0, 20.0], y = (sin, [30.0]))
julia> re([3, 5-im, 7+11im])
(x = [3.0, 5.0], y = (sin, ComplexF64[7.0 + 11.0im]))
```
"""
function destructure(x)
Expand All @@ -27,7 +27,7 @@ end
Restructure(Model, ..., length)
This is what [`destructure`](@ref) returns, and `re(p)` will re-build the model with
new parameters from vector `p`. If the model is callable, then `re(x, p)` .
new parameters from vector `p`. If the model is callable, then `re(x, p) == re(p)(x)`.
# Example
```julia
Expand Down Expand Up @@ -107,22 +107,22 @@ end

function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len)
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
_rebuild_back(dx) = (NoT, NoT, NoT, _accumulate!(x, dx, off, dflat))
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, dflat))
_rebuild(x, off, flat; len), _rebuild_back
end

# This is the gradient of model reconstruction, accumulating duplicates:
function _accumulate!(x, dx, off, flat::AbstractVector)
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
dx′, _ = functor(typeof(x), dx)
off′, _ = functor(typeof(x), off)
foreach((xᵢ, dxᵢ, oᵢ) -> _accumulate!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′)
flat
end
function _accumulate!(x, dx, off::Integer, flat::AbstractVector)
function _grad!(x, dx, off::Integer, flat::AbstractVector)
@views flat[off .+ (1:length(x))] .+= dx # must visit all tied nodes
flat
end
_accumulate!(x, dx::Zero, off, flat::AbstractVector) = nothing
_accumulate!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity
_grad!(x, dx::Zero, off, flat::AbstractVector) = nothing
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@testset verbose=true "Destructure" begin
include("destructure.jl")
end
@info "finished feature testing"
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
Expand Down

0 comments on commit af14f84

Please sign in to comment.