Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gradient of convolution for complex values #389

Merged
merged 6 commits into from
Mar 2, 2022

Conversation

zsoerenm
Copy link
Contributor

This fixes FluxML/Flux.jl#1876

@ToucheSir
Copy link
Member

Thanks! Can you add a regression test for complex convs so that we know this path is covered?

Copy link
Contributor

@staticfloat staticfloat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, @zsoerenm!

I'm thinking a test case as simple as a 1x1 complex convolution with known gradient should be sufficient, just to check that it gives the gradient in the correct direction!

src/gemm.jl Outdated
@@ -35,7 +35,7 @@ for (gemm, elt) in gemm_datatype_mappings
beta::$(elt), C::Ptr{$elt})
# Convert our compile-time transpose marker to a char for BLAS
convtrans(V::Val{false}) = 'N'
convtrans(V::Val{true}) = 'T'
convtrans(V::Val{true}) = $elt <: Complex ? 'C' : 'T'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that the underlying {S,D,C,Z}GEMM routine should correctly interpret 'C', even if it's a real-valued input. In the case of real-valued input, 'C' should be equivalent to 'T'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought so, too, but buildkite was consistently failing on the GPU. Therefore, I implemented the branch above.

@mcabbott
Copy link
Member

Re tests, it should be possible to test directly the complex gradient of a function C^n x C^m -> R. I think ChainRulesTestUtils will even handle this by finite differencing.

Alternatively, a path where you take real input, make the complex arrays, call conv, sum, real is something on which you could test Zygote against ForwardDiff etc, end-to-end.

The test in FluxML/Flux.jl#1876 (comment) looks probably right, but is quite complicated, and without reading it very carefully I can't tell whether it's adopted the same convention for complex numbers that Zygote uses, and not (say) the conjugate.

@staticfloat
Copy link
Contributor

Because that linked test case compares against Flux.gradient(), I think it should natively compare with Zygote and ensure that the result is correct?

@mcabbott
Copy link
Member

It uses complex grad only on one side of this comparison, the other hand-assembles complex out of real gradients (outside of Zygote's sight):

vec(grads[params[1]]) ≈ complex.(grads_real[params_real[1]], grads_real[params_real[2]])

@zsoerenm zsoerenm changed the title Fix gradient of convolution for complex values [WIP] Fix gradient of convolution for complex values Feb 16, 2022
@zsoerenm
Copy link
Contributor Author

zsoerenm commented Feb 16, 2022

I have made more changes to the pull request. This is not finished, yet. I wanted to reach out to you as soon as possible as I'd like to start a little debate here:
I think that for complex values the weights should be multiplied by it's conjugate with the input x. Though, it does not really make a difference, because in the end this will just conjugate the internal weights of the neural network (as far as I can see).
However, taking the conjugate is more text book like (see e.g. https://dsp.stackexchange.com/questions/55388/for-complex-values-why-use-complex-conjugate-in-convolution).
Whats more, with this modification taking the derivative with ∇conv_filter_direct!, which is based on conv, just worked out of the box.
I have only tested this for 1x1 convolution kernel with multiple input channels, yet. Note the vec(w)' * vec(x) which feels more natural for me than transpose(vec(w)) * vec(x). Also note that DepthwiseConv will need adaptions. I will add these and more tests, once we have resolved this debate.
I'm quite new to neural networks. I'm not 100% sure, if everything is correct here. Especially I'm unsure about ∇conv_data. Taking the conjugate of the result or not does not have an effect on my toy example. In both cases the neural network converges to its expected result (see comment below).

Comment on lines 54 to 56
w_ptr = pointer(w)
w_ptr = pointer(copy(conj(w)))
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
gemm!(Val(false), Val(false), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was a little awkward for me.
All dense convolution tests are green, if I simply change the last line here:

w_ptr = pointer(w)
y_ptr = pointer(y, (batch_idx - 1)*M*N + 1)
gemm!(Val(false), Val(true), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)

Note the Val(true) here. I wondered, if Val(false) is actually correct?
The gradtests are failing with this change, though. So I left it as above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, looking at CI I wonder if this fixed the flaky gradtests we have when nthreads > 1!

Copy link
Contributor Author

@zsoerenm zsoerenm Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, what I meant to say is that with

gemm!(Val(false), Val(true), M, N, K, alpha, col_ptr, w_ptr, beta, y_ptr)

all gradtests are failing. All other tests are green, though.
EDIT: I just had another look. 7 out of 22 gradtests are failing for each spatial rank, not all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sorry. I wasn't referring to the change you wrote, but to the diff here with copy.

src/conv.jl Outdated
@@ -100,7 +100,7 @@ for backend in (Symbol(), :_direct, :_im2col, :_nnpack)
dy::AbstractArray{yT,N}, w::AbstractArray{wT,N},
cdims::C; kwargs...) where {yT, wT, N, C <: ConvDims}
dx = similar(dy, input_size(cdims)..., channels_in(cdims), size(dy, N))
return $(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...)
return conj($(Symbol("$(name)$(backend)!"))(dx, dy, w, cdims; kwargs...))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unsure about this change. My toy neural network example works with and without this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my loss function is real in the end, due to abs2. This might be the reason, why it does not have an effect ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're going to need to test a stack of multiple layers to truly see if this is important, because as you say, immediately taking the absolute value here (because it's the end of the network) defeats any changes. :)

