Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activations #860

Merged
merged 18 commits into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))

(c::Chain)(x) = applychain(c.layers, x)

(c::Chain)(x) = extraChain(c.layers, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definition doesn't look right to me.


Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)

function Base.show(io::IO, c::Chain)
Expand All @@ -44,19 +46,23 @@ end
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
# Johnny Chen -- @johnnychen94
# only slightly changed to better handle interaction with Zygote @dsweber2
"""
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
"""
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
end
return rst
extraChain(c.layers, input)
end

function extraChain(fs::Tuple, x)
res = first(fs)(x)
return (res, extraChain(Base.tail(fs), res)...)
end

extraChain(::Tuple{}, x) = ()



"""
Dense(in::Integer, out::Integer, σ = identity)
Expand Down
18 changes: 13 additions & 5 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import Flux: activations
@testset "basic" begin
@testset "helpers" begin
@testset "activations" begin
dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax)
x = rand(10)
@test activations(Chain(), x) == []
@test activations(dummy_model, x)[1] == dummy_model[1](x)
@test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2]
dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x))
x = randn(10)
@test activations(dummy_model, x)[1] == x.^2
@test activations(dummy_model, x)[2] == (x.^2 .- 3)
@test activations(dummy_model, x)[3] == tan.(x.^2 .- 3)

@test activations(Chain(), x) == ()
@test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type
end
end
Expand All @@ -19,6 +21,12 @@ import Flux: activations
# numeric test should be put into testset of corresponding layer
end

@testset "Activations" begin
c = Chain(Dense(3,5,relu), Dense(5,1,relu))
X = Float32.([1.0; 1.0; 1.0])
@test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a regular test, i.e. making sure the output is right? Rather than chaining dense layers, it might be useful to chain something simple like Chain(x -> x^2, x -> x+1) or something so that the outputs and gradients are trivial.

Otherwise really happy with this patch!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I see that there are some other tests above here; what's the need for the additional @test_nowarn here? If it's redundant, it'd be best to remove.

end

@testset "Dense" begin
@test length(Dense(10, 5)(randn(10))) == 5
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
Expand Down