Skip to content

Commit

Permalink
Implement a caching based version for the training
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 13, 2024
1 parent 19ad710 commit 43a3bef
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 20 deletions.
96 changes: 81 additions & 15 deletions ext/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,92 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using Enzyme: Enzyme
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated
using Lux: Lux
using Setfield: @set!

@concrete struct CachedEnzymeExtras
dparameters
forward
reverse
end

# Case I: We have CachedEnzymeExtras and objective_function is unchanged.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras, F}) where {F}
Lux.__recursive_make_zero!(ts.cache.dparameters)
loss, st_new, stats = __compute_gradients!(
ts.cache.forward, ts.cache.reverse, objective_function,
ts.model, ts.parameters, ts.cache.dparameters, ts.states, data)
ts_new = __construct_new_trainstate(
st_new, ts.states, ts.cache.forward, ts.cache.reverse,
ts, objective_function, ts.cache.dparameters)
return ts.cache.dparameters, loss, stats, ts_new
end

# Case II: We have CachedEnzymeExtras and objective_function is changed.
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState{<:CachedEnzymeExtras}) where {F}
forward, reverse = Enzyme.autodiff_thunk(
Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)},
Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)},
Const{typeof(ts.states)}, Const{typeof(data)})

Lux.__recursive_make_zero!(ts.cache.dparameters)
loss, st_new, stats = __compute_gradients!(
forward, reverse, objective_function, ts.model,
ts.parameters, ts.cache.dparameters, ts.states, data)

ts_new = __construct_new_trainstate(
st_new, ts.states, forward, reverse, ts, objective_function, ts.cache.dparameters)
return ts.cache.dparameters, loss, stats, ts_new
end

# Case III: Nothing is cached
function Lux.Experimental.compute_gradients(::AutoEnzyme, objective_function::F, data,
ts::Lux.Experimental.TrainState) where {F}
dps = Enzyme.make_zero(ts.parameters)
fwd, rev = Enzyme.autodiff_thunk(
Enzyme.ReverseSplitWithPrimal, Enzyme.Const{typeof(objective_function)},
Enzyme.Active, Enzyme.Const{typeof(ts.model)},
Enzyme.Duplicated{typeof(ts.parameters)},
Enzyme.Const{typeof(ts.states)}, Enzyme.Const{typeof(data)})
tape, (loss, st_new, stats), shadow_result = fwd(
Enzyme.Const(objective_function), Enzyme.Const(ts.model),
Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data))
rev(Enzyme.Const(objective_function), Enzyme.Const(ts.model),
Enzyme.Duplicated(ts.parameters, dps), Enzyme.Const(ts.states), Enzyme.Const(data),
(one(loss), Enzyme.make_zero(st_new), Enzyme.make_zero(stats)), tape)
@set! ts.states = st_new
return dps, loss, stats, ts
dps = Lux.__recursive_make_zero(ts.parameters)
forward, reverse = Enzyme.autodiff_thunk(
Enzyme.ReverseSplitWithPrimal, Const{typeof(objective_function)},
Active, Const{typeof(ts.model)}, Duplicated{typeof(ts.parameters)},
Const{typeof(ts.states)}, Const{typeof(data)})

loss, st_new, stats = __compute_gradients!(
forward, reverse, objective_function, ts.model, ts.parameters, dps, ts.states, data)
ts_new = __construct_new_trainstate(
st_new, ts.states, forward, reverse, ts, objective_function, dps)
return dps, loss, stats, ts_new
end

function __compute_gradients!(
forward::F, reverse::R, obj_fn::O, model, ps, dps, st, data) where {F, R, O}
pps = Duplicated(ps, dps)
args = (Const(obj_fn), Const(model), pps, Const(st), Const(data))
tape, (loss, st_new, stats), shadow_result = forward(args...)
reverse(args...,
(one(loss), Lux.__recursive_make_zero(st_new), Lux.__recursive_make_zero(stats)),
tape)
return loss, st_new, stats
end

