-
-
Notifications
You must be signed in to change notification settings - Fork 611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Activations #860
Activations #860
Changes from 16 commits
540b736
38790dd
82261b5
bb84aee
1bb25dc
f412191
46abfbb
3b7b780
cdaaca8
d0202a2
99679f7
6475f6a
db92b0e
0fe3ac4
58c7947
89afa20
20eb840
dea2953
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,13 @@ import Flux: activations | |
@testset "basic" begin | ||
@testset "helpers" begin | ||
@testset "activations" begin | ||
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax) | ||
x = rand(10) | ||
@test activations(Chain(), x) == [] | ||
@test activations(dummy_model, x)[1] == dummy_model[1](x) | ||
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2] | ||
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x)) | ||
x = randn(10) | ||
@test activations(dummy_model, x)[1] == x.^2 | ||
@test activations(dummy_model, x)[2] == (x.^2 .- 3) | ||
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3) | ||
|
||
@test activations(Chain(), x) == () | ||
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type | ||
end | ||
end | ||
|
@@ -19,6 +21,12 @@ import Flux: activations | |
# numeric test should be put into testset of corresponding layer | ||
end | ||
|
||
@testset "Activations" begin | ||
c = Chain(Dense(3,5,relu), Dense(5,1,relu)) | ||
X = Float32.([1.0; 1.0; 1.0]) | ||
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a regular test, i.e. making sure the output is right? Rather than chaining dense layers, it might be useful to chain something simple like Otherwise really happy with this patch! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I see that there are some other tests above here; what's the need for the additional |
||
end | ||
|
||
@testset "Dense" begin | ||
@test length(Dense(10, 5)(randn(10))) == 5 | ||
@test_throws DimensionMismatch Dense(10, 5)(randn(1)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definition doesn't look right to me.