-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
new cross convolution layer #423
Conversation
Will add tests for the layer upon approval. |
Looks good to me. Can you add some quick tests that it runs through, and perhaps also a gradient check for |
src/tracker/array.jl
Outdated
crosscor(data(x), data(w); kw...), | ||
Δ -> nobacksies(:crosscor, | ||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...), | ||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs the gradient check. I guess this won't be right without flipkernel
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, missed that. Will add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MikeInnes We'd have to implement flipkernel
argument for the ∇conv_data
family of functions, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup
@MikeInnes Tests pass locally. |
Sorry for dropping this for a while. Can you do a quick rebase? Will probably need the gradients part of the PR to move to Tracker.jl. |
2482c5c
to
161fa86
Compare
@staticfloat did If so, it's not necessarily an issue; we can just define it in Flux and use it the same way (AD should work automatically if it just forward to |
Ah, yes, it did. It has been subsumed into |
Ok cool. In that case @ayush1999 let's just define |
src/layers/conv.jl
Outdated
bias::V | ||
stride::NTuple{N,Int} | ||
pad::NTuple{N,Int} | ||
dilation::NTuple{N,Int} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting is off here
Can we keep the Also, can we get a GPU test for this layer as well. |
@ayush1999 you can set new_cdims = DenseConvDims(cdims; flipkernel=true) |
If I'm not wrong, we'd need to tag a new release for NNlib so that the tests can pass. |
Those PRs are in, so we should be able to get this in now, just has one more conflict. Note that you can also put branches of packages in the manifest to get tests to pass. Shouldn't be necessary now though. |
I've added all these changes in a new PR (#762). (I lost access to the github account for this PR). Can you please close this and review the new PR? |
This works on top of changes in FluxML/NNlib.jl#71. Implemented a
CrossConv
layer, which callscrossconv
andcrossconv!
functions from NNlib.