From 26395239c0307fadc4d1143d9ae1bc1a6cb2711e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:08:22 -0500 Subject: [PATCH] Add `Duplicated` methods (#192) * add Duplicated methods * add test * test for shared params + minimal docs * remove 1.6 CI * indent by two spaces * fix doctest --- .github/workflows/ci.yml | 2 +- Project.toml | 14 ++++++-- docs/src/index.md | 9 +++++ ext/OptimisersEnzymeCoreExt.jl | 60 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 24 +++++++++++++- 5 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 ext/OptimisersEnzymeCoreExt.jl diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c3b18c..629d90b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,9 +15,9 @@ jobs: fail-fast: false matrix: version: - - '1.6' - '1' - 'nightly' + - "1.10" os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index 0a19f49..a60b58f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,21 +1,29 @@ name = "Optimisers" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.4.1" authors = ["Mike J Innes "] -version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[extensions] +OptimisersEnzymeCoreExt = "EnzymeCore" + [compat] ChainRulesCore = "1" +EnzymeCore = "0.8.5" Functors = "0.4.9, 0.5" Statistics = "1" Zygote = "0.6.40" -julia = "1.6" +julia = "1.10" [extras] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -23,4 +31,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "StaticArrays", "Zygote"] +test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"] diff --git a/docs/src/index.md b/docs/src/index.md index 3cb32f8..a595d70 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -358,3 +358,12 @@ julia> Optimisers.update!(opt_state, x, g); julia> opt_state # the state in `a` and `b` differ (a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001)))) ``` + +## Usage with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) + +Enzyme.jl is a new automatic differentiation package, an alternative to Zygote.jl. +It likes to store the model and the gradient together, as an object `Duplicated(x, dx)`. + +Optimisers.jl now has some methods to handle this: +* `update!(opt_state, Duplicated(model, grad))` uses the gradient to update both the model and the optimiser state, and +* `setup(::AbstractRule, ::Duplicated)` ignores the gradient and returns `setup(rule, model)`. diff --git a/ext/OptimisersEnzymeCoreExt.jl b/ext/OptimisersEnzymeCoreExt.jl new file mode 100644 index 0000000..a1c1ab9 --- /dev/null +++ b/ext/OptimisersEnzymeCoreExt.jl @@ -0,0 +1,60 @@ +module OptimisersEnzymeCoreExt + +import Optimisers: trainable, setup, update!, isnumeric, AbstractRule, _setup +import EnzymeCore: Duplicated, Const + +using Functors: fmapstructure + +trainable(x::Duplicated) = (; val = x.val) +trainable(x::Const) = (;) + +""" + setup(rule::AbstractRule, model_grad::Duplicated) + +For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`. +""" +setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val) + +_setup(rule, x::Duplicated; cache) = throw(ArgumentError( + """Objects of type `Duplicated` are only supported by Optimisers.jl at top level, + they may not appear deep inside other objects.""" +)) + +""" + update!(opt_state, model_grad::Duplicated) + +For use with Enzyme's `Duplicated`, which holds both a model/parameters +and the corresponding gradient. + +# Example + +```jldoctest +julia> using Optimisers, EnzymeCore + +julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4]) +Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0]) + +julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx +Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0]) + +julia> Optimisers.update!(st, x_dx) # mutates both arguments + +julia> x_dx +Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0]) + +julia> st +Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443]) +``` +""" +function update!(opt_state, model_grad::Duplicated) + _, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad)) + nothing +end + +# This function strips the returned gradient to be Zygote-like, +# most importantly prune=nothing removes 2nd appearance of shared gradient to avoid double-counting. +_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing) +_grad_or_nothing(::Const) = nothing +_grad_or_nothing(x) = isnumeric(x) ? x : nothing + +end diff --git a/test/runtests.jl b/test/runtests.jl index ae2d9d0..956aa04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Optimisers -using ChainRulesCore, Functors, StaticArrays, Zygote +using ChainRulesCore, Functors, StaticArrays, Zygote, EnzymeCore using LinearAlgebra, Statistics, Test, Random using Optimisers: @.., @lazy using Base.Broadcast: broadcasted, instantiate, Broadcasted @@ -534,6 +534,28 @@ end @test Optimisers._norm(bc2, p) isa Float64 end end + + @testset "Enzyme Duplicated" begin + x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4]) + st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx + @test st isa Optimisers.Leaf + @test nothing === Optimisers.update!(st, x_dx) # mutates both arguments + @test x_dx.val ≈ Float16[0.8887, 2.0, 3.445] + + shared = [1.0] + model = (x=shared, y=shared) + grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated. + dup = Duplicated(model, model) + st2 = Optimisers.setup(Descent(0.1), model) + Optimisers.update!(st2, dup) + @test model.x ≈ [0.9] + shared .= 1 + Optimisers.update!(st2, model, grad) + model.x ≈ [0.8] # This is wrong, but don't make it a test. + # Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case? + + @test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed + end end @testset verbose=true "Destructure" begin include("destructure.jl")