From 01bdd93eddd27c3b3098e5c1e5274cd7404fb328 Mon Sep 17 00:00:00 2001 From: Anton Smirnov <17990405+pxl-th@users.noreply.github.com> Date: Sun, 12 Jul 2020 12:52:52 +0300 Subject: [PATCH 1/2] Ignore Zeros when gathering parameters --- src/functor.jl | 4 +++- test/layers/conv.jl | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 592e70d9b1..430751cf9e 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -37,6 +37,8 @@ Possible values include: """ trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) +params!(p::Flux.Params, x::Zeros{<:Number}, seen = Flux.IdSet()) = nothing + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) @@ -84,4 +86,4 @@ f64(m) = paramtype(Float64, m) # Functors for certain Julia data structures @functor Cholesky -trainable(c::Cholesky) = () \ No newline at end of file +trainable(c::Cholesky) = () diff --git a/test/layers/conv.jl b/test/layers/conv.jl index dd732efab1..376cca8387 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -46,7 +46,7 @@ end op = bias(ip) @test sum(op) ≈ 0.f0 gs = gradient(() -> sum(bias(ip)), Flux.params(bias)) - @test gs[bias.bias] == nothing + @test !haskey(gs, bias.bias) # Train w/o bias and make sure no convergence happens # when only bias can be converged From aefe6e35681d39ba30096b490048626b3df3ebbe Mon Sep 17 00:00:00 2001 From: Anton Smirnov <17990405+pxl-th@users.noreply.github.com> Date: Sun, 12 Jul 2020 14:50:50 +0300 Subject: [PATCH 2/2] Add tests --- test/layers/conv.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 376cca8387..cdfaf46919 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -43,6 +43,15 @@ end @test sum(op) == prod(size(op)) bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) + + # Test that disabled bias is not in parameters + # and gathered parameters can be loaded back + bias_parameters = bias |> Flux.params + @test length(bias_parameters) == 1 + @test Flux.Zeros() ∉ bias_parameters + Flux.loadparams!(bias, bias_parameters) + + # Test that disabled bias does not appear in gradients op = bias(ip) @test sum(op) ≈ 0.f0 gs = gradient(() -> sum(bias(ip)), Flux.params(bias))