Skip to content

Commit

Permalink
feat: allow no grad option for reactant
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 8, 2025
1 parent cab3e1f commit 32406c9
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
3 changes: 2 additions & 1 deletion examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
start_time = time()
for (i, X) in enumerate(train_dataloader)
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false)
)

loss_total += loss
total_samples += size(X, ndims(X))
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber
using Setfield: @set!
using Static: False
using Static: True, False

using Lux: Lux, LuxOps, Training, Utils
using Lux.Training: TrainingBackendCache, ReactantBackend
Expand Down
27 changes: 18 additions & 9 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Lux.Training.compute_gradients_impl(
end

function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
obj_fn, ts.model, data, ts.parameters, ts.states)
@set! ts.states = st
Expand All @@ -70,7 +70,7 @@ for inplace in ("!", "")

# Ideally users never hit this dispatch but it is still good to have as a fallback
@eval function Lux.Training.$(apply_gradients_fn)(
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}}, grads
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}}, grads
)
if hasfield(typeof(ts.cache.extras), :update_function)
update_function = ts.cache.extras.update_function
Expand All @@ -94,15 +94,15 @@ for inplace in ("!", "")
@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data,
ts.parameters, ts.states, ts.optimizer_state)
ts.parameters, ts.states, ts.optimizer_state, backend.return_gradients)

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)
ts.optimizer_state, backend.return_gradients)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)
ts.optimizer_state, backend.return_gradients)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
Expand All @@ -116,10 +116,11 @@ for inplace in ("!", "")
return grads, loss, stats, ts
end

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@eval function Lux.Training.$(fname)(backend::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)
obj_fn, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state, backend.return_gradients)

@set! ts.states = st
@set! ts.parameters = ps
Expand All @@ -131,7 +132,15 @@ for inplace in ("!", "")

# XXX: Inplace version not actually inplace
@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state) where {F}
objective_function::F, model, data, ps, st, opt_state, ::False) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
return nothing, ps, loss, stats, stₙ, opt_state
end

@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state, ::True) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
Expand Down
45 changes: 32 additions & 13 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using FastClosures: @closure
using Functors: Functors, fmap
using Optimisers: Optimisers
using Setfield: @set!
using Static: StaticBool, Static, False, True
using Static: StaticBool, Static, False, True, static

using ..Lux: Lux, Utils, ReactantCompatibleOptimisers
using LuxCore: LuxCore, AbstractLuxLayer
Expand Down Expand Up @@ -104,7 +104,9 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

struct ReactantBackend end
@concrete struct ReactantBackend
return_gradients <: StaticBool
end

const APPLY_GRAD_DOCSTRING = """
## Arguments
Expand Down Expand Up @@ -198,10 +200,13 @@ function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F}
return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts)
end

maybe_wrap_adtype(backend::ReactantBackend, _) = backend
maybe_wrap_adtype(ad::AbstractADType, _) = ad
function maybe_wrap_adtype(ad::AbstractADType, ::Type{ReactantDevice})
ad isa AutoEnzyme && return ReactantBackend()
maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend
maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
function maybe_wrap_adtype(
ad::AbstractADType, ::Type{ReactantDevice};
return_gradients::Utils.BoolType=True()
)
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients))
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
Enzyme.jl (`AutoEnzyme`)."))
end
Expand Down Expand Up @@ -258,39 +263,53 @@ function wrap_objective_function(
end

"""
single_train_step!(backend, obj_fn::F, data, ts::TrainState)
single_train_step!(backend, obj_fn::F, data, ts::TrainState; return_gradients=True())
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
## Keyword Arguments
- `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned
gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend.
## Return
Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
function single_train_step!(backend, obj_fn::F, data, ts::TrainState;
return_gradients::Utils.BoolType=True()) where {F}
backend = maybe_wrap_adtype(
backend, get_device_type((ts.parameters, ts.states)); return_gradients)
return single_train_step_impl!(backend, obj_fn, data, ts)
end

"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
single_train_step(backend, obj_fn::F, data, ts::TrainState; return_gradients=True())
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.
In most cases you should use [`single_train_step!`](@ref) instead of this function.
## Keyword Arguments
- `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned
gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend.
## Return
Returned values are the same as [`compute_gradients`](@ref).
Returned values are the same as [`single_train_step!`](@ref).
"""
function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
function single_train_step(backend, obj_fn::F, data, ts::TrainState;
return_gradients::Utils.BoolType=True()) where {F}
backend = maybe_wrap_adtype(
backend, get_device_type((ts.parameters, ts.states)); return_gradients)
return single_train_step_impl(backend, obj_fn, data, ts)
end

Expand Down

0 comments on commit 32406c9

Please sign in to comment.