diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f42a9619f9..2a46520818 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -44,19 +44,23 @@ end # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 # Johnny Chen -- @johnnychen94 +# only slightly changed to better handle interaction with Zygote @dsweber2 """ activations(c::Chain, input) Calculate the forward results of each layers in Chain `c` with `input` as model input. """ function activations(c::Chain, input) - rst = [] - for l in c - x = get(rst, length(rst), input) - push!(rst, l(x)) - end - return rst + extraChain(c.layers, input) end +function extraChain(fs::Tuple, x) + res = first(fs)(x) + return (res, extraChain(Base.tail(fs), res)...) +end + +extraChain(::Tuple{}, x) = () + + """ Dense(in::Integer, out::Integer, σ = identity) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index cbe250fcca..0ff1776db8 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -4,11 +4,13 @@ import Flux: activations @testset "basic" begin @testset "helpers" begin @testset "activations" begin - dummy_model = Chain(Dense(10,5,σ),Dense(5,2),softmax) - x = rand(10) - @test activations(Chain(), x) == [] - @test activations(dummy_model, x)[1] == dummy_model[1](x) - @test activations(dummy_model, x)[2] == x |> dummy_model[1] |> dummy_model[2] + dummy_model = Chain(x->x.^2, x->x .- 3, x -> tan.(x)) + x = randn(10) + @test activations(dummy_model, x)[1] == x.^2 + @test activations(dummy_model, x)[2] == (x.^2 .- 3) + @test activations(dummy_model, x)[3] == tan.(x.^2 .- 3) + + @test activations(Chain(), x) == () @test activations(Chain(identity, x->:foo), x)[2] == :foo # results include `Any` type end end @@ -19,6 +21,12 @@ import Flux: activations # numeric test should be put into testset of corresponding layer end + @testset "Activations" begin + c = Chain(Dense(3,5,relu), Dense(5,1,relu)) + X = Float32.([1.0; 1.0; 1.0]) + @test_nowarn gradient(()->Flux.activations(c, X)[2][1], params(c)) + end + @testset "Dense" begin @test length(Dense(10, 5)(randn(10))) == 5 @test_throws DimensionMismatch Dense(10, 5)(randn(1))