Skip to content

Commit

Permalink
cl/comp
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 13, 2024
1 parent c2ae321 commit fccdd11
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Functors: functor, fmap, fmap_with_path,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra


include("interface.jl")
export AbstractRule

Expand All @@ -16,7 +17,7 @@ include("destructure.jl")
export destructure

include("trainables.jl")
export trainables
export trainables, trainables_nt
export KeyPath, haskeypath, getkeypath # from Functors.jl

include("rules.jl")
Expand Down
105 changes: 103 additions & 2 deletions src/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end

function ∇trainables(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end
Expand Down Expand Up @@ -113,7 +113,7 @@ end

function ∇trainables_with_path(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
Δi = Δ[i+=1]
if isnothing(Δi)
return nothing
Expand All @@ -122,3 +122,104 @@ function ∇trainables_with_path(x, Δ)
end
end
end


### trainables_nt ######################

"""
trainables_nt(model) -> ps, re
Return a pair `(ps, re)` where `ps` is a nested named tuple with the same structure as
the trainable part of `model` and with leaves the trainable parameters.
Parameters are not copied, but the returned `ps` is a view into the original model.
The `re` is a function that reconstructs a model from the parameters,
i.e. `re(ps)` is the same as the origin `model` but with the trainable parameters replaced by `ps`.
# Examples
```jldoctest
julia> using Flux, Optimisers
julia> model = Chain(Dense(784, 32, relu), Dense(32, 10));
julia> ps, re = trainables_nt(model);
julia> ps.layers._1.weight === model.layers[1].weight
true
```
```jldoctest
julia> v = ComponentVector(ps)
julia> model2 = re(2 * v)
```
"""
function trainables_nt(x)
walknt = TrainableNamedTupleWalk()
ps = fmap(identity, x; exclude=isnumeric, walk=walknt, cache=nothing)
re = RestructureFromNT(x)
return ps, re
end


struct RestructureFromNT{T}
x::T
end

function (re::RestructureFromNT)(ps)
walk = RestructureFromNamedTupleWalk()
return fmap(re.x, ps; exclude=isnumeric, walk, cache=nothing) do y, p
return p
end
end

struct TrainableNamedTupleWalk <: AbstractWalk end

function (::TrainableNamedTupleWalk)(recurse, x)
ch = trainable(x)
y = map(recurse, make_named_tuple(ch))
return y
end

struct RestructureFromNamedTupleWalk <: AbstractWalk end

function (::RestructureFromNamedTupleWalk)(recurse, x, nt)
children, re = functor(x)
newchildren = map_commons(recurse, children, nt)
return re(newchildren)
end

function map_commons(f, x::NamedTuple{xkeys}, y) where {xkeys}
ykeys = propertynames(y)
vals = map(k -> k in ykeys ? f(x[k], getproperty(y, k)) : x[k], xkeys)
return NamedTuple{xkeys}(vals)
end

function map_commons(f, x::Tuple, y)
ykeys = propertynames(y)
vals = ntuple(length(x)) do i
k = Symbol("_", i)
k in ykeys ? f(x[i], getproperty(y, k)) : x[i]
end
return vals
end

function map_commons(f, x::Vector, y)
ykeys = propertynames(y)
vals = map(1:length(x)) do i
k = Symbol("_", i)
k in ykeys ? f(x[i], getproperty(y, k)) : x[i]
end
return vals
end

make_named_tuple(x::NamedTuple) = x
make_named_tuple(x::AbstractDict{Symbol}) = NamedTuple(x)
make_named_tuple(x::AbstractDict) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x))
make_named_tuple(x::Tuple) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)
make_named_tuple(x::Vector) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)

19 changes: 19 additions & 0 deletions test/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,22 @@ end
@test g.y == [2.0, 4.0, 6.0]
@test g.z === nothing
end

using Flux, Optimisers
using ComponentArrays
using Test


model0 = Chain(
Dense(784, 32, relu),
Dense(32, 10))

ps, re = trainables_nt(model0)
@test ps.layers._1.weight === model0[1].weight
model1 = re(ps)
@test model1[1].weight === ps.layers._1.weight

v = ComponentVector(ps)
v2 = 2 * v
model2 = re(v2)
@test model2[1].weight === v2.layers._1.weight

0 comments on commit fccdd11

Please sign in to comment.