Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Adapt.adapt_structure method for Optimisers.Leaf #180

Merged
merged 9 commits into from
Nov 12, 2024
Merged
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
OptimisersAdaptExt = ["Adapt"]
OptimisersEnzymeCoreExt = "EnzymeCore"

[compat]
Adapt = "4"
ChainRulesCore = "1"
EnzymeCore = "0.8.5"
Functors = "0.4.9, 0.5"
Expand Down
20 changes: 20 additions & 0 deletions ext/OptimisersAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module OptimisersAdaptExt

import Adapt
import Optimisers: Leaf

function Adapt.adapt_structure(to, leaf::Leaf)
@warn """`Optimisers.Leaf` object does not support device transfer via
`Adapt.jl`. This is because `Adapt.jl` does not handle shared parameters (i.e. the same parameter array
appearing more than once in the model), and in such cases this will lead to incorrect gradient updates.
Avoid this by calling `Flux.gpu/cpu` or `MLDataDevices.cpu_device()/gpu_device()` on the
optimiser state object.
""" maxlog=1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this only once? It's a potential correctness bug, not a performance issue.

Suggested change
""" maxlog=1
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your worry that the message will get lost in the shuffle somehow? My thought was that people may have a valid use for this, and as long as they know what they're getting into the library doesn't have to remind them on every call.

Another practical concern would be what happens when someone tries to call cu(large state tree). Not setting maxlog would mean other logging is drowned out because this warning would trigger for every Leaf.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that printing 100 times on a big model is too much. Ideal would be once, on every invocation, IMO... but that's hard to make happen.

It's not the world's biggest correctness bug to ignore shared parameters, so maybe we should live with it. Maybe the message should say that's what the problem is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott i edited the warning to say that this is a correctness issue.


rule = Adapt.adapt(to, leaf.rule)
state = Adapt.adapt(to, leaf.state)

Leaf(rule, state, leaf.frozen)
end

end
Loading