Skip to content

Commit

Permalink
Add Adapt.adapt_structure method for Optimisers.Leaf (FluxML#180)
Browse files Browse the repository at this point in the history
* Adapt.adapt_structure method for Optimisers.Leaf

* import Adapt.jl

* add Adapt.jl to Project.toml

* adapt compat

* based on discussion: adapt_structure method does not maintain IdDict handled by functors. So we add a warning referring the user to Flux.gpu or MLDataDevices.gpu_device()

* Update ext/OptimisersAdaptExt.jl

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>

* edit warning to indicate that this is a correctness issue

* Update ext/OptimisersAdaptExt.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

---------

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
  • Loading branch information
3 people authored and mashu committed Nov 14, 2024
1 parent 42714cf commit 445c490
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
OptimisersAdaptExt = ["Adapt"]
OptimisersEnzymeCoreExt = "EnzymeCore"

[compat]
Adapt = "4"
ChainRulesCore = "1"
EnzymeCore = "0.8.5"
Functors = "0.4.9, 0.5"
Expand Down
20 changes: 20 additions & 0 deletions ext/OptimisersAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module OptimisersAdaptExt

import Adapt
import Optimisers: Leaf

function Adapt.adapt_structure(to, leaf::Leaf)
@warn """`Optimisers.Leaf` object does not support device transfer via
`Adapt.jl`. This is because `Adapt.jl` does not handle shared parameters (i.e. the same parameter array
appearing more than once in the model), and in such cases this will lead to incorrect gradient updates.
Avoid this by calling `Flux.gpu/cpu` or `MLDataDevices.cpu_device()/gpu_device()` on the
optimiser state object.
""" maxlog=1

rule = Adapt.adapt(to, leaf.rule)
state = Adapt.adapt(to, leaf.state)

Leaf(rule, state, leaf.frozen)
end

end

0 comments on commit 445c490

Please sign in to comment.