Skip to content

Commit

Permalink
Merge pull request #16 from gaurav-arya/ag-compact-kwargs
Browse files Browse the repository at this point in the history
Improve keyword argument handling in `@compact`
  • Loading branch information
mcabbott authored Aug 21, 2023
2 parents 9dcae27 + 9f43108 commit 514dff6
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 142 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Fluxperimental"
uuid = "3102ee7a-c841-4564-8f7f-ec69bd4fd658"
version = "0.1.3"
version = "0.2.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
30 changes: 21 additions & 9 deletions src/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@ r([1, 1, 1]) # x is set to [1, 1, 1].
Here is a linear model with bias and activation:
```
d = @compact(in=5, out=7, W=randn(out, in), b=zeros(out), act=relu) do x
d_in = 5
d_out = 7
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
y = W * x
act.(y .+ b)
end
d(ones(5, 10)) # 7×10 Matrix as output.
d(ones(5, 10)) # 7×10 Matrix as output.
d([1,2,3,4,5]) ≈ Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5]) # Equivalent to a dense layer
```
```
Finally, here is a simple MLP:
Expand Down Expand Up @@ -79,11 +83,21 @@ println(model) # "Linear(3 => 1)"
This can be useful when using `@compact` to hierarchically construct
complex models to be used inside a `Chain`.
"""
macro compact(fex, kwexs...)
# check input
macro compact(fex, _kwexs...)
# check inputs
Meta.isexpr(fex, :(->)) || error("expects a do block")
isempty(kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=))), kwexs) || error("expects only keyword argumens")
isempty(_kwexs) && error("expects keyword arguments")
all(ex -> Meta.isexpr(ex, (:kw,:(=),:parameters)), _kwexs) || error("expects only keyword arguments")

# process keyword arguments
if Meta.isexpr(_kwexs[1], :parameters) # handle keyword arguments provided after semicolon
kwexs1 = map(ex -> ex isa Symbol ? Expr(:kw, ex, ex) : ex, _kwexs[1].args)
_kwexs = _kwexs[2:end]
else
kwexs1 = ()
end
kwexs2 = map(ex -> Expr(:kw, ex.args...), _kwexs) # handle keyword arguments provided before semicolon
kwexs = (kwexs1..., kwexs2...)

# check if user has named layer:
name = findfirst(ex -> ex.args[1] == :name, kwexs)
Expand All @@ -103,16 +117,14 @@ macro compact(fex, kwexs...)

# edit expressions
vars = map(ex -> ex.args[1], kwexs)
assigns = map(ex -> Expr(:(=), ex.args...), kwexs)
@gensym self
pushfirst!(fex.args[1].args, self)
addprefix!(fex, self, vars)

# assemble
return esc(quote
let
$(assigns...)
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(vars...))
$CompactLayer($fex, $name, ($layer, $input, $block), $setup; $(kwexs...))
end
end)
end
Expand Down
296 changes: 164 additions & 132 deletions test/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,158 +27,190 @@ end

@testset "@compact" begin

r = @compact(w = [1, 5, 10]) do x
sum(w .* x)
@testset "Linear layer" begin
r = @compact(w = [1, 5, 10]) do x
sum(w .* x)
end
@test Flux.params(r) == Flux.Params([[1, 5, 10]])
@test r([1, 1, 1]) == 1 + 5 + 10
@test r([1, 2, 3]) == 1 + 2 * 5 + 3 * 10
@test r(ones(3, 3)) == 3 * (1 + 5 + 10)

# Test gradients:
@test gradient(r, [1, 1, 1])[1] == [1, 5, 10]
end
@test Flux.params(r) == Flux.Params([[1, 5, 10]])
@test r([1, 1, 1]) == 1 + 5 + 10
@test r([1, 2, 3]) == 1 + 2 * 5 + 3 * 10
@test r(ones(3, 3)) == 3 * (1 + 5 + 10)

# Test gradients:
@test gradient(r, [1, 1, 1])[1] == [1, 5, 10]
@testset "Linear layer with activation" begin
d_in = 5
d_out = 7
d = @compact(W = randn(d_out, d_in), b = zeros(d_out), act = relu) do x
y = W * x
act.(y .+ b)
end

@test size.(Flux.params(d)) == [(7, 5), (7,)]

@test size(d(ones(5, 10))) == (7, 10)
@test all(d(randn(5, 10)) .>= 0)

d = @compact(in = 5, out = 7, W = randn(out, in), b = zeros(out), act = relu) do x
y = W * x
act.(y .+ b)
# Test gradients:
y, ∇ = Flux.withgradient(Flux.params(d)) do
input = randn(5, 32)
desired_output = randn(7, 32)
prediction = d(input)
sum((prediction - desired_output) .^ 2)
end
@test typeof(y) == Float64
grads =.grads
@test typeof(grads) <: IdDict
@test length(grads) == 3
@test Set(size.(values(grads))) == Set([(7, 5), (), (7,)])

# Test equivalence to Dense layer:
d([1,2,3,4,5]) Dense(d.variables.W, zeros(7), relu)([1,2,3,4,5])
end

@test size.(Flux.params(d)) == [(7, 5), (7,)]
@testset "MLP" begin
n_in = 1
n_out = 1
nlayers = 3

@test size(d(ones(5, 10))) == (7, 10)
@test all(d(randn(5, 10)) .>= 0)
model = @compact(
w1 = Dense(n_in, 128),
w2 = [Dense(128, 128) for i = 1:nlayers],
w3 = Dense(128, n_out),
act = relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))
end
out = w3(embed)
return out
end

# Test gradients:
y, ∇ = Flux.withgradient(Flux.params(d)) do
input = randn(5, 32)
desired_output = randn(7, 32)
prediction = d(input)
sum((prediction - desired_output) .^ 2)
@test size.(Flux.params(model)) == [
(128, 1),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(1, 128),
(1,),
]
@test size(model(randn(n_in, 32))) == (1, 32)
end
@test typeof(y) == Float64
grads =.grads
@test typeof(grads) <: IdDict
@test length(grads) == 3
@test Set(size.(values(grads))) == Set([(7, 5), (), (7,)])


# MLP:
n_in = 1
n_out = 1
nlayers = 3

model = @compact(
w1 = Dense(n_in, 128),
w2 = [Dense(128, 128) for i = 1:nlayers],
w3 = Dense(128, n_out),
act = relu
) do x
embed = act(w1(x))
for w in w2
embed = act(w(embed))

@testset "String representations" begin
model = @compact(w=Dense(32 => 32)) do x, y
tmp = sum(w(x))
return tmp + y
end
out = w3(embed)
return out
expected_string = """@compact(
w = Dense(32=>32), #1_056 parameters
) do x, y
tmp = sum(w(x))
return tmp + y
end"""
@test similar_strings(get_model_string(model), expected_string)
end

@test size.(Flux.params(model)) == [
(128, 1),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(128, 128),
(128,),
(1, 128),
(1,),
]
@test size(model(randn(n_in, 32))) == (1, 32)

# Test string representations:
model = @compact(w=Dense(32 => 32)) do x, y
tmp = sum(w(x))
return tmp + y
end
expected_string = """@compact(
w = Dense(32=>32), #1_056 parameters
) do x, y
tmp = sum(w(x))
return tmp + y
end"""
@test similar_strings(get_model_string(model), expected_string)

# Custom naming:
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
tmp = sum(w(x))
return tmp + y
@testset "Custom naming" begin
model = @compact(w=Dense(32, 32), name="Linear(...)") do x, y
tmp = sum(w(x))
return tmp + y
end
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = "Linear(...) # 1_056 parameters"
@test similar_strings(get_model_string(model), expected_string)

# Hierarchical models should work too:
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
end
model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
w2(w1(x))
end
expected_string = """@compact(
w1 = @compact(
w1 = Dense(32 => 32, relu), # 1_056 parameters
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
@testset "Hierarchical models" begin
model1 = @compact(w1=Dense(32=>32, relu), w2=Dense(32=>32, relu)) do x
w2(w1(x))
end,
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
end
model2 = @compact(w1=model1, w2=Dense(32=>32, relu)) do x
w2(w1(x))
end # Total: 6 arrays, 3_168 parameters, 13.271 KiB."""
@test similar_strings(get_model_string(model2), expected_string)
end
expected_string = """@compact(
w1 = @compact(
w1 = Dense(32 => 32, relu), # 1_056 parameters
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
w2(w1(x))
end,
w2 = Dense(32 => 32, relu), # 1_056 parameters
) do x
w2(w1(x))
end # Total: 6 arrays, 3_168 parameters, 13.271 KiB."""
@test similar_strings(get_model_string(model2), expected_string)
end

# With array params:
model = @compact(x=randn(32), w=Dense(32=>32)) do s
w(x .* s)
@testset "Array parameters" begin
model = @compact(x=randn(32), w=Dense(32=>32)) do s
w(x .* s)
end
expected_string = """@compact(
x = randn(32), # 32 parameters
w = Dense(32 => 32), # 1_056 parameters
) do s
w(x .* s)
end # Total: 3 arrays, 1_088 parameters, 4.734 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = """@compact(
x = randn(32), # 32 parameters
w = Dense(32 => 32), # 1_056 parameters
) do s
w(x .* s)
end # Total: 3 arrays, 1_088 parameters, 4.734 KiB."""
@test similar_strings(get_model_string(model), expected_string)

# Hierarchy with inner model named:
model = @compact(
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
) do x
w2 * w1(x)

@testset "Hierarchy with inner model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32), name="Model(32)") do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
) do x
w2 * w1(x)
end
expected_string = """@compact(
Model(32), # 1_024 parameters
w2 = randn(32, 32), # 1_024 parameters
w3 = randn(32), # 32 parameters
) do x
w2 * w1(x)
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = """@compact(
Model(32), # 1_024 parameters
w2 = randn(32, 32), # 1_024 parameters
w3 = randn(32), # 32 parameters
) do x

@testset "Hierarchy with outer model named" begin
model = @compact(
w1=@compact(w1=randn(32, 32)) do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
name="Model(32)"
) do x
w2 * w1(x)
end # Total: 3 arrays, 2_080 parameters, 17.089 KiB."""
@test similar_strings(get_model_string(model), expected_string)
end
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
@test similar_strings(get_model_string(model), expected_string)
end

# Hierarchy with outer model named:
model = @compact(
w1=@compact(w1=randn(32, 32)) do x
w1 * x
end,
w2=randn(32, 32),
w3=randn(32),
name="Model(32)"
) do x
w2 * w1(x)
@testset "Dependent initializations" begin
# Test that initialization lines cannot depend on each other
@test_throws UndefVarError @compact(y = 3, z = y^2) do x
y + z + x
end
end
expected_string = """Model(32) # Total: 3 arrays, 2_080 parameters, 17.057KiB."""
@test similar_strings(get_model_string(model), expected_string)

@testset "Keyword argument syntax" begin
_a = 3
_b = 4
c = 5
model = @compact(a=_a; b=_b, c) do x
a + b * x + c * x^2
end
@test model(2) == _a + _b * 2 + c * 2^2
end
end

0 comments on commit 514dff6

Please sign in to comment.