Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve keyword argument handling in @compact #16

Merged
merged 6 commits into from
Aug 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Fluxperimental"
uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
version = "0.1.3"
version = "0.2.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
30 changes: 21 additions & 9 deletions src/compact.jl
Original file line number Diff line number Diff line change
@@ -21,11 +21,15 @@ r([1, 1, 1]) # x is set to [1, 1, 1].
Here is a linear model with bias and activation:

```
d = @compact(in=5, out=7, W=randn(out, in), b=zeros(out), act=relu) do x
d_in = 5
d_out = 7
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
y = W * x
act.(y .+ b)
end
d(ones(5, 10)) # 7×10 Matrix as output.
d(ones(5, 10)) # 7×10 Matrix as output.
d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5]) # Equivalent to a dense layer
```
```

Finally, here is a simple MLP:
@@ -79,11 +83,21 @@ println(model) # "Linear(3 => 1)"
This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.
"""
macro compact(fex, kwexs...)
# check input
macro compact(fex, _kwexs...)
# check inputs
Meta.isexpr(fex, :(->)) || error("expects a do block")
isempty(kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=))), kwexs) || error("expects only keyword argumens")
isempty(_kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")

# process keyword arguments
if Meta.isexpr(_kwexs[1], :parameters) # handle keyword arguments provided after semicolon
kwexs1 = map(ex -> ex isa Symbol ? Expr(:kw, ex, ex) : ex, _kwexs[1].args)
_kwexs = _kwexs[2:end]
else
kwexs1 = ()
end
kwexs2 = map(ex -> Expr(:kw, ex.args...), _kwexs) # handle keyword arguments provided before semicolon
kwexs = (kwexs1..., kwexs2...)
Comment on lines +92 to +100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic handles the first case below, but can it be done nicely on all arguments to handle also the second case here?

julia> macro pr(exs...)
         @show exs; nothing
       end
@pr (macro with 1 method)

julia> @pr(; a=1, b) do x
         x+a+b
       end
exs = (:((x,)->begin
          #= REPL[57]:2 =#
          x + a + b
      end), :($(Expr(:parameters, :($(Expr(:kw, :a, 1))), :b))))

julia> @pr(x -> x+a+b; a=1, b)
exs = (:($(Expr(:parameters, :($(Expr(:kw, :a, 1))), :b))), :(x->begin
          #= REPL[58]:1 =#
          x + a + b
      end))

julia> @pr(x -> x+a+b, a=1, b=2)
exs = (:(x->begin
          #= REPL[59]:1 =#
          x + a + b
      end), :(a = 1), :(b = 2))

Or might it be better to have the function call itself with arguments in re-arranged order?

Copy link
Contributor Author

@gaurav-arya gaurav-arya Aug 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I added

@warn "function stringifying does not yet handle all cases, using empty string"
""
just so that the test could run, it should not stay very long hopefully:)

Copy link
Contributor Author

@gaurav-arya gaurav-arya Aug 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. That warning isn't sufficient because this PR doesn't have the new logic in #17. So I addressed this comment there instead, in c1f6605 (sorry for the mess...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. So we merge this, and work on any further refinements in #17?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds good.


# check if user has named layer:
name = findfirst(ex -> ex.args[1] == :name, kwexs)
@@ -103,16 +117,14 @@ macro compact(fex, kwexs...)

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
assigns = map(ex -> Expr(:(=), ex.args...), kwexs)
@gensym self
pushfirst!(fex.args[1].args, self)
addprefix!(fex, self, vars)

# assemble
return esc(quote
let
$(assigns...)
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(vars...))
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))
end
end)
end
296 changes: 164 additions & 132 deletions test/compact.jl
Original file line number Diff line number Diff line change
@@ -27,158 +27,190 @@ end

@testset "@compact" begin

r = @compact(w = [1, 5, 10]) do x
sum(w .* x)
@testset "Linear layer" begin
r = @compact(w = [1, 5, 10]) do x
sum(w .* x)
end
@test Flux.params(r) == Flux.Params([[1, 5, 10]])
@test r([1, 1, 1]) == 1 + 5 + 10
@test r([1, 2, 3]) == 1 + 2 * 5 + 3 * 10
@test r(ones(3, 3)) == 3 * (1 + 5 + 10)

# Test gradients:
@test gradient(r, [1, 1, 1])[1] == [1, 5, 10]
end
@test Flux.params(r) == Flux.Params([[1, 5, 10]])
@test r([1, 1, 1]) == 1 + 5 + 10
@test r([1, 2, 3]) == 1 + 2 * 5 + 3 * 10
@test r(ones(3, 3)) == 3 * (1 + 5 + 10)

# Test gradients:
@test gradient(r, [1, 1, 1])[1] == [1, 5, 10]
@testset "Linear layer with activation" begin
d_in = 5
d_out = 7
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
y = W * x
act.(y .+ b)
end

@test size.(Flux.params(d)) == [(7, 5), (7,)]

@test size(d(ones(5, 10))) == (7, 10)
@test all(d(randn(5, 10)) .>= 0)

d = @compact(in = 5, out = 7, W = randn(out, in), b = zeros(out), act = relu) do x
y = W * x
act.(y .+ b)
# Test gradients:
y, ∇ = Flux.withgradient(Flux.params(d)) do
input = randn(5, 32)
desired_output = randn(7, 32)
prediction = d(input)
sum((prediction - desired_output) .^ 2)
end
@test typeof(y) == Float64
grads = ∇.grads
@test typeof(grads) <: IdDict
@test length(grads) == 3
@test Set(size.(values(grads))) == Set([(7, 5), (), (7,)])

# Test equivalence to Dense layer:
d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5])
end

@test size.(Flux.params(d)) == [(7, 5), (7,)]
@testset "MLP" begin
n_in = 1
n_out = 1
nlayers = 3

@test size(d(ones(5, 10))) == (7, 10)
@test all(d(randn(5, 10)) .>= 0)
model = @compact(
w1 = Dense(n_in, 128),
w2 = [Dense(128, 128) for i = 1:nlayers],
w3 = Dense(128, n_out),
act = relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
return out
end

# Test gradients:
y, ∇ = Flux.withgradient(Flux.params(d)) do
input = randn(5, 32)
desired_output = randn(7, 32)
prediction = d(input)
sum((prediction - desired_output) .^ 2)
@test size.(Flux.params(model)) == [
(128, 1),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(1, 128),
(1,),
]
@test size(model(randn(n_in, 32))) == (1, 32)
end
@test typeof(y) == Float64
grads = ∇.grads
@test typeof(grads) <: IdDict
@test length(grads) == 3
@test Set(size.(values(grads))) == Set([(7, 5), (), (7,)])


# MLP:
n_in = 1
n_out = 1
nlayers = 3

model = @compact(
w1 = Dense(n_in, 128),
w2 = [Dense(128, 128) for i = 1:nlayers],
w3 = Dense(128, n_out),
act = relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))

@testset "String representations" begin
model = @compact(w=Dense(32 => 32)) do x, y
tmp = sum(w(x))
return tmp + y
end
out = w3(embed)
return out
expected_string = """@compact(
w = Dense(32=>32), #1_056 parameters
) do x, y
tmp = sum(w(x))
return tmp + y
end"""
@test similar_strings(get_model_string(model), expected_string)
end

@test size.(Flux.params(model)) == [
(128, 1),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(1, 128),
(1,),
]
@test size(model(randn(n_in, 32))) == (1, 32)

# Test string representations:
model = @compact(w=Dense(32 => 32)) do x, y
tmp = sum(w(x))
return tmp + y
end
expected_string = """@compact(
w = Dense(32=>32), #1_056 parameters
) do x, y
tmp = sum(w(x))
return tmp + y
end"""
@test similar_strings(get_model_string(model), expected_string)

# Custom naming:
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
tmp = sum(w(x))
return tmp + y
@testset "Custom naming" begin
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
tmp = sum(w(x))
return tmp + y
end
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)

# Hierarchical models should work too:
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
end
model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
w2(w1(x))
end
expected_string = """@compact(
w1 = @compact(
w1 = Dense(32 => 32, relu), # 1_056 parameters
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
@testset "Hierarchical models" begin
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
end,
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
end
model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
w2(w1(x))
end # Total: 6 arrays, 3_168 parameters, 13.271 KiB."""
@test similar_strings(get_model_string(model2), expected_string)
end
expected_string = """@compact(
w1 = @compact(
w1 = Dense(32 => 32, relu), # 1_056 parameters
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
w2(w1(x))
end,
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
w2(w1(x))
end # Total: 6 arrays, 3_168 parameters, 13.271 KiB."""
@test similar_strings(get_model_string(model2), expected_string)
end

# With array params:
model = @compact(x=randn(32), w=Dense(32=>32)) do s
w(x .* s)
@testset "Array parameters" begin
model = @compact(x=randn(32), w=Dense(32=>32)) do s
w(x .* s)
end
expected_string = """@compact(
x = randn(32), # 32 parameters
w = Dense(32 => 32), # 1_056 parameters
) do s
w(x .* s)
end # Total: 3 arrays, 1_088 parameters, 4.734 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = """@compact(
x = randn(32), # 32 parameters
w = Dense(32 => 32), # 1_056 parameters
) do s
w(x .* s)
end # Total: 3 arrays, 1_088 parameters, 4.734 KiB."""
@test similar_strings(get_model_string(model), expected_string)

# Hierarchy with inner model named:
model = @compact(
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
) do x
w2 * w1(x)

@testset "Hierarchy with inner model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
) do x
w2 * w1(x)
end
expected_string = """@compact(
Model(32), # 1_024 parameters
w2 = randn(32, 32), # 1_024 parameters
w3 = randn(32), # 32 parameters
) do x
w2 * w1(x)
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = """@compact(
Model(32), # 1_024 parameters
w2 = randn(32, 32), # 1_024 parameters
w3 = randn(32), # 32 parameters
) do x

@testset "Hierarchy with outer model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32)) do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
name="Model(32)"
) do x
w2 * w1(x)
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
@test similar_strings(get_model_string(model), expected_string)
end

# Hierarchy with outer model named:
model = @compact(
w1=@compact(w1=randn(32, 32)) do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
name="Model(32)"
) do x
w2 * w1(x)
@testset "Dependent initializations" begin
# Test that initialization lines cannot depend on each other
@test_throws UndefVarError @compact(y = 3, z = y^2) do x
y + z + x
end
end
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
@test similar_strings(get_model_string(model), expected_string)

@testset "Keyword argument syntax" begin
_a = 3
_b = 4
c = 5
model = @compact(a=_a; b=_b, c) do x
a + b * x + c * x^2
end
@test model(2) == _a + _b * 2 + c * 2^2
end
end