Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle gradients and object trees together #24

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Optimisers

using Functors
using Functors: functor, fmap, isleaf

include("interface.jl")
Expand Down
34 changes: 21 additions & 13 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,36 @@ function state(o, x)
if isleaf(x)
return init(o, x)
else
x, _ = functor(x)
x, _ = _functor(x)
return map(x -> state(o, x), x)
end
end

function _update(o, st, x, x̄s...)
st, x̄ = apply(o, st, x, x̄s...)
return st, patch(x, x̄)
function _update(o, x, x̄, st)
x̄, st = apply(o, x, x̄, st)
return patch(x, x̄), st
end

function update(o, state, x::T, x̄s...) where T
if all(isnothing, x̄s)
return state, x
function update(o, x::T, x̄, state) where T
if x̄ === nothing
return x, state
elseif isleaf(x)
return _update(o, state, x, x̄s...)
return _update(o, x, x̄, state)
else
x̄s = map(x̄ -> functor(typeof(x), x̄)[1], x̄s)
x, restructure = functor(typeof(x), x)
xstate = map((state, x, x̄s...) -> update(o, state, x, x̄s...), state, x, x̄s...)
return map(first, xstate), restructure(map(last, xstate))
x̄, _ = _functor(typeof(x), x̄)
x, restructure = _functor(typeof(x), x)
xstate = map((x, x̄, state) -> update(o, x, , state), x, x̄, state)
return restructure(map(first, xstate)), map(x -> x[2], xstate)
end
end

_functor(x) = Functors.functor(x)
_functor(ref::Base.RefValue) = Functors.functor(ref[])
_functor(T, x) = Functors.functor(T, x)

# may be risky since Optimisers may silently call
# this if some structures don't have appropriate overrides
init(o, x) = nothing

# default all rules to first order calls
apply(o, state, x, dx, dxs...) = apply(o, state, x, dx)
# apply(o, x, dx, dxs, state) = apply(o, x, dx, state)
90 changes: 45 additions & 45 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ Descent() = Descent(1f-1)

init(o::Descent, x::AbstractArray) = nothing

function apply(o::Descent, state, x, dx)
function apply(o::Descent, x, dx, state)
η = convert(eltype(dx), o.eta)

return state, dx .* η
return dx .* η, state
end

(o::Descent)(state, m, dm) = update(o, state, m, dm)
(o::Descent)(m, dm, state) = update(o, m, dm, state)

"""
Momentum(η = 1f-2, ρ = 9f-1)
Expand All @@ -42,14 +42,14 @@ Momentum(η = 1f-2, ρ = 9f-1) = Momentum{typeof(η)}(η, ρ)

init(o::Momentum, x::AbstractArray) = zero(x)

function apply(o::Momentum, state, x, dx)
function apply(o::Momentum, x, dx, state)
η, ρ, v = o.eta, o.rho, state
@. v = ρ * v - η * dx

return v, -v
end

(o::Momentum)(state, m, dm) = update(o, state, m, dm)
(o::Momentum)(m, dm, state) = update(o, m, dm, state)

"""
Nesterov(η = 1f-3, ρ = 9f-1)
Expand All @@ -70,14 +70,14 @@ Nesterov(η = 1f-3, ρ = 9f-1) = Nesterov{typeof(η)}(η, ρ)

init(o::Nesterov, x::AbstractArray) = zero(x)

(o::Nesterov)(state, m, dm) = update(o, state, m, dm)
(o::Nesterov)(m, dm, state) = update(o, m, dm, state)

function apply(o::Nesterov, state, x, dx)
function apply(o::Nesterov, x, dx, state)
η, ρ, v = o.eta, o.rho, state
d = @. ρ^2 * v - (1+ρ) * η * dx
@. v = ρ * v - η * dx

return v, -d
return -d, v
end

