Skip to content

Commit

Permalink
add trainables (#171)
Browse files Browse the repository at this point in the history
* trainables

* trainables

* cl/trainables

* trainables

* test second order derivatives

* add doc section

* fix test

* Update src/trainables.jl
  • Loading branch information
CarloLucibello authored Apr 4, 2024
1 parent b4920f7 commit a87ffd5
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Such restrictions are also obeyed by this function for flattening a model:
```@docs
Optimisers.destructure
Optimisers.Restructure
Optimisers.trainables
```

## Rule Definition
Expand Down
26 changes: 26 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,29 @@ flat, re = destructure(params)
end
```

## Collecting all trainable parameters

Sometimes it is useful to collect all trainable parameters in a model,
similarly to what [`destructure`](@ref Optimisers.destructure) does but without
concatenating the arrays into a flat vector.
This is done by [`trainables`](@ref Optimisers.trainables), which returns a list of arrays:

```julia
julia> using Flux, Optimisers

julia> model = Chain(Dense(2 => 3, tanh), BatchNorm(3), Dense(3 => 2));

julia> trainables(model)
6-element Vector{AbstractArray}:
Float32[0.5756773 -0.1975264; 0.4723181 -0.7546912; -0.91631395 0.07392061]
Float32[0.0, 0.0, 0.0]
Float32[0.0, 0.0, 0.0]
Float32[1.0, 1.0, 1.0]
Float32[-0.8764882 0.40812716 0.1919528; -0.9123545 -0.4462516 0.6751252]
Float32[0.0, 0.0]

julia> l2reg(model) = sum([sum(abs2,p) for p in trainables(model)]);

julia> g = gradient(l2reg, model)[1];
```
Notice that the `BatchNorm` layer has two trainable parameters, `γ` and `β`, which are included in the list, while the `μ ` and `σ²` buffers are not.
3 changes: 3 additions & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ include("adjust.jl")
include("destructure.jl")
export destructure

include("trainables.jl")
export trainables

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
Expand Down
9 changes: 5 additions & 4 deletions src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ function _flatten(x)
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
arrays = AbstractVector[]
len = Ref(0)
off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
off = fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, _vec(y))
o = len[]
len[] = o + length(y)
o
end
isempty(arrays) && return Bool[], off, 0
reduce(vcat, arrays), off, len[]
return reduce(vcat, arrays), off, len[]
end

struct _TrainableStructWalk <: AbstractWalk end
struct TrainableStructWalk <: AbstractWalk end

(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
Expand Down Expand Up @@ -174,3 +174,4 @@ function ChainRulesCore.rrule(::typeof(_maybewarn))
@warn "second derivatives of destructure may not work yet, sorry!" maxlog=3
nothing, _ -> (NoT,)
end

1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ and `trainable(x)` must contain a subset of these.
"""
trainable(x) = functor(x)[1]

# like trainable(x), but also tries to output non-trainable children giving value nothing
_trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
Expand Down
59 changes: 59 additions & 0 deletions src/trainables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

"""
trainables(x)
Return a list over all the trainable parameters in `x`, that is all the numerical
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).
Parameters appearing multiple times in the model (tied weights) will be present only once in the output.
See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.
# Examples
```jldoctest
julia> struct MyLayer
w
b
end
julia> Functors.@functor MyLayer
julia> Optimisers.trainable(x::MyLayer) = (; w = x.w,) # only w is trainable in this example
julia> x = MyLayer([1.0,2.0,3.0], [4.0,5.0,6.0]);
julia> trainables(x)
1-element Vector{AbstractArray}:
[1.0, 2.0, 3.0]
julia> x = MyLayer((a=[1.0,2.0], b=[3.0]), [4.0,5.0,6.0]);
julia> trainables(x) # collects nested parameters
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
"""
function trainables(x)
arrays = AbstractArray[]
exclude(x) = Optimisers.isnumeric(x)
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end

function ∇trainables(x, Δ)
exclude(x) = Optimisers.isnumeric(x)
i = 0
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end

function ChainRulesCore.rrule(::typeof(trainables), x)
y = trainables(x)
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
return y, trainables_back
end
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Random.seed!(1)

struct Foo; x; y; end
Functors.@functor Foo
Optimisers.trainable(x::Foo) = (x.y, x.x)
Optimisers.trainable(x::Foo) = (; x.y, x.x)

struct TwoThirds a; b; c; end
Functors.@functor TwoThirds (a, c)
Expand Down Expand Up @@ -539,6 +539,9 @@ end
@testset verbose=true "Destructure" begin
include("destructure.jl")
end
@testset verbose=true "Trainables" begin
include("trainables.jl")
end
@testset verbose=true "Optimisation Rules" begin
include("rules.jl")
end
Expand Down
115 changes: 115 additions & 0 deletions test/trainables.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

m1 = collect(1:3.0)
m2 = (collect(1:3.0), collect(4:6.0))
m3 = (x = m1, y = sin, z = collect(4:6.0))

m4 = (x = m1, y = m1, z = collect(4:6.0)) # tied
m5 = (a = (m3, true), b = (m1, false), c = (m4, true))
m6 = (a = m1, b = [4.0 + im], c = m1)

m7 = TwoThirds((sin, collect(1:3.0)), (cos, collect(4:6.0)), (tan, collect(7:9.0)))
m8 = [Foo(m1, m1), (a = true, b = Foo([4.0], false), c = ()), [[5.0]]]

mat = Float32[4 6; 5 7]
m9 = (a = m1, b = mat, c = [mat, m1])

@testset "trainables" begin
ps = trainables(m1)
@test ps isa Vector
@test length(ps) == 1
@test ps[1] == m1

ps = trainables(m2)
@test ps isa Vector
@test length(ps) == 2
@test ps[1] == m2[1]
@test ps[2] == m2[2]

ps = trainables(m3)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == 4:6

ps = trainables(m4)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == 4:6

ps = trainables(m5)
@test length(ps) == 3
@test ps[1] == 1:3
@test ps[2] == 4:6
@test ps[3] == 4:6

ps = trainables(m6)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == ComplexF64[4.0 + 1.0im]

ps = trainables(m7)
@test length(ps) == 1
@test ps[1] == [1.0, 2.0, 3.0]

ps = trainables(m8)
@test length(ps) == 3
@test ps[1] == 1:3
@test ps[2] == [4.0]
@test ps[3] == [5.0]

ps = trainables(m9)
@test length(ps) == 2
@test ps[1] == 1:3
@test ps[2] == mat
end

@testset "gradient" begin
loss(m) = sum([sum(abs2, p) for p in trainables(m)])
g = gradient(loss, m1)[1]
@test g == [2.0, 4.0, 6.0]

g = gradient(loss, m2)[1]
@test g == ([2.0, 4.0, 6.0], [8.0, 10.0, 12.0])

g = gradient(loss, m3)[1]
@test g.x == [2.0, 4.0, 6.0]
@test g.y === nothing
@test g.z == [8.0, 10.0, 12.0]

g = gradient(loss, m4)[1]
@test g == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0])
g.x === g.y # shared gradient for shared weights

g = gradient(loss, m5)[1]
@test g == (a = ((x = [2.0, 4.0, 6.0], y = nothing, z = [8.0, 10.0, 12.0]), nothing), b = ([2.0, 4.0, 6.0], nothing), c = ((x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0], z = [8.0, 10.0, 12.0]), nothing))

g = gradient(loss, m6)[1]
@test g == (a = [2.0, 4.0, 6.0], b = ComplexF64[8.0 + 2.0im], c = [2.0, 4.0, 6.0])

g = gradient(loss, m7)[1]
@test g == (a = (nothing, [2.0, 4.0, 6.0]), b = nothing, c = nothing)

g = gradient(loss, m8)[1]
@test g[1] == (x = [2.0, 4.0, 6.0], y = [2.0, 4.0, 6.0])
@test g[2] == (a = nothing, b = (x = [8.0], y = nothing), c = nothing)
@test g[3] == [[10.0]]

g = gradient(loss, m9)[1]
@test g == (a = [2.0, 4.0, 6.0], b = Float32[8.0 12.0; 10.0 14.0], c = Array[Float32[8.0 12.0; 10.0 14.0], [2.0, 4.0, 6.0]])
end

@testset "second order derivatives" begin
struct DenseLayer
w
b
end

Functors.@functor DenseLayer

loss(m) = sum([sum(abs2, p) for p in trainables(m)])

model = DenseLayer([1. 2.; 3. 4.], [0., 0.])

g = gradient(m -> loss(gradient(loss, m)), model)[1]
@test g.w == [8.0 16.0; 24.0 32.0]
@test g.b == [0.0, 0.0]
end

0 comments on commit a87ffd5

Please sign in to comment.