Skip to content

Commit

Permalink
in-place destructure
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 2, 2023
1 parent 1cd1e87 commit 89c8d43
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export AbstractRule
include("adjust.jl")

include("destructure.jl")
export destructure
export destructure, destructure!

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
Expand Down
73 changes: 68 additions & 5 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Copies all [`trainable`](@ref), [`isnumeric`](@ref) parameters in the model
to a vector, and returns also a function which reverses this transformation.
Differentiable.
See also [`destructure!`](@ref).
# Example
```jldoctest
julia> v, re = destructure((x=[1.0, 2.0], y=(sin, [3.0 + 4.0im])))
Expand All @@ -31,6 +33,36 @@ function destructure(x)
flat, Restructure(x, off, len)
end

"""
destructure!(model) -> vector, reconstructor
This is a variant of [`destructure`](@ref), whose reconstruction function mutates the model.
Requires that all trainable parameters in the model be mutable arrays!
# Example
```jldoctest
julia> m = (x=[1.0, 2.0], y=(sin, Float32[3.0 4.0], cos))
julia> v, re! = destructure!(m)
([1.0, 2.0, 3.0, 4.0], Restructure!(NamedTuple, ..., 4))
julia> m === re!([3, 5, 7, 9]) # mutates the original m, and returns it
true
julia> m
(x = [3.0, 5.0], y = (sin, Float32[7.0 9.0], cos))
```
"""
function destructure!(x)
flat, off, len = _flatten(x)
flat, Restructure!(x, off, len)
end

# function destructure!(flat::AbstractVector, x)
# flat, off, len = _flatten!(flat, x)
# flat, Restructure!(x, off, len)
# end

"""
Restructure(Model, ..., length)
Expand All @@ -55,12 +87,20 @@ struct Restructure{T,S}
model::T
offsets::S
length::Int
mutate::Bool
end
(re::Restructure)(flat::AbstractVector) = _rebuild(re.model, re.offsets, flat, re.length)
Restructure(model, offsets, length) = Restructure(model, offsets, length, false)
Restructure!(model, offsets, length) = Restructure(model, offsets, length, true)

(re::Restructure)(flat::AbstractVector) = re.mutate ? _rebuild!(re.model, re.offsets, flat, re.length) : _rebuild(re.model, re.offsets, flat, re.length)
(re::Restructure)(x, flat::AbstractVector) = re(flat)(x)
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")")
Base.length(re::Restructure) = re.length

function Base.show(io::IO, re::Restructure{T}) where T
print(io, "Restructure", re.mutate ? "!" : "")
print(io, "(", T.name.name, ", ..., ", re.length, ")")

Check warning on line 101 in src/destructure.jl

View check run for this annotation

Codecov / codecov/patch

src/destructure.jl#L99-L101

Added lines #L99 - L101 were not covered by tests
end

# This flattens a model, and returns a web of offsets for later use:
function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
Expand All @@ -75,6 +115,17 @@ function _flatten(x)
isempty(arrays) && return Bool[], off, 0
reduce(vcat, arrays), off, len[]
end
# function _flatten!(flat, x)
# isnumeric(x) && return copyto!(flat, _vec(x)) # trivial case
# len = Ref(0)
# off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
# o = len[]
# copyto!(flat, o, _vec(y))
# len[] = o + length(y)
# o
# end
# flat, off, len[]
# end

struct _TrainableStructWalk <: AbstractWalk end

Expand All @@ -97,10 +148,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trai
_getat(y, o, flat)
end
end
# (mutating version, same arguments & same return)
function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
copyto!(y, _getat(y, o, flat, view))
end
x
end

_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1])

Check warning on line 160 in src/destructure.jl

View check run for this annotation

Codecov / codecov/patch

src/destructure.jl#L160

Added line #L160 was not covered by tests
_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) =
ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes

struct _Trainable_biwalk <: AbstractWalk end

Expand Down Expand Up @@ -135,6 +194,10 @@ function ChainRulesCore.rrule(::typeof(_rebuild), x, off, flat, len; kw...)
_rebuild_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
_rebuild(x, off, flat, len; kw...), _rebuild_back
end
function ChainRulesCore.rrule(::typeof(_rebuild!), x, off, flat, len; kw...)
_rebuild!_back(dx) = (NoT, NoT, NoT, _grad!(x, unthunk(dx), off, _zero(flat)), NoT)
_rebuild!(x, off, flat, len; kw...), _rebuild!_back
end

_zero(x) = map!(zero, similar(x, float(eltype(x))), x) # mutable zero array for _grad!
ChainRulesCore.@non_differentiable _zero(x)
Expand Down
42 changes: 42 additions & 0 deletions test/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ m9 = (a = m1, b = mat, c = [mat, m1])
@test destructure(m9)[1] == 1:7

@test destructure(m1)[2](7:9) == [7,8,9]
@test m1 == 1:3 # not mutated
@test destructure(m2)[2](4:9) == ([4,5,6], [7,8,9])
@test destructure(m3)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9])
@test m3.z == 4:6 # not mutated
m4′ = destructure(m4)[2](4:9)
@test m4′ == (x = [4,5,6], y = [4,5,6], z = [7,8,9])
@test m4′.x === m4′.y
Expand Down Expand Up @@ -60,11 +62,31 @@ m9 = (a = m1, b = mat, c = [mat, m1])
@test_throws Exception destructure(m7)[2]([10,20,30,40])
end

@testset "destructure!" begin
m3′ = deepcopy(m3)
@test destructure!(m3′)[1] == 1:6
@test destructure!(m3′)[2](4:9) == (x = [4,5,6], y = sin, z = [7,8,9])
@test m3′ == (x = [4,5,6], y = sin, z = [7,8,9])

m7′ = deepcopy(m7)
@test destructure!(m7′)[1] == 1:3
destructure!(m7′)[2]([10,20,30])
@test m7′.a == (sin, [10,20,30])
@test m7′.b == (cos, [4,5,6])
@test m7′.c == (tan, [7,8,9])

# errors
@test_throws Exception destructure!(m7)[2]([10,20])
@test_throws Exception destructure!(m7)[2]([10,20,30,40])
end

@testset "gradient of flatten" begin
@test gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
@test gradient(m -> destructure!(m)[1][1], m1)[1] == [1,0,0]
@test gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
@test gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
@test gradient(m -> destructure(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
@test gradient(m -> destructure!(m)[1][1], m3)[1] == (x = [1,0,0], y = nothing, z = [0,0,0])
@test gradient(m -> destructure(m)[1][2], m4)[1] == (x = [0,1,0], y = nothing, z = [0,0,0])

g5 = gradient(m -> destructure(m)[1][3], m5)[1]
Expand Down Expand Up @@ -206,6 +228,26 @@ end
end
end

@testset "gradient of rebuild!" begin
re1 = destructure!(deepcopy(m1))[2]
@test gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]

re2 = destructure!(deepcopy(m2))[2]
@test gradient(x -> re2(x)[1][2], rand(6))[1] == [0,1,0,0,0,0]

re3 = destructure!(deepcopy(m3))[2]
@test gradient(x -> re3(x).x[3], rand(6))[1] == [0,0,1,0,0,0]
@test gradient(x -> re3(x).z[1], rand(6))[1] == [0,0,0,1,0,0]

re4 = destructure!(deepcopy(m4))[2]
@test gradient(x -> re4(x).x[1], rand(6))[1] == [1,0,0,0,0,0]
@test gradient(x -> re4(x).y[2], rand(6))[1] == [0,1,0,0,0,0]
@test gradient(rand(6)) do x
m = re4(x)
m.x[1] + 2*m.y[2] + 3*m.z[3]
end[1] == [1,2,0, 0,0,3]
end

@testset "Flux issue 1826" begin
v, re = destructure((x=[1,2.0], y=[3,4,5.0]))
@test gradient(zero(v)) do w
Expand Down

0 comments on commit 89c8d43

Please sign in to comment.