Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Adapt.adapt_structure method for Optimisers.Leaf #180

Merged
merged 9 commits into from
Nov 12, 2024
5 changes: 3 additions & 2 deletions ext/OptimisersAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import Optimisers: Leaf

function Adapt.adapt_structure(to, leaf::Leaf)
@warn """`Optimisers.Leaf` object does not support device transfer via
`Adapt.jl`. Avoid this by calling `Flux.gpu/cpu` or
`MLDataDevices.cpu_device()/gpu_device()` on the optimiser state object.
`Adapt.jl`. This could lead to incorrect gradient updates. Avoid this by
calling `Flux.gpu/cpu` or `MLDataDevices.cpu_device()/gpu_device()` on the
optimiser state object.
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
""" maxlog=1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this only once? It's a potential correctness bug, not a performance issue.

Suggested change
""" maxlog=1
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your worry that the message will get lost in the shuffle somehow? My thought was that people may have a valid use for this, and as long as they know what they're getting into the library doesn't have to remind them on every call.

Another practical concern would be what happens when someone tries to call cu(large state tree). Not setting maxlog would mean other logging is drowned out because this warning would trigger for every Leaf.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that printing 100 times on a big model is too much. Ideal would be once, on every invocation, IMO... but that's hard to make happen.

It's not the world's biggest correctness bug to ignore shared parameters, so maybe we should live with it. Maybe the message should say that's what the problem is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott i edited the warning to say that this is a correctness issue.


rule = Adapt.adapt(to, leaf.rule)
Expand Down
Loading