@staticfloat
Copy link
Contributor

staticfloat commented Feb 16, 2022

Okay, having looked through this a bit more thoroughly now with these changes, I see what you're proposing; you're saying that due to the definition of convolution, we should generally be taking the conjugate of the weights before convolving.

In a similar vein to how many neural network libraries treat correlation and convolution as the same (e.g. not flipping the kernel in time), I am also unsure if we actually need to do the convolution here. It's pleasing, on an academic and technical level, since we now get the "textbook correct" results, and it certainly does make it easier to deal with non-ML fields (e.g. if the result of your ML training should correspond to some physical value). However, it is (slightly) more work in some cases, as can be seen by how you have to conj(x) every now and then, which does eat up precious memory bandwidth and FLOPS.

I'm happy to let the NNlib maintainers weigh in on whether they think this is important enough to include, or if we should stick with only the first commit (along with the tests) and be happy to just state somewhere in the documentation that NNlib learns the conjugate of what convolution "should" learn.

One thing that is definitely important though, is to match "conjugate style" with the other layers, such as Dense. If the weights of the Dense layer learn the "textbook correct" values, I think we should learn the textbook correct values here as well, as if users create fancy architectures that feed off of the weights of multiple layers at once, we don't want them to fight against eachother.

@zsoerenm
Copy link
Contributor Author

Yes, to be consistent, the weights of the Dense layer should also be conjugated before multiplied with the input.
I guess that this change is not worth the trouble?

@zsoerenm zsoerenm changed the title [WIP] Fix gradient of convolution for complex values Fix gradient of convolution for complex values Feb 17, 2022
@zsoerenm
Copy link
Contributor Author

I have split the pull request.
This pull request simply fixes the convolution for complex values.
The other pull request #390 addresses the question whether the weights should conjugated before it is multiplied with the input.

Copy link
Contributor

@staticfloat staticfloat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, I think this change is pretty non-controversial, and the tests look good to me. Thanks so much @zsoerenm!

@foldfelis
Copy link

Hi, is this PR also solve this issue?

@ToucheSir
Copy link
Member

It does not. Supporting GPU will require quite a bit more effort.

@ToucheSir ToucheSir merged commit aa86827 into FluxML:master Mar 2, 2022
@ToucheSir
Copy link
Member

Apologies for missing the merge on this. Thanks @zsoerenm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Gradient incorrect for Conv-layer and complex numbers
5 participants