Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add path option to trainables #174

Merged
merged 6 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,3 @@ jobs:
file: lcov.info
continue-on-error: ${{ matrix.julia-version == 'nightly' }}

docs:
name: Documentation
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
with:
version: '1.6'
- run: |
julia --project=docs -e '
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
- run: |
julia --color=yes --project=docs/ -e '
using Optimisers
using Documenter
using Documenter: doctest
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)
doctest(Optimisers)'
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
7 changes: 7 additions & 0 deletions .github/workflows/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
28 changes: 28 additions & 0 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Documentation

on:
push:
branches:
- master # update to match your development branch (master, main, dev, trunk, ...)
tags: '*'
pull_request:

jobs:
build:
permissions:
contents: write
statuses: write
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1.10'
- uses: julia-actions/cache@v1
- name: Install dependencies
run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate()'
- name: Build and deploy
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # If authenticating with SSH deploy key
run: julia --project=docs/ docs/make.jl
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Functors = "0.4"
Functors = "0.4.9"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
Expand Down
Binary file modified docs/.DS_Store
Binary file not shown.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6 changes: 4 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Documenter, Optimisers, Zygote, StaticArrays, Functors

DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers); recursive = true)
DocMeta.setdocmeta!(Optimisers, :DocTestSetup, :(using Optimisers, Functors); recursive = true)
DocMeta.setdocmeta!(Functors, :DocTestSetup, :(using Functors); recursive = true)

