Skip to content

Commit

Permalink
make len positional, fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 10, 2022
1 parent 6f3eefa commit 17b57f0
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct Restructure{T,S}
offsets::S
length::Int
end
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat; len = re.length)
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
Base.length(re::Restructure) = re.length
Expand All @@ -69,13 +69,13 @@ end

function ChainRulesCore.rrule(::typeof(_flatten), x)
flat, off, len = _flatten(x)
_flatten_back((dflat, _)) = (NoT, _rebuild(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, dflat, len; walk = _Tangent_biwalk, prune = NoT))
(flat, off, len), _flatten_back
end

# This reconstructs either a model like x, or a gradient for it:
function _rebuild(x, off, flat::AbstractVector; len, walk = _trainable_biwalk, kw...)
len == length(flat) || error("wrong length")
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
_getat(y, o, flat)
end
Expand Down Expand Up @@ -105,12 +105,14 @@ function _Tangent_biwalk(f, x, aux) # use with prune = NoT
Tangent{typeof(x), typeof(y)}(y)
end

function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat; len)
dflat = map!(zero, similar(flat, float(eltype(flat))), flat)
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, dflat))
_rebuild(x, off, flat; len), _rebuild_back
function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
_rebuild(x, off, flat, len; kw...), _rebuild_back
end

_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad!
ChainRulesCore.@non_differentiable _zero(x)

# This is the gradient of model reconstruction, accumulating duplicates:
function _grad!(x, dx, off, flat::AbstractVector)
x′, _ = functor(typeof(x), x)
Expand Down

0 comments on commit 17b57f0

Please sign in to comment.