Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify @compact printing #20

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Fluxperimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export shinkansen!
include("chain.jl")

include("compact.jl")
export @compact

include("noshow.jl")
export NoShow
Expand Down
137 changes: 82 additions & 55 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,66 @@ import Flux: _big_show
@compact(forward::Function; name=nothing, parameters...)

Creates a layer by specifying some `parameters`, in the form of keywords,
and (usually as a `do` block) a function for the forward pass.
and a function for the forward pass (often as a `do` block).

You may think of `@compact` as a specialized `let` block creating local variables
that are trainable in Flux.
Declared variable names may be used within the body of the `forward` function.

Here is a linear model:
# Examples

Here is a linear model, equivalent to `Flux.Scale`:

```
r = @compact(w = rand(3)) do x
w .* x
end
r([1, 1, 1]) # x is set to [1, 1, 1].
using Flux, Fluxperimental

w = rand(3)
sc = @compact(x -> x .* w; w)

sc([1 10 100]) # 3×3 Matrix as output.
ans ≈ Flux.Scale(w)([1 10 100]) # equivalent Flux layer
```

Here is a linear model with bias and activation:
Here is a linear model with bias and activation, equivalent to Flux's `Dense` layer.
The forward pass function is now written as a do block, instead of `x -> begin y = W * x; ...`

```
d_in = 5
d_in = 3
d_out = 7
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
layer = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
y = W * x
act.(y .+ b)
end
d(ones(5, 10)) # 7×10 Matrix as output.
d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5]) # Equivalent to a dense layer

den = Dense(layer.variables.W, zeros(7), relu)([1,2,3]) # equivalent Flux layer
layer(ones(3, 10)) ≈ layer(ones(3, 10)) # 7×10 Matrix as output.
```
```

Finally, here is a simple MLP:
Finally, here is a simple MLP, equivalent to a `Chain` with 5 `Dense` layers:

```
using Flux

n_in = 1
n_out = 1
d_in = 1
nlayers = 3

model = @compact(
w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i=1:nlayers],
w3=Dense(128, n_out),
act=relu
lay1 = Dense(d_in => 64),
lay234 = [Dense(64 => 64) for i=1:nlayers],
wlast = rand32(64),
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
y = tanh.(lay1(x))
for lay in lay234
y = relu.(lay(y))
end
out = w3(embed)
return out
return wlast' * y
end

model(randn(n_in, 32)) # 1×32 Matrix as output.
model(randn(Float32, d_in, 8)) # 1×8 array as output.
```

We can train this model just like any `Chain`:
We can train this model just like any `Chain`, for example:

```
data = [([x], 2x-x^3) for x in -2:0.1f0:2]
data = [([x], [2x-x^3]) for x in -2:0.1f0:2]
optim = Flux.setup(Adam(), model)

for epoch in 1:1000
Expand All @@ -71,19 +73,23 @@ end
To specify a custom printout for the model, you may find [`NoShow`](@ref) useful.
"""
macro compact(_exs...)
_compact(_exs...) |> esc
end

function _compact(_exs...)
# check inputs, extracting function expression fex and unprocessed keyword arguments _kwexs
isempty(_exs) && error("expects at least two expressions: a function and at least one keyword")
isempty(_exs) && error("@compact expects at least two expressions: a function and at least one keyword")
if Meta.isexpr(_exs[1], :parameters)
length(_exs) >= 2 || error("expects an anonymous function")
length(_exs) >= 2 || error("@compact expects an anonymous function")
fex = _exs[2]
_kwexs = (_exs[1], _exs[3:end]...)
else
fex = _exs[1]
_kwexs = _exs[2:end]
end
Meta.isexpr(fex, :(->)) || error("expects an anonymous function")
isempty(_kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")
Meta.isexpr(fex, :(->)) || error("@compact expects an anonymous function")
isempty(_kwexs) && error("@compact expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("@compact expects only keyword arguments")

# process keyword arguments
if Meta.isexpr(_kwexs[1], :parameters) # handle keyword arguments provided after semicolon
Expand All @@ -96,27 +102,25 @@ macro compact(_exs...)
kwexs = (kwexs1..., kwexs2...)

# make strings
layer = "@compact"
setup = NamedTuple(map(ex -> Symbol(string(ex.args[1])) => string(ex.args[2]), kwexs))
input =
try
fex_args = fex.args[1]
isa(fex_args, Symbol) ? string(fex_args) : join(fex_args.args, ", ")
catch e
@warn "Function stringifying does not yet handle all cases. Falling back to empty string for input arguments"
""
@warn """@compact's function stringifying does not yet handle all cases. Falling back to "?" """ maxlog=1
"?"
end
block = string(Base.remove_linenums!(fex).args[2])
block = string(Base.remove_linenums!(fex).args[2]) # TODO make this remove macro comments

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
fex = supportself(fex, vars)
fex = _supportself(fex, vars)

