Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Duplicated methods #192

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1'
- 'nightly'
- "1.10"
os:
- ubuntu-latest
arch:
Expand Down
14 changes: 11 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
name = "Optimisers"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.4.1"
authors = ["Mike J Innes <mike.j.innes@gmail.com>"]
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
OptimisersEnzymeCoreExt = "EnzymeCore"

[compat]
ChainRulesCore = "1"
EnzymeCore = "0.8.5"
Functors = "0.4.9, 0.5"
Statistics = "1"
Zygote = "0.6.40"
julia = "1.6"
julia = "1.10"

[extras]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Test", "EnzymeCore", "StaticArrays", "Zygote"]
9 changes: 9 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,12 @@ julia> Optimisers.update!(opt_state, x, g);
julia> opt_state # the state in `a` and `b` differ
(a = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.09, 0.09], [0.000999, 0.000999], (0.729, 0.997003))), b = Leaf(Adam(0.1, (0.9, 0.999), 1.0e-8), ([0.1, 0.1], [0.001, 0.001], (0.81, 0.998001))))
```

## Usage with [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl)

Enzyme.jl is a new automatic differentiation package, an alternative to Zygote.jl.
It likes to store the model and the gradient together, as an object `Duplicated(x, dx)`.

Optimisers.jl now has some methods to handle this:
* `update!(opt_state, Duplicated(model, grad))` uses the gradient to update both the model and the optimiser state, and
* `setup(::AbstractRule, ::Duplicated)` ignores the gradient and returns `setup(rule, model)`.
60 changes: 60 additions & 0 deletions ext/OptimisersEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module OptimisersEnzymeCoreExt

import Optimisers: trainable, setup, update!, isnumeric, AbstractRule, _setup
import EnzymeCore: Duplicated, Const

using Functors: fmapstructure

trainable(x::Duplicated) = (; val = x.val)
trainable(x::Const) = (;)

"""
setup(rule::AbstractRule, model_grad::Duplicated)

For use with Enzyme's Duplicated, this just calls `setup(rule, model_grad.val)`.
"""
setup(rule::AbstractRule, model_grad::Duplicated) = setup(rule, model_grad.val)

_setup(rule, x::Duplicated; cache) = throw(ArgumentError(
"""Objects of type `Duplicated` are only supported by Optimisers.jl at top level,
they may not appear deep inside other objects."""
))

"""
update!(opt_state, model_grad::Duplicated)

For use with Enzyme's `Duplicated`, which holds both a model/parameters
and the corresponding gradient.

# Example

```jldoctest
julia> using Optimisers, EnzymeCore

julia> x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
Duplicated{Vector{Float16}}(Float16[1.0, 2.0, 3.0], Float16[1.0, 0.0, -4.0])

julia> st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
Leaf(Momentum(0.111111, 0.9), Float16[0.0, 0.0, 0.0])

julia> Optimisers.update!(st, x_dx) # mutates both arguments

julia> x_dx
Duplicated{Vector{Float16}}(Float16[0.8887, 2.0, 3.445], Float16[1.0, 0.0, -4.0])

julia> st
Leaf(Momentum(0.111111, 0.9), Float16[0.1111, 0.0, -0.4443])
```
"""
function update!(opt_state, model_grad::Duplicated)
_, _ = update!(opt_state, model_grad.val, _grad_or_nothing(model_grad))
nothing
end

# This function strips the returned gradient to be Zygote-like,
# most importantly prune=nothing removes 2nd appearance of shared gradient to avoid double-counting.
_grad_or_nothing(dup::Duplicated) = fmapstructure(_grad_or_nothing, dup.dval; prune=nothing)
_grad_or_nothing(::Const) = nothing
_grad_or_nothing(x) = isnumeric(x) ? x : nothing

end
24 changes: 23 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Optimisers
using ChainRulesCore, Functors, StaticArrays, Zygote
using ChainRulesCore, Functors, StaticArrays, Zygote, EnzymeCore
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted
Expand Down Expand Up @@ -534,6 +534,28 @@ end
@test Optimisers._norm(bc2, p) isa Float64
end
end

@testset "Enzyme Duplicated" begin
x_dx = Duplicated(Float16[1,2,3], Float16[1,0,-4])
st = Optimisers.setup(Momentum(1/9), x_dx) # acts only on x not on dx
@test st isa Optimisers.Leaf
@test nothing === Optimisers.update!(st, x_dx) # mutates both arguments
@test x_dx.val ≈ Float16[0.8887, 2.0, 3.445]

shared = [1.0]
model = (x=shared, y=shared)
grad = deepcopy(model) # Enzyme produces something like this, grad.x === grad.y, already accumulated.
dup = Duplicated(model, model)
st2 = Optimisers.setup(Descent(0.1), model)
Optimisers.update!(st2, dup)
@test model.x ≈ [0.9]
shared .= 1
Optimisers.update!(st2, model, grad)
model.x ≈ [0.8] # This is wrong, but don't make it a test.
# Ideally, perhaps the 3-arg update! could notice that grad.x===grad.y, and not accumulate the gradient in this case?
Comment on lines +552 to +555
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get what you mean, but do we promise to handle shared gradient arrays? Might be easier to add a warning in the docs. Maybe a broken test outside of the Enzyme test suite if we really care about this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At present (after this PR) I think the usual path with Zygote and this new path with Enzyme should always agree.

Whether the normal path can auto-detect this (as in the comment) I'm not really sure. It's possible to make Zygote return two components which are ===, and it's conceivable that this could happen with shared parameters.

Nevertheless, detecting this would probably be better than not. In real use, e.g. you use the same layer twice, Zygote is certainly going to return two new arrays for shared parameters.


@test_throws ArgumentError Optimisers.setup(Adam(), (; a=[1,2,3.], b=x_dx)) # Duplicated deep inside is not allowed
end
end
@testset verbose=true "Destructure" begin
include("destructure.jl")
Expand Down
Loading