diff --git a/src/functor.jl b/src/functor.jl index 592e70d9b1..430751cf9e 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -37,6 +37,8 @@ Possible values include: """ trainmode!(m, mode = true) = mode isa Bool ? testmode!(m, !mode) : testmode!(m, mode) +params!(p::Flux.Params, x::Zeros{<:Number}, seen = Flux.IdSet()) = nothing + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) @@ -84,4 +86,4 @@ f64(m) = paramtype(Float64, m) # Functors for certain Julia data structures @functor Cholesky -trainable(c::Cholesky) = () \ No newline at end of file +trainable(c::Cholesky) = () diff --git a/test/layers/conv.jl b/test/layers/conv.jl index dd732efab1..cdfaf46919 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -43,10 +43,19 @@ end @test sum(op) == prod(size(op)) bias = Conv((2,2), 1=>3, bias = Flux.Zeros()) + + # Test that disabled bias is not in parameters + # and gathered parameters can be loaded back + bias_parameters = bias |> Flux.params + @test length(bias_parameters) == 1 + @test Flux.Zeros() ∉ bias_parameters + Flux.loadparams!(bias, bias_parameters) + + # Test that disabled bias does not appear in gradients op = bias(ip) @test sum(op) ≈ 0.f0 gs = gradient(() -> sum(bias(ip)), Flux.params(bias)) - @test gs[bias.bias] == nothing + @test !haskey(gs, bias.bias) # Train w/o bias and make sure no convergence happens # when only bias can be converged