Skip to content

Commit

Permalink
Merge pull request #2210 from mcabbott/show4
Browse files Browse the repository at this point in the history
Fix a bug in show
  • Loading branch information
mcabbott authored Mar 14, 2023
2 parents 0155f61 + 35f018a commit 0d83f60
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
printstyled(io, "# ", underscorise(sum(length, params(layer))), " parameters"; color=:light_black)
nonparam = _childarray_sum(length, layer) - sum(length, params(layer))
printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
color=:light_black)
nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
if nonparam > 0
printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
end
Expand All @@ -90,11 +91,11 @@ end
function _big_finale(io::IO, m)
ps = params(m)
if length(ps) > 2
pars = underscorise(sum(length, ps))
pars = underscorise(sum(length, ps; init=0))
bytes = Base.format_bytes(Base.summarysize(m))
noncnt = _childarray_sum(_->1, m) - length(ps)
if noncnt > 0
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps))
nonparam = underscorise(_childarray_sum(length, m) - sum(length, ps; init=0))
printstyled(io, " "^08, "# Total: ", length(ps), " trainable arrays, "; color=:light_black)
println(io, pars, " parameters,")
printstyled(io, " "^10, "# plus ", noncnt, " non-trainable, ", nonparam, " parameters, summarysize "; color=:light_black)
Expand All @@ -107,7 +108,8 @@ function _big_finale(io::IO, m)
end

_childarray_sum(f, x::AbstractArray{<:Number}) = f(x)
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x))
_childarray_sum(f, x) = isleaf(x) ? 0 : sum(y -> _childarray_sum(f, y), Functors.children(x),
init=0)

# utility functions

Expand Down
9 changes: 9 additions & 0 deletions test/layers/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,12 @@ end
@test occursin("Dense(2 => 2)", adjoint_chain)
@test occursin("Chain([", adjoint_chain)
end

# Bug when no children, https://github.com/FluxML/Flux.jl/issues/2208
struct NoFields end
Flux.@functor NoFields

@testset "show with no fields" begin
str = repr("text/plain", Chain(Dense(1=>1), Dense(1=>1), NoFields()))
@test occursin("4 arrays, 4 parameters", str)
end

0 comments on commit 0d83f60

Please sign in to comment.