diff --git a/src/destructure.jl b/src/destructure.jl index ffff14b5..3be43f6c 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -3,10 +3,10 @@ using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo const NoT = NoTangent() """ - destructure([T], model) -> vector, reconstructor + destructure(model) -> vector, reconstructor Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model -to a `Vector{T}`, and returns also a function which reverses this transformation. +to a vector, and returns also a function which reverses this transformation. Differentiable. # Example @@ -18,8 +18,8 @@ julia> re([10,20,30]) (x = [10.0, 20.0], y = (sin, [30.0])) ``` """ -function destructure(::Type{T}, x) where T - flat, off, len = alpha!(x, T[]) +function destructure(x) + flat, off, len = alpha(x) flat, Restucture(x, off, len) end @@ -32,19 +32,22 @@ end Base.show(io::IO, re::Restucture{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") # This flattens a model, and returns a web of offsets for later use: -function alpha!(x, flat::AbstractVector) - isempty(flat) || error("this won't work") - isnumeric(x) && return append!(flat, x), 0 # trivial case +function alpha(x) + isnumeric(x) && return vcat(vec(x)), 0, length(x) # trivial case + arrays = AbstractVector[] + len = Ref(0) off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y - append!(flat, y) - length(flat) - length(y) + push!(arrays, vec(y)) + o = len[] + len[] = o + length(y) + o end - flat, off, length(flat) + reduce(vcat, arrays), off, len[] end -function ChainRulesCore.rrule(::typeof(alpha!), x, flat) - flat′, off, len = alpha!(x, flat) - alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len), NoT) +function ChainRulesCore.rrule(::typeof(alpha), x) + flat, off, len = alpha(x) + alpha_back((dflat, _)) = (NoT, beta(x, off, dflat; walk = _Tangent_biwalk, prune = NoT, len)) (flat, off, len), alpha_back end @@ -100,14 +103,3 @@ function gamma!(x, dx, off::Integer, flat::AbstractVector) end gamma!(x, dx::Zero, off, flat::AbstractVector) = nothing gamma!(x, dx::Zero, off::Integer, flat::AbstractVector) = nothing # ambiguity - -# Least importantly, this infers the eltype if one is not given: -destructure(x) = destructure(omega(x), x) -function omega(x) - T = Bool - fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y - T = promote_type(T, eltype(y)) - end - T -end -ChainRulesCore.@non_differentiable omega(::Any) diff --git a/test/destructure.jl b/test/destructure.jl index c8a3c1f2..ad685c7b 100644 --- a/test/destructure.jl +++ b/test/destructure.jl @@ -8,9 +8,7 @@ m6 = (a = m1, b = [4.0 + im], c = m1) m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0))) @testset "flatten & restore" begin - @test destructure(Int, m1)[1] isa Vector{Int} @test destructure(m1)[1] isa Vector{Float64} - @test destructure(m1)[1] == 1:3 @test destructure(m2)[1] == 1:6 @test destructure(m3)[1] == 1:6