Skip to content

Commit

Permalink
add trainables_with_path
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Apr 3, 2024
1 parent 292a82d commit 62cd081
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 107 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.4"
Functors = "0.4.9"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Such restrictions are also obeyed by this function for flattening a model:
Optimisers.destructure
Optimisers.Restructure
Optimisers.trainables
Optimisers.trainables_with_path
```

## Rule Definition
Expand All @@ -70,3 +71,11 @@ Optimisers.@..
Optimisers.@lazy
Optimisers.adjust(::AbstractRule, ::Real)
```

## KeyPath

```@docs
Functors.KeyPath
Functors.haskeypath
Functors.getkeypath
```
9 changes: 7 additions & 2 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
module Optimisers

using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
using Functors: functor, fmap, fmap_with_path,
KeyPath, haskeypath, getkeypath,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra

include("interface.jl")
export AbstractRule

include("utils.jl")

include("adjust.jl")

include("destructure.jl")
export destructure

include("trainables.jl")
export trainables
export trainables, trainables_with_path
export KeyPath, haskeypath, getkeypath # from Functors.jl

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
Expand Down
2 changes: 1 addition & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

struct TrainableStructWalk <: AbstractWalk end

(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
Expand Down
15 changes: 4 additions & 11 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function _setup(rule, x; cache)
cache[x] =
end
else
valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
end
end

Expand Down Expand Up @@ -82,7 +82,7 @@ function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
x′, re = functor(x)
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
x′′ = re(mapvalue((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
else # no ties to preserve between immutable structs, right?
Expand Down Expand Up @@ -115,7 +115,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
# functor(typeof(tree), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
foreachvalue((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# default all rules to first order calls
Expand Down Expand Up @@ -172,21 +172,14 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr)
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)

function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
map(c -> c in tr ? c : nothing, ch)
end


valuemap(f, x...) = map(f, x...)
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
valueforeach(f, x...) = foreach(f, x...)
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end


###
### rule definition helpers
Expand Down
71 changes: 67 additions & 4 deletions src/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,16 @@ julia> trainables(x)
"""
function trainables(x)
arrays = AbstractArray[]
exclude(x) = Optimisers.isnumeric(x)
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end

function ∇trainables(x, Δ)
exclude(x) = Optimisers.isnumeric(x)
i = 0
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end
Expand All @@ -57,3 +55,68 @@ function ChainRulesCore.rrule(::typeof(trainables), x)
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
return y, trainables_back
end

"""
trainables_with_path(x)
Return an iterable over all the trainable parameters in `x`, that is all the numerical
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).
The output is a list of `(KeyPath, AbstractArray)` pairs, where [`KeyPath`](@ref Functors.KeyPath) is a type
representing the path to the array in the original structure.
See also [`trainables`](@ref) and [`destructure`](@ref).
# Examples
```jldoctest
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));
julia> for (kp, y) in trainables_with_path(x)
println(kp, " => ", y)
end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]
julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
3.0
4.0
```
"""
function trainables_with_path(x)
named_params = []
exclude(kp, x) = isnumeric(x)
fmap_with_path(x; exclude, walk = TrainableStructWalkWithPath()) do kp, y
push!(named_params, (kp, y))
return y
end
return named_params
end

struct TrainableStructWalkWithPath <: AbstractWalk end

function (::TrainableStructWalkWithPath)(recurse, kp::KeyPath, x)
x_children = trainable(x)
kps = mapkey(c -> KeyPath(kp, c), x_children)
return mapvalue(recurse, kps, x_children)
end

function ChainRulesCore.rrule(::typeof(trainables_with_path), x)
y = trainables_with_path(x)
trainables_with_path_back(Δ) = (NoTangent(), ∇trainables_with_path(x, unthunk(Δ)))
return y, trainables_with_path_back
end

function ∇trainables_with_path(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
Δi = Δ[i+=1]
if isnothing(Δi)
return nothing
else
return Δi[2]
end
end
end
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
mapvalue(f, x...) = map(f, x...)

mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks))
mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x))
mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x))
mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)]

foreachvalue(f, x...) = foreach(f, x...)

foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end
Loading

0 comments on commit 62cd081

Please sign in to comment.