Skip to content

Commit

Permalink
replace append! with reduce(vcat, ...)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 7, 2022
1 parent 868903a commit 5a7bfc8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
40 changes: 16 additions & 24 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo
const NoT = NoTangent()

"""
destructure([T], model) -> vector, reconstructor
destructure(model) -> vector, reconstructor
Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
to a `Vector{T}`, and returns also a function which reverses this transformation.
to a vector, and returns also a function which reverses this transformation.
Differentiable.
# Example
Expand All @@ -18,8 +18,8 @@ julia> re([10,20,30])
(x = [10.0, 20.0], y = (sin, [30.0]))
```
"""
function destructure(::Type{T}, x) where T
flat, off, len = alpha!(x, T[])
function destructure(x)
flat, off, len = alpha(x)
flat, Restucture(x, off, len)
end

Expand All @@ -32,19 +32,22 @@ end
Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")

# This flattens a model, and returns a web of offsets for later use:
function alpha!(x, flat::AbstractVector)
isempty(flat) || error("this won't work")
isnumeric(x) && return append!(flat, x), 0 # trivial case
function alpha(x)
isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
append!(flat, y)
length(flat) - length(y)
push!(arrays, vec(y))
o = len[]
len[] = o + length(y)
o
end
flat, off, length(flat)
reduce(vcat, arrays), off, len[]
end

function ChainRulesCore.rrule(::typeof(alpha!), x, flat)
flat, off, len = alpha!(x, flat)
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT)
function ChainRulesCore.rrule(::typeof(alpha), x)
flat, off, len = alpha(x)
alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len))
(flat, off, len), alpha_back
end

Expand Down Expand Up @@ -100,14 +103,3 @@ function gamma!(x, dx, off::Integer, flat::AbstractVector)
end
gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing
gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity

# Least importantly, this infers the eltype if one is not given:
destructure(x) = destructure(omega(x), x)
function omega(x)
T = Bool
fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y
T = promote_type(T, eltype(y))
end
T
end
ChainRulesCore.@non_differentiable omega(::Any)
2 changes: 0 additions & 2 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ m6 = (a = m1, b = [4.0 + im], c = m1)
m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))

@testset "flatten & restore" begin
@test destructure(Int, m1)[1] isa Vector{Int}
@test destructure(m1)[1] isa Vector{Float64}

@test destructure(m1)[1] == 1:3
@test destructure(m2)[1] == 1:6
@test destructure(m3)[1] == 1:6
Expand Down

0 comments on commit 5a7bfc8

Please sign in to comment.