Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Adapt.adapt_structure method for Optimisers.Leaf #180

Merged
merged 9 commits into from
Nov 12, 2024

Conversation

vpuri3
Copy link
Contributor

@vpuri3 vpuri3 commented Oct 3, 2024

Fix #179

PR Checklist

  • Tests are added
  • Documentation, if applicable

@vpuri3 vpuri3 changed the title Add Adapt.adapt_storage method for Optimisers.Leaf Add Adapt.adapt_structure method for Optimisers.Leaf Oct 3, 2024
vpuri3 added a commit to vpuri3/NeuralROMs.jl that referenced this pull request Oct 3, 2024
Project.toml Outdated Show resolved Hide resolved
@mcabbott
Copy link
Member

mcabbott commented Oct 3, 2024

I think this won't preserve identity? That's the big difference between Functors and Adapt

@vpuri3
Copy link
Contributor Author

vpuri3 commented Oct 3, 2024

@mcabbott can you explain? I don't know what identity is.

@CarloLucibello addressing your comment from #179

Since Leaf is a functor, it will move to GPU when using Flux.gpu or MLDataDevices.jl.

That is true, the MWE works with MLDataDevices. However, we still need Adapt functionality. Consider the case when Leaf is stored as part of a struct. Then using MLDataDevice.gpu_device doesn't move the state to the GPU even if we have Adapt.adapt_structure defined for the object.

using Optimisers, CUDA, LuxCUDA, MLDataDevices, Adapt

struct TrainState{Tp, To}
  p::Tp
  opt_st::To
end

Adapt.@adapt_structure TrainState

p = rand(2)
opt_st = Optimisers.setup(Optimisers.Adam(), p)
ts = TrainState(p, opt_st)
device = gpu_device()
device(ts).opt_st.state[1]

2-element Vector{Float64}:
 0.0
 0.0

So there is a need to define Adapt.adapt_structure for Leaf.

@mcabbott
Copy link
Member

mcabbott commented Oct 3, 2024

Functors keeps an IdDict so that if the same array appears twice, this property is preserved by fmap. Optimisers.jl follows that too, and will use (and expect, IIRC) the same Leaf in such cases. So I don't see an easy way to male cu work for this.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Oct 3, 2024

Would something like this solve the problem?

function Adapt.adapt_storage(to, leaf::Leaf)
    return fmap(x -> Adapt.adapt_storage(to, x), leaf)
end

@ToucheSir
Copy link
Member

It would not, because the IdDict needs to be shared between Leafs. This is why Flux.gpu is a standalone function right now, FWIW.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Oct 3, 2024

It would not, because the IdDict needs to be shared between Leafs. This is why Flux.gpu is a standalone function right now, FWIW.

Maybe it's possible to grab the IdDict and bind it to the new Leaf object? Where is it defined?

BTW the fix in this PR Adapt.@adapt_storage Optimisers.Leaf is working fine in my training runs.

@ToucheSir
Copy link
Member

ToucheSir commented Oct 4, 2024

BTW the fix in this PR Adapt.@adapt_storage Optimisers.Leaf is working fine in my training runs.

That's because your model doesn't have any shared/"tied" parameters. e.g. model.layer1.W === model.layer2.W. Which is fine, but libraries like Optimisers have to support all use cases.

Maybe it's possible to grab the IdDict and bind it to the new Leaf object? Where is it defined?

It's created at the top level in Functors.fmap and threaded down through the state tree. I'm not sure what it means to "grab the IdDict" in the context of overriding adapt_structure. Flux and MLDataDevices only ever call adapt using fmap, and adapt doesn't take a cache argument.

…handled by functors. So we add a warning referring the user to Flux.gpu or MLDataDevices.gpu_device()
@vpuri3
Copy link
Contributor Author

vpuri3 commented Oct 4, 2024