makedocs(modules = [Optimisers],
makedocs(modules = [Optimisers, Functors],
doctest = false,
sitename = "Optimisers.jl",
pages = ["Home" => "index.md",
Expand All @@ -13,6 +14,7 @@ makedocs(modules = [Optimisers],
assets = ["assets/flux.css"],
prettyurls = get(ENV, "CI", nothing) == "true"
),
checkdocs = :none, # don't check that Functors' docstrings are all reported here
)

deploydocs(
Expand Down
14 changes: 14 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
```@meta
CollapsedDocStrings = true
```

## Optimisation Rules

Expand Down Expand Up @@ -72,3 +75,14 @@ Optimisers.@lazy
Optimisers.adjust(::AbstractRule, ::Real)
Optimisers.@def
```

## KeyPath

A `KeyPath` is a sequence of keys that can be used to access a value within a nested structure.
It is defined in Functors.jl and re-exported by Optimisers.jl here for convenience.

```@docs
Functors.KeyPath
Functors.haskeypath
Functors.getkeypath
```
7 changes: 6 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
module Optimisers

using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
using Functors: functor, fmap, fmap_with_path,
KeyPath, haskeypath, getkeypath,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra

include("interface.jl")
export AbstractRule

include("utils.jl")

include("adjust.jl")

include("destructure.jl")
export destructure

include("trainables.jl")
export trainables
export KeyPath, haskeypath, getkeypath # from Functors.jl

include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
Expand Down
2 changes: 1 addition & 1 deletion src/destructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

struct TrainableStructWalk <: AbstractWalk end

(::TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
(::TrainableStructWalk)(recurse, x) = mapvalue(recurse, _trainable(x))

_vec(x::Number) = LinRange(x,x,1)
_vec(x::AbstractArray) = vec(x)
Expand Down
16 changes: 4 additions & 12 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function _setup(rule, x; cache)
cache[x] = ℓ
end
else
valuemap(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
mapvalue(xᵢ -> _setup(rule, xᵢ; cache), _trainable(x))
end
end

Expand Down Expand Up @@ -82,7 +82,7 @@ function _update!(tree, x; grads, params)
haskey(params, (tree,x)) && return params[(tree,x)]
isbits(tree) && return x # means () is not cached, and also (((),),)
x′, re = functor(x)
x′′ = re(valuemap((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
x′′ = re(mapvalue((tᵢ, xᵢ) -> _update!(tᵢ, xᵢ; grads, params), tree, x′))
if ismutable(x′′)
params[(tree,x)] = x′′
else # no ties to preserve between immutable structs, right?
Expand Down Expand Up @@ -115,7 +115,7 @@ function _grads!(dict::IdDict, tree, x, x̄s...)
# functor(typeof(tree), base(x̄)), for things like Transpose
x̄s′ = map(x̄ -> functor(typeof(x), base(x̄))[1], x̄s)
x′, _ = functor(typeof(x), x)
valueforeach((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
foreachvalue((tᵢ, xᵢ, x̄sᵢ...) -> _grads!(dict, tᵢ, xᵢ, x̄sᵢ...), tree, x′, x̄s′...)
end

# default all rules to first order calls
Expand Down Expand Up @@ -172,22 +172,14 @@ _trainable(x) = _trainable(functor(x)[1], trainable(x))
_trainable(ch::NamedTuple, tr::NamedTuple) = merge(map(_ -> nothing, ch), tr)
_trainable(ch::Tuple{Vararg{Any,N}}, tr::Tuple{Vararg{Any,N}}) where N = tr
_trainable(ch::AbstractArray, tr::AbstractArray) = tr
_trainable(ch::Dict, tr::Dict) = merge(valuemap(_ -> nothing, ch), tr)
_trainable(ch::Dict, tr::Dict) = merge(mapvalue(_ -> nothing, ch), tr)

function _trainable(ch::NamedTuple, tr::Tuple) # for old Flux-style no-names tuple
@warn "trainable(x) should now return a NamedTuple with the field names, not a Tuple" maxlog=3
map(c -> c in tr ? c : nothing, ch)
end


valuemap(f, x...) = map(f, x...)
valuemap(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)
valueforeach(f, x...) = foreach(f, x...)
valueforeach(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end


###
### rule definition helpers
###
Expand Down
81 changes: 73 additions & 8 deletions src/trainables.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@

"""
trainables(x)
trainables(x, path = false)

Return a list over all the trainable parameters in `x`, that is all the numerical
Return an iterable over all the trainable parameters in `x`, that is all the numerical
arrays (see [`isnumeric`](@ref Optimisers.isnumeric)) which are reachable through [`trainable`](@ref Optimisers.trainable).

Parameters appearing multiple times in the model (tied weights) will be present only once in the output.

If `path = false`, the output is a list of numerical arrays.

If `path = true`, the output is a list of `(KeyPath, AbstractArray)` pairs, where [`KeyPath`](@ref) is a type
representing the path to the array in the original structure.

See also [`destructure`](@ref) for a similar operation that returns a single flat vector instead.

# Examples
Expand All @@ -33,27 +38,87 @@
2-element Vector{AbstractArray}:
[1.0, 2.0]
[3.0]
```

```jldoctest
julia> x = (a = [1.0,2.0], b = (Dict("c" => [3.0, 4.0], "d" => 5.0), [6.0,7.0]));

julia> for (kp, y) in trainables(x, path = true)
println(kp, " => ", y)
end
KeyPath(:a,) => [1.0, 2.0]
KeyPath(:b, 1, "c") => [3.0, 4.0]
KeyPath(:b, 2) => [6.0, 7.0]

julia> getkeypath(x, KeyPath(:b, 1, "c"))
2-element Vector{Float64}:
3.0
4.0
```
"""
function trainables(x)
function trainables(x; path = false)
if path
return _trainables_with_path(x)
else
return _trainables(x)
end
end


function _trainables(x)
arrays = AbstractArray[]
exclude(x) = Optimisers.isnumeric(x)
fmap(x; exclude, walk = Optimisers.TrainableStructWalk()) do y
fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do y
push!(arrays, y)
return y
end
return arrays
end

function ∇trainables(x, Δ)
exclude(x) = Optimisers.isnumeric(x)
i = 0
return fmapstructure(x; exclude, walk = TrainableStructWalk()) do _
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end

function ChainRulesCore.rrule(::typeof(trainables), x)
function ChainRulesCore.rrule(::typeof(_trainables), x)
y = trainables(x)
trainables_back(Δ) = (NoTangent(), ∇trainables(x, unthunk(Δ)))
return y, trainables_back
end

function _trainables_with_path(x)
named_params = []
exclude(kp, x) = isnumeric(x)
fmap_with_path(x; exclude, walk = TrainableStructWalkWithPath()) do kp, y
push!(named_params, (kp, y))
return y
end
return named_params
end

struct TrainableStructWalkWithPath <: AbstractWalk end

function (::TrainableStructWalkWithPath)(recurse, kp::KeyPath, x)
x_children = trainable(x)
kps = mapkey(c -> KeyPath(kp, c), x_children)
return mapvalue(recurse, kps, x_children)
end

function ChainRulesCore.rrule(::typeof(_trainables_with_path), x)
y = _trainables_with_path(x)
trainables_with_path_back(Δ) = (NoTangent(), ∇trainables_with_path(x, unthunk(Δ)))
return y, trainables_with_path_back
end

function ∇trainables_with_path(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
Δi = Δ[i+=1]
if isnothing(Δi)
return nothing

Check warning on line 119 in src/trainables.jl

View check run for this annotation

Codecov / codecov/patch

src/trainables.jl#L119

Added line #L119 was not covered by tests
else
return Δi[2]
end
end
end
15 changes: 15 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks))
mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x))
mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x))
mapkey(f, x::AbstractArray) = [f(i) for i=1:length(x)]

Check warning on line 8 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L6-L8

Added lines #L6 - L8 were not covered by tests

foreachvalue(f, x...) = foreach(f, x...)

foreachvalue(f, x::Dict, ys...) = foreach(pairs(x)) do (k, v)
f(v, (get(y, k, nothing) for y in ys)...)
end

Loading
Loading