Skip to content

Commit

Permalink
remove train_autodiff macro
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Oct 16, 2022
1 parent 0cea7ee commit bbc0f85
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 132 deletions.
56 changes: 0 additions & 56 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ The built-in loss functions accept 3 arguments, allowing for instance `train!(Fl
Callback functions are not supported. But see 3-argument `train!(loss, model, opt)` for an
easy way to construct more complicated training loops.
To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
"""
function train!(loss, model, data, opt)
losses = Float32[]
Expand Down Expand Up @@ -144,8 +142,6 @@ for (i, d) in enumerate(data)
end
```
To change the package used to calculate gradients, use [`Flux.@train_autodiff`](@ref).
!!! note
This method has no implicit `Params` analog in Flux ≤ 0.13.
"""
Expand Down Expand Up @@ -178,56 +174,4 @@ end

explicit_withgradient(f, args...) = Zygote.withgradient(f, args...) # can overload this to use e.g. Yota / Diffractor

"""
Flux.@train_autodiff Tracker
Flux.@train_autodiff Zygote
Flux.@train_autodiff Yota
Flux.@train_autodiff Diffractor
This macro allows the use of `train!` with various automatic differentiation (AD) packages,
instead of the default Zygote.jl.
You should load AD package, and then call this macro with the chosen name.
The macro overwrites a method withing Flux, thus is a global setting, lasting until you re-start Julia.
Only works with [Yota.jl](https://github.com/dfdx/Yota.jl),
[Tracker.jl](https://github.com/FluxML/Tracker.jl) (Flux's old AD),
[Diffractor.jl](https://github.com/JuliaDiff/Diffractor.jl) (which is not yet registered),
and with the default [Zygote.jl](https://github.com/FluxML/Zygote.jl).
!!! note
This is mechanism is experimental! And there are known bugs, in particular Tracker will not automatically switch to training mode for `Dropout` etc.
"""
macro train_autodiff(pkg)
if pkg == :Diffractor
return quote
Diffractor.gradient(sin, 0.0)[1] 1.0 # ensures an error if not loaded
function Flux.Train.explicit_withgradient(f, args...)
y, back = Diffractor.∂⃖¹(f, args...)
dy1 = Flux.Zygote.sensitivity(y) # Zygote is loaded, and this gives nice errors
return (; value = y, gradient = Base.tail(back(dy1)))
end
end |> esc
elseif pkg == :Yota
return quote
Yota.grad(sin, 0.0) # [2][1] ≈ 1.0
function Flux.Train.explicit_withgradient(f, args...)
value, (_, gradient...) = Yota.grad(f, args...)
return (; value, gradient)
end
end |> esc
elseif pkg == :Tracker
return quote
Tracker.withgradient(sum, [1.0]).val == 1.0 # ensures an error if too-old version
Flux.Train.explicit_withgradient(f, args...) = Tracker.withgradient(f, args...)
end |> esc
elseif pkg == :Zygote
return quote
Flux.Train.explicit_withgradient(f, args...) = Flux.Zygote.withgradient(f, args...)
end |> esc
else
throw("@train_autodiff expects one of Tracker, Zygote, Yota, or Diffractor. No other arguments are understood.")
end
end

end # module
76 changes: 0 additions & 76 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,79 +53,3 @@ end
# Test NaN / Inf early stop
# Test that loss is returned
end

import Tracker
Flux.@train_autodiff Tracker

@testset "Explicit Flux.train! with Tracker" begin
Random.seed!(84)
w = randn(10, 10)
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@testset for rule in [Descent(0.1), Adam(), AdamW()]

loss(m, x) = begin
Flux.istraining() && error("This test is not in fact using Tracker!")
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
end
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1

opt = Flux.setup(rule, model)
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end

# Test 3-arg `Flux.train!` method:
@testset for rule in [Descent(0.1), Adam()]

loss(m) = let x = rand(10)
Flux.istraining() && error("This test is not in fact using Tracker!")
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
end
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model) > 1

opt = Flux.setup(rule, model)
for i in 1:10^5
Flux.train!(loss, model, opt)
end
@test loss(model) < 0.01
end
end

import Yota
Flux.@train_autodiff Yota

@testset "Explicit Flux.train! with Yota" begin
Random.seed!(84)
w = randn(10, 10)
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
@testset for rule in [Descent(0.1), Adam(), AdamW()]

loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1

opt = Flux.setup(rule, model)
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end

# Test 3-arg `Flux.train!` method:
@testset for rule in [Descent(0.1), Adam()]

loss(m) = let x = rand(10)
Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
end
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model) > 1

opt = Flux.setup(rule, model)
for i in 1:10^5
Flux.train!(loss, model, opt)
end
@test loss(model) < 0.01
end
end

Flux.@train_autodiff Zygote

0 comments on commit bbc0f85

Please sign in to comment.