Skip to content

Commit

Permalink
more enzyme fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 10, 2024
1 parent 2d91378 commit 1bfc0d3
Showing 1 changed file with 39 additions and 34 deletions.
73 changes: 39 additions & 34 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,49 +125,54 @@ end
end

for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
@testset "L2 regularisation with $name" begin
# New docs claim an exact equivalent. It's a bit long to put the example in there,
# but perhaps the tests should contain it.

model = Dense(3 => 2, tanh);
init_weight = copy(model.weight);
data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10];

# Take 1: explicitly add a penalty in the loss function
opt = Flux.setup(Adam(0.1), model)
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
err + 0.33 * l2

if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false")
continue
end
diff1 = model.weight .- init_weight

@testset "L2 regularisation with $name" begin
# New docs claim an exact equivalent. It's a bit long to put the example in there,
# but perhaps the tests should contain it.

# Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
# skipping this test for Enzyme cause implicit params is unsupported
if name == "Zygote"
model.weight .= init_weight
model.bias .= 0
pen2(x::AbstractArray) = sum(abs2, x)/2
model = Dense(3 => 2, tanh);
init_weight = copy(model.weight);
data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10];

# Take 1: explicitly add a penalty in the loss function
opt = Flux.setup(Adam(0.1), model)
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
err + 0.33 * l2
end
diff2 = model.weight .- init_weight
@test diff1 diff2
end
diff1 = model.weight .- init_weight

# Take 2: the same, but with Flux.params. Was broken for a bit, no tests!
# skipping this test for Enzyme cause implicit params is unsupported
if name == "Zygote"
model.weight .= init_weight
model.bias .= 0
pen2(x::AbstractArray) = sum(abs2, x)/2
opt = Flux.setup(Adam(0.1), model)
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
end
diff2 = model.weight .- init_weight
@test diff1 diff2
end

# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
model.weight .= init_weight
model.bias .= 0
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model);
trainfn!(model, data, decay_opt) do m, x, y
Flux.mse(m(x), y)
# Take 3: using WeightDecay instead. Need the /2 above, to match exactly.
model.weight .= init_weight
model.bias .= 0
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model);
trainfn!(model, data, decay_opt) do m, x, y
Flux.mse(m(x), y)
end
diff3 = model.weight .- init_weight
@test diff1 diff3
end
diff3 = model.weight .- init_weight
@test diff1 diff3
end
end

@testset "Flux.setup bugs" begin
Expand Down

0 comments on commit 1bfc0d3

Please sign in to comment.