-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix gradient of convolution for complex values #389
Conversation
Thanks! Can you add a regression test for complex convs so that we know this path is covered? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, @zsoerenm!
I'm thinking a test case as simple as a 1x1
complex convolution with known gradient should be sufficient, just to check that it gives the gradient in the correct direction!
src/gemm.jl
Outdated
@@ -35,7 +35,7 @@ for (gemm, elt) in gemm_datatype_mappings | |||
beta::$(elt), C::Ptr{$elt}) | |||
# Convert our compile-time transpose marker to a char for BLAS | |||
convtrans(V::Val{false}) = 'N' | |||
convtrans(V::Val{true}) = 'T' | |||
convtrans(V::Val{true}) = $elt <: Complex ? 'C' : 'T' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that the underlying {S,D,C,Z}GEMM
routine should correctly interpret 'C'
, even if it's a real-valued input. In the case of real-valued input, 'C'
should be equivalent to 'T'
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought so, too, but buildkite was consistently failing on the GPU. Therefore, I implemented the branch above.
Re tests, it should be possible to test directly the complex gradient of a function Alternatively, a path where you take real input, make the complex arrays, call conv, sum, real is something on which you could test Zygote against ForwardDiff etc, end-to-end. The test in FluxML/Flux.jl#1876 (comment) looks probably right, but is quite complicated, and without reading it very carefully I can't tell whether it's adopted the same convention for complex numbers that Zygote uses, and not (say) the conjugate. |
Because that linked test case compares against |
It uses complex grad only on one side of this comparison, the other hand-assembles complex out of real gradients (outside of Zygote's sight):
|
I have made more changes to the pull request. This is not finished, yet. I wanted to reach out to you as soon as possible as I'd like to start a little debate here: |
src/impl/conv_im2col.jl
Outdated
w_ptr = pointer(w) | ||
w_ptr = pointer(copy(conj(w))) | ||
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1) | ||
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was a little awkward for me.
All dense convolution tests are green, if I simply change the last line here:
w_ptr = pointer(w)
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
gemm!(Val(false), Val(true), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
Note the Val(true)
here. I wondered, if Val(false)
is actually correct?
The gradtest
s are failing with this change, though. So I left it as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, looking at CI I wonder if this fixed the flaky gradtests we have when nthreads > 1!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, what I meant to say is that with
gemm!(Val(false), Val(true), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
all gradtests are failing. All other tests are green, though.
EDIT: I just had another look. 7 out of 22 gradtest
s are failing for each spatial rank, not all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry. I wasn't referring to the change you wrote, but to the diff here with copy
.
src/conv.jl
Outdated
@@ -100,7 +100,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack) | |||
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N}, | |||
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims} | |||
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N)) | |||
return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...) | |||
return conj($(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unsure about this change. My toy neural network example works with and without this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, my loss function is real in the end, due to abs2
. This might be the reason, why it does not have an effect ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're going to need to test a stack of multiple layers to truly see if this is important, because as you say, immediately taking the absolute value here (because it's the end of the network) defeats any changes. :)
Okay, having looked through this a bit more thoroughly now with these changes, I see what you're proposing; you're saying that due to the definition of convolution, we should generally be taking the conjugate of the weights before convolving. In a similar vein to how many neural network libraries treat correlation and convolution as the same (e.g. not flipping the kernel in time), I am also unsure if we actually need to do the convolution here. It's pleasing, on an academic and technical level, since we now get the "textbook correct" results, and it certainly does make it easier to deal with non-ML fields (e.g. if the result of your ML training should correspond to some physical value). However, it is (slightly) more work in some cases, as can be seen by how you have to I'm happy to let the NNlib maintainers weigh in on whether they think this is important enough to include, or if we should stick with only the first commit (along with the tests) and be happy to just state somewhere in the documentation that NNlib learns the conjugate of what convolution "should" learn. One thing that is definitely important though, is to match "conjugate style" with the other layers, such as |
Yes, to be consistent, the weights of the |
I have split the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, I think this change is pretty non-controversial, and the tests look good to me. Thanks so much @zsoerenm!
Hi, is this PR also solve this issue? |
It does not. Supporting GPU will require quite a bit more effort. |
Apologies for missing the merge on this. Thanks @zsoerenm! |
This fixes FluxML/Flux.jl#1876