diff --git a/src/layers/show.jl b/src/layers/show.jl index 5494b958fe..aa9ccaf86f 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -77,8 +77,9 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing) print(io, " "^indent, str, indent==0 ? "" : ",") if !isempty(params(layer)) print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str))) - printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black) - nonparam = _childarray_sum(length, layer) - sum(length, params(layer)) + printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters"; +color=:light_black) + nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0) if nonparam > 0 printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black) end @@ -90,11 +91,11 @@ end function _big_finale(io::IO, m) ps = params(m) if length(ps) > 2 - pars = underscorise(sum(length, ps)) + pars = underscorise(sum(length, ps; init=0)) bytes = Base.format_bytes(Base.summarysize(m)) noncnt = _childarray_sum(_->1, m) - length(ps) if noncnt > 0 - nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps)) + nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps; init=0)) printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black) println(io, pars, " parameters,") printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black) @@ -107,7 +108,8 @@ function _big_finale(io::IO, m) end _childarray_sum(f, x::AbstractArray{<:Number}) = f(x) -_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x)) +_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x), +init=0) # utility functions diff --git a/test/layers/show.jl b/test/layers/show.jl index 3fc9bd097b..6910e5fa08 100644 --- a/test/layers/show.jl +++ b/test/layers/show.jl @@ -73,3 +73,12 @@ end @test occursin("Dense(2 => 2)", adjoint_chain) @test occursin("Chain([", adjoint_chain) end + +# Bug when no children, https://github.com/FluxML/Flux.jl/issues/2208 +struct NoFields end +Flux.@functor NoFields + +@testset "show with no fields" begin + str = repr("text/plain", Chain(Dense(1=>1), Dense(1=>1), NoFields())) + @test occursin("4 arrays, 4 parameters", str) +end