# If `st_new` is of a new type, we will have to recompute the cache anyway. Force it
# my not storing the objective function.
function __construct_new_trainstate(
st_new::S, ::S, forward::F, reverse::R, ts::Lux.Experimental.TrainState,
objective_fn::O, dps) where {S, F, R, O}
cache = CachedEnzymeExtras(dps, forward, reverse)
return Lux.Experimental.TrainState(
cache, ts.objective_function, ts.model, ts.parameters,
st_new, ts.optimizer_state, ts.step + 1)
end

function __construct_new_trainstate(
st_new, _, forward::F, reverse::R, ts::Lux.Experimental.TrainState,
objective_fn::O, dps) where {F, R, O}
cache = CachedEnzymeExtras(dps, nothing, nothing)
return Lux.Experimental.TrainState(
cache, nothing, ts.model, ts.parameters, st_new, ts.optimizer_state, ts.step + 1)
end

end
16 changes: 12 additions & 4 deletions ext/LuxOptimisersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,18 @@ function Lux.Experimental.TrainState(
return Lux.Experimental.TrainState(nothing, nothing, model, ps, st, st_opt, 0)
end

function Lux.Experimental.apply_gradients(ts::Lux.Experimental.TrainState, grads)
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model,
ps, ts.states, optimizer_state, ts.step + 1)
function Lux.Experimental.apply_gradients(
ts::Lux.Experimental.TrainState, grads, update_inplace=false)
if update_inplace
optimizer_state, ps = Optimisers.update(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(ts.cache, ts.objective_function, ts.model,
ps, ts.states, optimizer_state, ts.step + 1)
else
Optimisers.update!(ts.optimizer_state, ts.parameters, grads)
return Lux.Experimental.TrainState(
ts.cache, ts.objective_function, ts.model, ts.parameters,
ts.states, ts.optimizer_state, ts.step + 1)
end
end

# DistributedUtils
Expand Down
10 changes: 9 additions & 1 deletion src/contrib/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ Internal fields:
end

"""
apply_gradients(ts::TrainState, grads)
apply_gradients(ts::TrainState, grads, update_inplace::Bool=false)
Update the parameters stored in `ts` using the gradients `grads`.
## Arguments
- `ts`: [`TrainState`](@ref) object.
- `grads`: Gradients of the loss function wrt `ts.params`.
- `update_inplace`: Whether to update the parameters inplace or not.
## Returns
Expand Down Expand Up @@ -73,6 +74,13 @@ A 4-Tuple containing:
- `loss`: Loss from the objective function.
- `stats`: Any computed statistics from the objective function.
- `ts`: Updated Training State.
!!! danger
`grads` returned by this function might be aliased by the implementation of the gradient
backend. For example, if you cache the `grads` from step `i`, the new gradients
returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
"""
function compute_gradients(ad::ADTypes.AbstractADType, ::F, _, ::TrainState) where {F}
return __maybe_implemented_compute_gradients(ad)
Expand Down
17 changes: 17 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,20 @@ end

@inline __size(x::AbstractArray) = size(x)
@inline __size(x::T) where {T} = hasmethod(size, Tuple{T}) ? size(x) : nothing

@inline __recursive_make_zero(x::AbstractArray{<:Number}) = zero(x)
@inline __recursive_make_zero(x::AbstractArray) = map(__recursive_make_zero, x)
@inline __recursive_make_zero(x::Tuple) = map(__recursive_make_zero, x)
@inline __recursive_make_zero(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
__recursive_make_zero, values(x)))
@inline __recursive_make_zero(::Nothing) = nothing
@inline __recursive_make_zero(v::Val) = v
@inline __recursive_make_zero(x) = fmap(__recursive_make_zero, x)

@inline __recursive_make_zero!(x::AbstractArray{<:Number}) = fill!(x, zero(eltype(x)))
@inline __recursive_make_zero!(x::AbstractArray) = map(__recursive_make_zero!, x)
@inline __recursive_make_zero!(x::Tuple) = map(__recursive_make_zero!, x)
@inline __recursive_make_zero!(x::NamedTuple{fields}) where {fields} = NamedTuple{fields}(map(
__recursive_make_zero!, values(x)))
@inline __recursive_make_zero!(::Nothing) = nothing
@inline __recursive_make_zero!(x) = fmap(__recursive_make_zero!, x)

0 comments on commit 43a3bef

Please sign in to comment.