-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Conv transfer to AMDGPU #2235
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of small touch-ups, but otherwise LGTM
ext/AMDGPUExt/functor.jl
Outdated
|
||
_conv_basetype(c::Type{C}) where C <: Conv = Conv | ||
_conv_basetype(c::Type{C}) where C <: ConvTranspose = ConvTranspose | ||
Flux._isleaf(::AMD_CONV) = return true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flux._isleaf(::AMD_CONV) = return true | |
Flux._isleaf(::AMD_CONV) = true |
ext/AMDGPUExt/functor.jl
Outdated
|
||
_amd(m::Union{Conv, ConvTranspose}) = adapt_storage(FluxAMDAdaptor(), m) | ||
Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = return m |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = return m | |
Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = m |
ext/AMDGPUExt/functor.jl
Outdated
_conv_basetype(c::C) where C <: Conv = Conv | ||
_conv_basetype(c::C) where C <: ConvTranspose = ConvTranspose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_conv_basetype(c::C) where C <: Conv = Conv | |
_conv_basetype(c::C) where C <: ConvTranspose = ConvTranspose | |
_conv_basetype(::Conv) = Conv | |
_conv_basetype(::ConvTranspose) = ConvTranspose |
Done! Once this is merged, can we also tag a release? |
I want to see if the Julia v1 CUDA failures on this PR propagate to the main branch. They shouldn't be caused by this test suite because it never runs on CI (we should add that), so it'd be good to figure out what's going on before tagging. |
The one with Float16 or the one at the very end? |
The Float16 one. I'm not sure why it's seemingly spread from just the 1.6 job to the 1.x one as well... |
This could be some kind of synchronization issue probably unrelated to Flux. |
You mean running the CUDA tests locally works? That's odd, I wonder why it's failing consistently on CI then. CUDNN_STATUS_BAD_PARAM is an input validation error, so I would have expected it to be deterministic. |
Yes. And now it always succeeds. Can't reproduce the error for some reason... |
I was able to repro consistently yesterday on 1.6 and 1.8. Will try to look into it over the next couple of days. |
Previously only this flipped conv weights:
but not this (or anything more complex):
To fix this, I modified
exclude
function passed tofmap
both forCPU -> GPU
&GPU -> CPU
transfer (so that weights are flipped back).PR Checklist