diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b209a82f..9869a5d66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,13 @@ # News -## [0.3.0] - 04.04.2022 +## Unreleased + +### Added +- Support for [Optimisers.jl](https://github.com/FluxML/Optimisers.jl) https://github.com/FluxML/FluxTraining.jl/pull/114. + +## [0.3.0] - 04.04.2022 ### Added diff --git a/Project.toml b/Project.toml index 12c4f464d..2da0bcafc 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" @@ -34,8 +35,8 @@ Graphs = "1" ImageCore = "0.8, 0.9" InlineTest = "0.2" OnlineStats = "1.5" -Parameters = "0.12" ParameterSchedulers = "0.3.1" +Parameters = "0.12" PrettyTables = "1, 1.1, 1.2" ProgressMeter = "1.4" Reexport = "1.0" diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 2001cb0d8..e9e8452b8 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -15,6 +15,7 @@ module ES end import OnlineStats using OnlineStats: EqualWeight, Mean, OnlineStat +import Optimisers using Parameters using ProgressMeter: Progress, next! using Statistics: mean diff --git a/src/learner.jl b/src/learner.jl index 8ac5b8064..930083416 100644 --- a/src/learner.jl +++ b/src/learner.jl @@ -16,6 +16,8 @@ mutable struct Learner data::PropDict optimizer lossfn + # this used to store `Flux.Params` but now stores the optimiser state + # if an optim from Optimisers.jl is used params step::PropDict callbacks::Callbacks @@ -96,7 +98,7 @@ function Learner( _dataiters(data), optimizer, lossfn, - paramsrec(model), + setupoptimstate(model, optimizer), PropDict(), cbs, PropDict()) @@ -129,9 +131,15 @@ phasedataiter(::AbstractValidationPhase) = :validation function model!(learner, model) learner.model = model - learner.params = paramsrec(model) + learner.params = setupoptimstate(model, learner.optimizer) end +# Flux.jl optimisers store `params`, while Optimisers.jl store the result of `setup` +setupoptimstate(model, ::Flux.Optimise.AbstractOptimiser) = Flux.params(model) +# Optimisers.jl has no abstract supertype so we assume non-Flux optimisers +# conform to the Optimisers.jl interface. +setupoptimstate(model, optim) = Optimisers.setup(optim, model) + _dataiters(d::PropDict) = d _dataiters(t::NamedTuple) = PropDict(pairs(t)) @@ -146,9 +154,3 @@ function _dataiters(t::Tuple) error("Please pass a `NamedTuple` or `PropDict` as `data`.") end end - - -paramsrec(m) = Flux.params(m) -paramsrec(t::Union{Tuple,NamedTuple}) = map(paramsrec, t) - -# Callback utilities diff --git a/src/training.jl b/src/training.jl index 4356252d2..308a91b63 100644 --- a/src/training.jl +++ b/src/training.jl @@ -49,19 +49,36 @@ function step! end function step!(learner, phase::TrainingPhase, batch) xs, ys = batch runstep(learner, phase, (; xs=xs, ys=ys)) do handle, state - state.grads = gradient(learner.params) do - state.ŷs = learner.model(state.xs) + + state.grads = _gradient(learner.optimizer, learner.model, learner.params) do model + state.ŷs = model(state.xs) handle(LossBegin()) state.loss = learner.lossfn(state.ŷs, state.ys) handle(BackwardBegin()) return state.loss end handle(BackwardEnd()) - update!(learner.optimizer, learner.params, state.grads) + learner.params, learner.model = _update!( + learner.optimizer, learner.params, learner.model, state.grads) end end +# Handle both old Flux.jl and new Optimisers.jl optimisers + +_gradient(f, _, m, _) = gradient(f, m)[1] +_gradient(f, ::Flux.Optimise.AbstractOptimiser, m, ps::Params) = gradient(() -> f(m), ps) + +function _update!(optimizer::Flux.Optimise.AbstractOptimiser, params, model, grads) + update!(optimizer, params, grads) + return params, model +end +function _update!(_, st, model, grads) + st, model = Optimisers.update!(st, model, grads) + return st, model +end + + function step!(learner, phase::ValidationPhase, batch) xs, ys = batch runstep(learner, phase, (;xs=xs, ys=ys)) do _, state diff --git a/test/Project.toml b/test/Project.toml index 9ec14e997..16ca9e068 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ version = "0.1.0" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" diff --git a/test/imports.jl b/test/imports.jl index bb4b4ef5e..a6f694336 100644 --- a/test/imports.jl +++ b/test/imports.jl @@ -1,4 +1,5 @@ using ReTest +import Optimisers using FluxTraining using ParameterSchedulers using Colors diff --git a/test/training.jl b/test/training.jl index f1846ae96..8b16e6a4b 100644 --- a/test/training.jl +++ b/test/training.jl @@ -47,3 +47,10 @@ end fit!(learner, 5) @test learner.model.coeff[1] ≈ 3 atol = 0.1 end + + +@testset "Optimisers.jl compatibility" begin + learner = testlearner(coeff = 3, opt=Optimisers.Descent(0.001)) + fit!(learner, 5) + @test learner.model.coeff[1] ≈ 3 atol = 0.1 +end