# assemble
return esc(:($CompactLayer($fex, ($layer, $input, $block), $setup; $(kwexs...))))
return :($CompactLayer($fex, ($input, $block); $(kwexs...)))
end

function supportself(fex::Expr, vars)
function _supportself(fex::Expr, vars)
@gensym self
@gensym curried_f
# To avoid having to manipulate fex's arguments and body explicitly, we form a curried function first
Expand All @@ -130,17 +134,18 @@ function supportself(fex::Expr, vars)
end
end

struct CompactLayer{F,NT1<:NamedTuple,NT2<:NamedTuple}
struct CompactLayer{F<:Function, NT<:NamedTuple}
fun::F
strings::NTuple{3,String}
setup_strings::NT1
variables::NT2
strings::NTuple{2,String}
variables::NT
end
CompactLayer(f::Function, str::Tuple, setup_str::NamedTuple; kw...) = CompactLayer(f, str, setup_str, NamedTuple(kw))
(m::CompactLayer)(x...) = m.fun(m.variables, x...)
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro")
CompactLayer(f::Function, str::Tuple; kw...) = CompactLayer(f, str, NamedTuple(kw))
CompactLayer(args...) = error("CompactLayer is meant to be constructed by the macro @compact")

Flux.@functor CompactLayer

(m::CompactLayer)(x...) = m.fun(m.variables, x...)

Flux._show_children(m::CompactLayer) = m.variables

function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
Expand All @@ -154,16 +159,17 @@ function Base.show(io::IO, ::MIME"text/plain", m::CompactLayer)
end

function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
setup_strings = obj.setup_strings
layer, input, block = obj.strings
input, block = obj.strings
pre, post = ("(", ")")
println(io, " "^indent, isnothing(name) ? "" : "$name = ", layer, pre)
println(io, " "^indent, "@compact", pre)
for k in keys(obj.variables)
v = obj.variables[k]
if Flux._show_leaflike(v)
if false # Flux._show_leaflike(v)
# If the value is a leaf, just print verbatim what the user wrote:
str = String(k) * " = " * setup_strings[k]
# str = String(k) * " = " * summary(v)
str = String(k) * " isa " * string(typeof(v))
_just_show_params(io, str, v, indent+2)
# Flux._layer_show(io::IO, str, indent+2, nothing) # doesn't work
else
Flux._big_show(io, v, indent+2, String(k))
end
Expand All @@ -174,7 +180,7 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
print(io, " "^indent, post)
end

input != "" && print(io, " do ", input)
print(io, " do ", input)
if block != ""
block_to_print = block[6:end]
# Increase indentation of block according to `indent`:
Expand All @@ -188,6 +194,27 @@ function Flux._big_show(io::IO, obj::CompactLayer, indent::Int=0, name=nothing)
end
end

# # Temporarily fixing things via piracy, but would be an easy change in Flux
# using Flux: params, underscorise, _childarray_sum, _nan_show
# function Flux._layer_show(io::IO, layer::AbstractArray, indent::Int=0, name=nothing)
# _str = isnothing(name) ? "" : "$name = "
# # str = _str * sprint(show, layer, context=io) # before
# # str = _str * String(typeof(layer).name.name) # print Array
# str = _str * summary(layer) # print size too, sometimes too long... trim it?
# print(io, " "^indent, str, indent==0 ? "" : ",")
# if !isempty(params(layer))
# print(io, " "^max(2, (indent==0 ? 20 : 39) - indent - length(str)))
# printstyled(io, "# ", underscorise(sum(length, params(layer); init=0)), " parameters";
# color=:light_black)
# nonparam = _childarray_sum(length, layer) - sum(length, params(layer), init=0)
# if nonparam > 0
# printstyled(io, ", plus ", underscorise(nonparam), indent==0 ? " non-trainable" : ""; color=:light_black)
# end
# _nan_show(io, params(layer))
# end
# indent==0 || println(io)
# end

# Modified from src/layers/show.jl
function _just_show_params(io::IO, str::String, layer, indent::Int=0)
print(io, " "^indent, str, indent==0 ? "" : ",")
Expand Down
29 changes: 24 additions & 5 deletions test/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ function similar_strings(s1, s2)
return s1 == s2
end

function get_model_string(model)
io = IOBuffer()
show(io, MIME"text/plain"(), model)
String(take!(io))
end
get_model_string(model) = repr(MIME("text/plain"), model)

@testset "@compact" begin

Expand Down Expand Up @@ -139,6 +135,29 @@ end
@test similar_strings(get_model_string(model2), expected_string)
end

#= # This test is broken:

julia> model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
end;

julia> model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
w2(w1(x))
end
@compact(
@compact(
w1 = Dense(32 => 32, relu), # 1_056 parameters
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
w2(w1(x))
end,
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
w2(w1(x))
end # Total: 6 arrays, 3_168 parameters, 13.239 KiB.

=#

@testset "Array parameters" begin
model = @compact(x=randn(32), w=Dense(32=>32)) do s
w(x .* s)
Expand Down
Loading