"""
Expand Down Expand Up @@ -105,15 +105,15 @@ RMSProp(η = 1f-3, ρ = 9f-1, ϵ = eps(typeof(η))) = RMSProp{typeof(η)}(η, ρ

init(o::RMSProp, x::AbstractArray) = zero(x)

function apply(o::RMSProp, state, x, dx)
function apply(o::RMSProp, x, dx, state)
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
@. acc = ρ * acc + (1 - ρ) * dx^2
dx = @. dx * (η / (sqrt(acc) + ϵ))

return acc, dx
return dx, acc
end

(o::RMSProp)(state, m, dm) = update(o, state, m, dm)
(o::RMSProp)(m, dm, state) = update(o, m, dm, state)

"""
ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
Expand All @@ -137,17 +137,17 @@ ADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = ADAM{typeof(η)}(

init(o::ADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)

(o::ADAM)(state, m, dm) = update(o, state, m, dm)
(o::ADAM)(m, dm, state) = update(o, m, dm, state)

function apply(o::ADAM{T}, state, x, dx) where T
function apply(o::ADAM{T}, x, dx, state) where T
η, β, ϵ = o.eta, o.beta, o.epsilon
mt, vt, βt = state

@. mt = β[1] * mt + (one(T) - β[1]) * dx
@. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
dx = @. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η

return (mt, vt, βt .* β), dx
return dx, (mt, vt, βt .* β)
end

"""
Expand All @@ -172,9 +172,9 @@ RADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = RADAM{typeof(η)}

init(o::RADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, 1)

(o::RADAM)(state, m, dm) = update(o, state, m, dm)
(o::RADAM)(m, dm, state) = update(o, m, dm, state)

function apply(o::RADAM, state, x, dx)
function apply(o::RADAM, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon
ρ∞ = 2/(1-β[2])-1

Expand All @@ -190,7 +190,7 @@ function apply(o::RADAM, state, x, dx)
dx = @. mt / (1 - βt[1]) * η
end

return (mt, vt, βt .* β, t + 1), dx
return dx, (mt, vt, βt .* β, t + 1)
end

"""
Expand All @@ -215,9 +215,9 @@ AdaMax(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaMax{typeof(η

init(o::AdaMax, x::AbstractArray) = (zero(x), zero(x), o.beta)

(o::AdaMax)(state, m, dm) = update(o, state, m, dm)
(o::AdaMax)(m, dm, state) = update(o, m, dm, state)

function apply(o::AdaMax, state, x, dx)
function apply(o::AdaMax, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

mt, ut, βt = state
Expand All @@ -226,7 +226,7 @@ function apply(o::AdaMax, state, x, dx)
@. ut = max(β[2] * ut, abs(dx))
dx = @. (η/(1 - βt[1])) * mt/(ut + ϵ)

return (mt, ut, βt .* β), dx
return dx, (mt, ut, βt .* β)
end

"""
Expand All @@ -252,9 +252,9 @@ OADAM(η = 1f-3, β = (5f-1, 9f-1), ϵ = eps(typeof(η))) = OADAM{typeof(η)}(η

init(o::OADAM, x::AbstractArray) = (zero(x), zero(x), o.beta, zero(x))

(o::OADAM)(state, m, dm) = update(o, state, m, dm)
(o::OADAM)(m, dm, state) = update(o, m, dm, state)

function apply(o::OADAM, state, x, dx)
function apply(o::OADAM, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

mt, vt, βt, dx_ = state
Expand All @@ -265,7 +265,7 @@ function apply(o::OADAM, state, x, dx)
@. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
dx = @. dx + 2*dx_

return (mt, vt, βt .* β, dx_), dx
return dx, (mt, vt, βt .* β, dx_)
end

"""
Expand All @@ -289,16 +289,16 @@ ADAGrad(η = 1f-1, ϵ = eps(typeof(η))) = ADAGrad{typeof(η)}(η, ϵ)

init(o::ADAGrad, x::AbstractArray) = fill!(similar(x), o.epsilon)

(o::ADAGrad)(state, m, dm) = update(o, state, m, dm)
(o::ADAGrad)(m, dm, state) = update(o, m, dm, state)

function apply(o::ADAGrad, state, x, dx)
function apply(o::ADAGrad, x, dx, state)
η, ϵ = o.eta, o.epsilon
acc = state

@. acc += dx^2
dx = @. dx * η / (sqrt(acc) + ϵ)

return acc, dx
return dx, acc
end

"""
Expand All @@ -321,9 +321,9 @@ ADADelta(ρ = 9f-1, ϵ = eps(typeof(ρ))) = ADADelta{typeof(ρ)}(ρ, ϵ)

init(o::ADADelta, x::AbstractArray) = (zero(x), zero(x))

(o::ADADelta)(state, m, dm) = update(o, state, m, dm)
(o::ADADelta)(m, dm, state) = update(o, m, dm, state)

function apply(o::ADADelta, state, x, dx)
function apply(o::ADADelta, x, dx, state)
ρ, ϵ = o.rho, o.epsilon
acc, Δacc = state

Expand All @@ -333,7 +333,7 @@ function apply(o::ADADelta, state, x, dx)
dx = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
@. Δacc = ρ * Δacc + (1 - ρ) * dx^2

return (acc, Δacc), dx
return dx, (acc, Δacc)
end

"""
Expand All @@ -360,9 +360,9 @@ AMSGrad(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AMSGrad{typeof(
init(o::AMSGrad, x::AbstractArray) =
(fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon))

(o::AMSGrad)(state, m, dm) = update(o, state, m, dm)
(o::AMSGrad)(m, dm, state) = update(o, m, dm, state)

function apply(o::AMSGrad, state, x, dx)
function apply(o::AMSGrad, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

mt, vt, v̂t = state
Expand All @@ -372,7 +372,7 @@ function apply(o::AMSGrad, state, x, dx)
@. v̂t = max(v̂t, vt)
dx = @. η * mt / (sqrt(v̂t) + ϵ)

return (mt, vt, v̂t), dx
return dx, (mt, vt, v̂t)
end

"""
Expand All @@ -398,9 +398,9 @@ NADAM(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = NADAM{typeof(η)}

init(o::NADAM, x::AbstractArray) = (zero(x), zero(x), o.beta)

(o::NADAM)(state, m, dm) = update(o, state, m, dm)
(o::NADAM)(m, dm, state) = update(o, m, dm, state)

function apply(o::NADAM, state, x, dx)
function apply(o::NADAM, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon

mt, vt, βt = state
Expand All @@ -410,7 +410,7 @@ function apply(o::NADAM, state, x, dx)
dx = @. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η

return (mt, vt, βt .* β), dx
return dx, (mt, vt, βt .* β)
end

"""
Expand Down Expand Up @@ -454,17 +454,17 @@ AdaBelief(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) = AdaBelief{typ

init(o::AdaBelief, x::AbstractArray) = (zero(x), zero(x))

(o::AdaBelief)(state, m, dm) = update(o, state, m, dm)
(o::AdaBelief)(m, dm, state) = update(o, m, dm, state)

function apply(o::AdaBelief, state, x, dx)
function apply(o::AdaBelief, x, dx, state)
η, β, ϵ = o.eta, o.beta, o.epsilon
mt, st = state

@. mt = β[1] * mt + (1 - β[1]) * dx
@. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
dx = @. η * mt / (sqrt(st) + ϵ)

return (mt, st), dx
return dx, (mt, st)
end

"""
Expand All @@ -482,12 +482,12 @@ WeightDecay() = WeightDecay(5f-4)

init(o::WeightDecay, x::AbstractArray) = nothing

(o::WeightDecay)(state, m, dm) = update(o, state, m, dm)
(o::WeightDecay)(m, dm, state) = update(o, m, dm, state)

function apply(o::WeightDecay, state, x, dx)
function apply(o::WeightDecay, x, dx, state)
dx = @. dx + o.wd * x

return state, dx
return dx, state
end

"""
Expand All @@ -503,15 +503,15 @@ OptimiserChain(opts...) = OptimiserChain(opts)

init(o::OptimiserChain, x::AbstractArray) = [init(opt, x) for opt in o.opts]

(o::OptimiserChain)(state, m, dms...) = update(o, state, m, dms...)
(o::OptimiserChain)(m, dm, states) = update(o, m, dm, states)

function apply(o::OptimiserChain, states, x, dx, dxs...)
function apply(o::OptimiserChain, x, dx, states)
new_states = similar(states)
for (i, (opt, state)) in enumerate(zip(o.opts, states))
new_states[i], dx = apply(opt, state, x, dx, dxs...)
dx, new_states[i] = apply(opt, x, dx, state)
end

return new_states, dx
return dx, new_states
end

for Opt in (:Descent, :ADAM, :Momentum, :Nesterov, :RMSProp,
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ using Statistics
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
l = loss(w, w′)
for i = 1:10^4
gs = gradient(x -> loss(x, w′), w)
st, w = o(st, w, gs...)
gs, = gradient(x -> loss(x, w′), w)
w, st = o(w, gs, st)
end
@test loss(w, w′) < 0.01
end
Expand All @@ -29,8 +29,8 @@ end
st = Optimisers.state(opt, w)
for t = 1:10^5
x = rand(10)
gs = gradient(w -> loss(x, w, w′), w)
st, w = Optimisers.update(opt, st, w, gs...)
gs, = gradient(w -> loss(x, w, w′), w)
w, st = Optimisers.update(opt, w, gs, st)
end
@test loss(rand(10, 10), w, w′) < 0.01
end
Expand Down