@ToucheSir, thanks for explaining. I added a warning to the adapt_structure method that points the user to Flux.gpu, and moved it all to an extension. Now cu(opt_st) won't silently not do what the user expects it to do.

The behavior is as follows:

julia> using Optimisers, CUDA, LuxCUDA

julia> opt_st = Optimisers.setup(Optimisers.Adam(), zeros(2))
Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), ([0.0, 0.0], [0.0, 0.0], (0.9, 0.999)))

julia> cu(opt_st)
┌ Warning: `Optimisers.Leaf` object does not support device transfer via
│ `Adapt.jl`. Avoid this by calling `Flux.gpu/cpu` or
│ `MLDataDevices.cpu_device()/gpu_device()` on the optimiser state object.
│ See below GitHub issue for more details.
│ https://github.com/FluxML/Optimisers.jl/issues/179 
└ @ OptimisersAdaptExt ~/.julia/dev/Optimisers.jl/ext/OptimisersAdaptExt.jl:7
Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))

julia> cu(opt_st).state[1] |> typeof
CuArray{Float32, 1, CUDA.DeviceMemory}

julia> using MLDataDevices

julia> gpu_device()(opt_st)
Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999)))

julia> gpu_device()(opt_st).state[1]
2-element CuArray{Float32, 1, CUDA.DeviceMemory}:
 0.0
 0.0

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I had another look through this after some internal changes around device handling on the Flux side. Made some minor tweaks (removed PR link since we'd want a docs one, and used adapt to be more general), but it should be ready to go.

ext/OptimisersAdaptExt.jl Outdated Show resolved Hide resolved
vpuri3 and others added 2 commits November 11, 2024 16:23
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
@warn """`Optimisers.Leaf` object does not support device transfer via
`Adapt.jl`. Avoid this by calling `Flux.gpu/cpu` or
`MLDataDevices.cpu_device()/gpu_device()` on the optimiser state object.
""" maxlog=1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this only once? It's a potential correctness bug, not a performance issue.

Suggested change
""" maxlog=1
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your worry that the message will get lost in the shuffle somehow? My thought was that people may have a valid use for this, and as long as they know what they're getting into the library doesn't have to remind them on every call.

Another practical concern would be what happens when someone tries to call cu(large state tree). Not setting maxlog would mean other logging is drowned out because this warning would trigger for every Leaf.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that printing 100 times on a big model is too much. Ideal would be once, on every invocation, IMO... but that's hard to make happen.

It's not the world's biggest correctness bug to ignore shared parameters, so maybe we should live with it. Maybe the message should say that's what the problem is?

@vpuri3
Copy link
Contributor Author

vpuri3 commented Nov 11, 2024

Thanks @ToucheSir , I resolved the merge conflicts.

`Adapt.jl`. This could lead to incorrect gradient updates. Avoid this by
calling `Flux.gpu/cpu` or `MLDataDevices.cpu_device()/gpu_device()` on the
optimiser state object.
""" maxlog=1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcabbott i edited the warning to say that this is a correctness issue.

ext/OptimisersAdaptExt.jl Outdated Show resolved Hide resolved
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@mcabbott mcabbott merged commit 9928588 into FluxML:master Nov 12, 2024
3 of 4 checks passed
@vpuri3 vpuri3 deleted the patch-1 branch November 12, 2024 17:58
mashu pushed a commit to mashu/Optimisers.jl that referenced this pull request Nov 14, 2024
* Adapt.adapt_structure method for Optimisers.Leaf

* import Adapt.jl

* add Adapt.jl to Project.toml

* adapt compat

* based on discussion: adapt_structure method does not maintain IdDict handled by functors. So we add a warning referring the user to Flux.gpu or MLDataDevices.gpu_device()

* Update ext/OptimisersAdaptExt.jl

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>

* edit warning to indicate that this is a correctness issue

* Update ext/OptimisersAdaptExt.jl

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>

---------

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimiser state a not moving to GPU
4 participants