Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Conv transfer to AMDGPU #2235

Merged
merged 3 commits into from
Apr 23, 2023
Merged

Fix Conv transfer to AMDGPU #2235

merged 3 commits into from
Apr 23, 2023

Conversation

pxl-th
Copy link
Member

@pxl-th pxl-th commented Apr 21, 2023

Previously only this flipped conv weights:

Conv((3, 3), 3 => 3) |> gpu

but not this (or anything more complex):

Chain(Conv((3, 3), 3 => 3)) |> gpu

To fix this, I modified exclude function passed to fmap both for CPU -> GPU & GPU -> CPU transfer (so that weights are flipped back).

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

src/functor.jl Outdated Show resolved Hide resolved
ext/AMDGPUExt/functor.jl Outdated Show resolved Hide resolved
ext/AMDGPUExt/functor.jl Outdated Show resolved Hide resolved
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of small touch-ups, but otherwise LGTM


_conv_basetype(c::Type{C}) where C <: Conv = Conv
_conv_basetype(c::Type{C}) where C <: ConvTranspose = ConvTranspose
Flux._isleaf(::AMD_CONV) = return true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Flux._isleaf(::AMD_CONV) = return true
Flux._isleaf(::AMD_CONV) = true


_amd(m::Union{Conv, ConvTranspose}) = adapt_storage(FluxAMDAdaptor(), m)
Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = return m
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = return m
Adapt.adapt_structure(to::FluxAMDAdaptor, m::AMD_CONV) = m

Comment on lines 40 to 41
_conv_basetype(c::C) where C <: Conv = Conv
_conv_basetype(c::C) where C <: ConvTranspose = ConvTranspose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_conv_basetype(c::C) where C <: Conv = Conv
_conv_basetype(c::C) where C <: ConvTranspose = ConvTranspose
_conv_basetype(::Conv) = Conv
_conv_basetype(::ConvTranspose) = ConvTranspose

@pxl-th
Copy link
Member Author

pxl-th commented Apr 23, 2023

Done! Once this is merged, can we also tag a release?
I'd like to specify a compat bound for one of the packages.

@ToucheSir ToucheSir merged commit 3392a02 into FluxML:master Apr 23, 2023
@ToucheSir
Copy link
Member

I want to see if the Julia v1 CUDA failures on this PR propagate to the main branch. They shouldn't be caused by this test suite because it never runs on CI (we should add that), so it'd be good to figure out what's going on before tagging.

@pxl-th pxl-th deleted the amd-transfer branch April 24, 2023 05:29
@pxl-th
Copy link
Member Author

pxl-th commented Apr 24, 2023

The one with Float16 or the one at the very end?

@ToucheSir
Copy link
Member

The Float16 one. I'm not sure why it's seemingly spread from just the 1.6 job to the 1.x one as well...

@pxl-th
Copy link
Member Author

pxl-th commented Apr 24, 2023

This could be some kind of synchronization issue probably unrelated to Flux.
Running tests multiple times either fails or finishes successfully (mostly successfully)...

@ToucheSir
Copy link
Member

You mean running the CUDA tests locally works? That's odd, I wonder why it's failing consistently on CI then. CUDNN_STATUS_BAD_PARAM is an input validation error, so I would have expected it to be deterministic.

@pxl-th
Copy link
Member Author

pxl-th commented Apr 24, 2023

You mean running the CUDA tests locally works?

Yes. And now it always succeeds. Can't reproduce the error for some reason...

@ToucheSir
Copy link
Member

I was able to repro consistently yesterday on 1.6 and 1.8. Will try to look into it over the next couple of days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants