diff --git a/src/train.jl b/src/train.jl index 884c3578c0..0a7433ac6d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -90,8 +90,6 @@ The built-in loss functions accept 3 arguments, allowing for instance `train!(Fl Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an easy way to construct more complicated training loops. - -To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref). """ function train!(loss, model, data, opt) losses = Float32[] @@ -144,8 +142,6 @@ for (i, d) in enumerate(data) end ``` -To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref). - !!! note This method has no implicit `Params` analog in Flux ≤ 0.13. """ @@ -178,56 +174,4 @@ end explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor -""" - Flux.@train_autodiff Tracker - Flux.@train_autodiff Zygote - Flux.@train_autodiff Yota - Flux.@train_autodiff Diffractor - -This macro allows the use of `train!` with various automatic differentiation (AD) packages, -instead of the default Zygote.jl. - -You should load AD package, and then call this macro with the chosen name. -The macro overwrites a method withing Flux, thus is a global setting, lasting until you re-start Julia. - -Only works with [Yota.jl](https://github.com/dfdx/Yota.jl), -[Tracker.jl](https://github.com/FluxML/Tracker.jl) (Flux's old AD), -[Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (which is not yet registered), -and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl). - -!!! note - This is mechanism is experimental! And there are known bugs, in particular Tracker will not automatically switch to training mode for `Dropout` etc. -""" -macro train_autodiff(pkg) - if pkg == :Diffractor - return quote - Diffractor.gradient(sin, 0.0)[1] ≈ 1.0 # ensures an error if not loaded - function Flux.Train.explicit_withgradient(f, args...) - y, back = Diffractor.∂⃖¹(f, args...) - dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors - return (; value = y, gradient = Base.tail(back(dy1))) - end - end |> esc - elseif pkg == :Yota - return quote - Yota.grad(sin, 0.0) # [2][1] ≈ 1.0 - function Flux.Train.explicit_withgradient(f, args...) - value, (_, gradient...) = Yota.grad(f, args...) - return (; value, gradient) - end - end |> esc - elseif pkg == :Tracker - return quote - Tracker.withgradient(sum, [1.0]).val == 1.0 # ensures an error if too-old version - Flux.Train.explicit_withgradient(f, args...) = Tracker.withgradient(f, args...) - end |> esc - elseif pkg == :Zygote - return quote - Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...) - end |> esc - else - throw("@train_autodiff expects one of Tracker, Zygote, Yota, or Diffractor. No other arguments are understood.") - end -end - end # module diff --git a/test/train.jl b/test/train.jl index 81ffa2f3db..443a39dd75 100644 --- a/test/train.jl +++ b/test/train.jl @@ -53,79 +53,3 @@ end # Test NaN / Inf early stop # Test that loss is returned end - -import Tracker -Flux.@train_autodiff Tracker - -@testset "Explicit Flux.train! with Tracker" begin - Random.seed!(84) - w = randn(10, 10) - w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. - @testset for rule in [Descent(0.1), Adam(), AdamW()] - - loss(m, x) = begin - Flux.istraining() && error("This test is not in fact using Tracker!") - Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - end - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model, rand(10, 10)) > 1 - - opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) - @test loss(model, rand(10, 10)) < 0.01 - end - - # Test 3-arg `Flux.train!` method: - @testset for rule in [Descent(0.1), Adam()] - - loss(m) = let x = rand(10) - Flux.istraining() && error("This test is not in fact using Tracker!") - Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - end - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model) > 1 - - opt = Flux.setup(rule, model) - for i in 1:10^5 - Flux.train!(loss, model, opt) - end - @test loss(model) < 0.01 - end -end - -import Yota -Flux.@train_autodiff Yota - -@testset "Explicit Flux.train! with Yota" begin - Random.seed!(84) - w = randn(10, 10) - w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. - @testset for rule in [Descent(0.1), Adam(), AdamW()] - - loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model, rand(10, 10)) > 1 - - opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) - @test loss(model, rand(10, 10)) < 0.01 - end - - # Test 3-arg `Flux.train!` method: - @testset for rule in [Descent(0.1), Adam()] - - loss(m) = let x = rand(10) - Flux.Losses.mse(w*x, m.weight*x .+ m.bias) - end - model = (weight=copy(w2), bias=zeros(10), ignore=nothing) - @test loss(model) > 1 - - opt = Flux.setup(rule, model) - for i in 1:10^5 - Flux.train!(loss, model, opt) - end - @test loss(model) < 0.01 - end -end - -Flux.@train_autodiff Zygote