diff --git a/test/train.jl b/test/train.jl index 92b25bef5d..d2114d5a16 100644 --- a/test/train.jl +++ b/test/train.jl @@ -125,49 +125,54 @@ end end for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme")) -@testset "L2 regularisation with $name" begin - # New docs claim an exact equivalent. It's a bit long to put the example in there, - # but perhaps the tests should contain it. - - model = Dense(3 => 2, tanh); - init_weight = copy(model.weight); - data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10]; - - # Take 1: explicitly add a penalty in the loss function - opt = Flux.setup(Adam(0.1), model) - trainfn!(model, data, opt) do m, x, y - err = Flux.mse(m(x), y) - l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 - err + 0.33 * l2 + + if (name == "Enzyme" && get(ENV, "FLUX_TEST_ENZYME", "true") == "false") + continue end - diff1 = model.weight .- init_weight + + @testset "L2 regularisation with $name" begin + # New docs claim an exact equivalent. It's a bit long to put the example in there, + # but perhaps the tests should contain it. - # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! - # skipping this test for Enzyme cause implicit params is unsupported - if name == "Zygote" - model.weight .= init_weight - model.bias .= 0 - pen2(x::AbstractArray) = sum(abs2, x)/2 + model = Dense(3 => 2, tanh); + init_weight = copy(model.weight); + data = [(randn(Float32, 3,5), randn(Float32, 2,5)) for _ in 1:10]; + + # Take 1: explicitly add a penalty in the loss function opt = Flux.setup(Adam(0.1), model) trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) - l2 = sum(pen2, Flux.params(m)) + l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 err + 0.33 * l2 end - diff2 = model.weight .- init_weight - @test diff1 ≈ diff2 - end + diff1 = model.weight .- init_weight + + # Take 2: the same, but with Flux.params. Was broken for a bit, no tests! + # skipping this test for Enzyme cause implicit params is unsupported + if name == "Zygote" + model.weight .= init_weight + model.bias .= 0 + pen2(x::AbstractArray) = sum(abs2, x)/2 + opt = Flux.setup(Adam(0.1), model) + trainfn!(model, data, opt) do m, x, y + err = Flux.mse(m(x), y) + l2 = sum(pen2, Flux.params(m)) + err + 0.33 * l2 + end + diff2 = model.weight .- init_weight + @test diff1 ≈ diff2 + end - # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. - model.weight .= init_weight - model.bias .= 0 - decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); - trainfn!(model, data, decay_opt) do m, x, y - Flux.mse(m(x), y) + # Take 3: using WeightDecay instead. Need the /2 above, to match exactly. + model.weight .= init_weight + model.bias .= 0 + decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); + trainfn!(model, data, decay_opt) do m, x, y + Flux.mse(m(x), y) + end + diff3 = model.weight .- init_weight + @test diff1 ≈ diff3 end - diff3 = model.weight .- init_weight - @test diff1 ≈ diff3 -end end @testset "Flux.setup bugs" begin