-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activations #860
Activations #860
Conversation
Thanks a lot for the patch! I wonder if we could just have |
Yeah, that's definitely a nicer way to do this (in the new commit). There are a couple of extra terms in the gradient that I'm sort of mystified by, but they don't seem to be causing problems.
What do you think of adding a Chain method that takes in two arguments, the second being a list of indices of the depths at which to return the transform? That would remove the need of an activations function completely, though not really decrease the total number of lines of code. |
This looks good, thanks. Would be great to have a quick test for it so we don't miss it again.
I don't understand this; would be great to perhaps see an example of what this would look like. |
Something like
|
src/layers/basic.jl
Outdated
return (res, extraChain(Base.tail(fs), res)...) | ||
end | ||
|
||
extraChain(::Tuple{}, x) = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably be the empty tuple, so that the compiler can unroll everything
@testset "Activations" begin | ||
c = Chain(Dense(3,5,relu), Dense(5,1,relu)) | ||
X = Float32.([1.0; 1.0; 1.0]) | ||
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a regular test, i.e. making sure the output is right? Rather than chaining dense layers, it might be useful to chain something simple like Chain(x -> x^2, x -> x+1)
or something so that the outputs and gradients are trivial.
Otherwise really happy with this patch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I see that there are some other tests above here; what's the need for the additional @test_nowarn
here? If it's redundant, it'd be best to remove.
sorry about the extra commits; looks like rebasing to master made this a bit of a mess. I just made the two suggestions you made. |
src/layers/basic.jl
Outdated
@@ -31,6 +31,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) | |||
|
|||
(c::Chain)(x) = applychain(c.layers, x) | |||
|
|||
(c::Chain)(x) = extraChain(c.layers, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definition doesn't look right to me.
No worries, it just needs to target the master branch rather than the old zygote branch. CI appears to have an issue BTW. |
CI seems to be because I wasn't fully on the master branch, though nightly is having some issue with initializing Zygote. The extra |
bors r+ |
860: Activations r=MikeInnes a=dsweber2 Taking derivatives w.r.t. the parameters results in complaints about mutability (mwe below). To get around this, I made the storage array a `Zygote.Buffer`, and then return a copy after inserting everything. I tried an `accumulate!` based version, which worked on commit ecc9ce9, but broke when I caught up. Simple example: ``` c = Chain(Dense(3,5,relu), Dense(5,1,relu)) X = Float32.([1.0; 1.0; 1.0]) gradient(()->Flux.activations(c, X)[2][1], params(c)) ``` Co-authored-by: dsweber2 <david.weber2@gmail.com> Co-authored-by: Mosè Giordano <m.giordano@ucl.ac.uk>
Build failed |
Failure is #923. |
Thanks @dsweber2! |
Taking derivatives w.r.t. the parameters results in complaints about mutability (mwe below). To get around this, I made the storage array a
Zygote.Buffer
, and then return a copy after inserting everything. I tried anaccumulate!
based version, which worked on commit ecc9ce9, but broke when I caught up.Simple example: