Skip to content

Commit

Permalink
don't capture a string for each variable
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 13, 2023
1 parent d19d235 commit 34d569d
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ macro compact(_exs...)
kwexs = (kwexs1..., kwexs2...)

# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
input =
try
fex_args = fex.args[1]
Expand All @@ -113,7 +111,7 @@ macro compact(_exs...)
fex = supportself(fex, vars)

# assemble
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
return esc(:($CompactLayer($fex, ($input, $block); $(kwexs...))))
end

function supportself(fex::Expr, vars)
Expand All @@ -130,17 +128,18 @@ function supportself(fex::Expr, vars)
end
end

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
struct CompactLayer{F<:Function, NT<:NamedTuple}
fun::F
strings::NTuple{3,String}
setup_strings::NT1
variables::NT2
strings::NTuple{2,String}
variables::NT
end
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
CompactLayer(f::Function, str::Tuple; kw...) = CompactLayer(f, str, NamedTuple(kw))
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro @compact")

Flux.@functor CompactLayer

(m::CompactLayer)(x...) = m.fun(m.variables, x...)

Flux._show_children(m::CompactLayer) = m.variables

function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
Expand All @@ -154,16 +153,17 @@ function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
end

function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
layer, input, block = obj.strings
input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, "@compact", pre)
for k in keys(obj.variables)
v = obj.variables[k]
if Flux._show_leaflike(v)
if false # Flux._show_leaflike(v)
# If the value is a leaf, just print verbatim what the user wrote:
str = String(k) * " = " * setup_strings[k]
# str = String(k) * " = " * summary(v)
str = String(k) * " isa " * string(typeof(v))
_just_show_params(io, str, v, indent+2)
# Flux._layer_show(io::IO, str, indent+2, nothing) # doesn't work
else
Flux._big_show(io, v, indent+2, String(k))
end
Expand All @@ -188,6 +188,27 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
end
end

# Temporarily fixing things via piracy, but would be an easy change in Flux
using Flux: params, underscorise, _childarray_sum, _nan_show
function Flux._layer_show(io::IO, layer::AbstractArray, indent::Int=0, name=nothing)
_str = isnothing(name) ? "" : "$name = "
# str = _str * sprint(show, layer, context=io) # before
# str = _str * String(typeof(layer).name.name) # print Array
str = _str * summary(layer) # print size too, sometimes too long... trim it?
print(io, " "^indent, str, indent==0 ? "" : ",")
if !isempty(params(layer))
print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
color=:light_black)
nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
if nonparam > 0
printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
end
_nan_show(io, params(layer))
end
indent==0 || println(io)
end

# Modified from src/layers/show.jl
function _just_show_params(io::IO, str::String, layer, indent::Int=0)
print(io, " "^indent, str, indent==0 ? "" : ",")
Expand Down

0 comments on commit 34d569d

Please sign in to comment.