Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow no grad option for reactant #1190

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/ConditionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
start_time = time()
for (i, X) in enumerate(train_dataloader)
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, X, train_state)
AutoEnzyme(), loss_function, X, train_state; return_gradients=Val(false)
)

loss_total += loss
total_samples += size(X, ndims(X))
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber
using Setfield: @set!
using Static: False
using Static: True, False

using Lux: Lux, LuxOps, Training, Utils
using Lux.Training: TrainingBackendCache, ReactantBackend
Expand Down
27 changes: 18 additions & 9 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Lux.Training.compute_gradients_impl(
end

function Lux.Training.compute_gradients_impl(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
grads, loss, stats, st = ts.cache.extras.compiled_gradient_function(
obj_fn, ts.model, data, ts.parameters, ts.states)
@set! ts.states = st
Expand All @@ -70,7 +70,7 @@ for inplace in ("!", "")

# Ideally users never hit this dispatch but it is still good to have as a fallback
@eval function Lux.Training.$(apply_gradients_fn)(
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}}, grads
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}}, grads
)
if hasfield(typeof(ts.cache.extras), :update_function)
update_function = ts.cache.extras.update_function
Expand All @@ -94,15 +94,15 @@ for inplace in ("!", "")
@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data,
ts.parameters, ts.states, ts.optimizer_state)
ts.parameters, ts.states, ts.optimizer_state, backend.return_gradients)

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)
ts.optimizer_state, backend.return_gradients)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)
ts.optimizer_state, backend.return_gradients)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
Expand All @@ -116,10 +116,11 @@ for inplace in ("!", "")
return grads, loss, stats, ts
end

@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@eval function Lux.Training.$(fname)(backend::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend}, F}) where {F}
grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)
obj_fn, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state, backend.return_gradients)

@set! ts.states = st
@set! ts.parameters = ps
Expand All @@ -131,7 +132,15 @@ for inplace in ("!", "")

# XXX: Inplace version not actually inplace
@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state) where {F}
objective_function::F, model, data, ps, st, opt_state, ::False) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
return nothing, ps, loss, stats, stₙ, opt_state
end

@eval function $(internal_fn)(
objective_function::F, model, data, ps, st, opt_state, ::True) where {F}
dps, loss, stats, stₙ = compute_gradients_internal(
objective_function, model, data, ps, st)
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
Expand Down
45 changes: 32 additions & 13 deletions src/helpers/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using FastClosures: @closure
using Functors: Functors, fmap
using Optimisers: Optimisers
using Setfield: @set!
using Static: StaticBool, Static, False, True
using Static: StaticBool, Static, False, True, static

using ..Lux: Lux, Utils, ReactantCompatibleOptimisers
using LuxCore: LuxCore, AbstractLuxLayer
Expand Down Expand Up @@ -104,7 +104,9 @@ function Base.show(io::IO, ::MIME"text/plain", ts::TrainState)
print(io, "\n objective_function: ", nameof(typeof(ts.objective_function)))
end

struct ReactantBackend end
@concrete struct ReactantBackend
return_gradients <: StaticBool
end

const APPLY_GRAD_DOCSTRING = """
## Arguments
Expand Down Expand Up @@ -198,10 +200,13 @@ function compute_gradients(ad, obj_fn::F, data, ts::TrainState) where {F}
return compute_gradients_impl(maybe_wrap_adtype(ad, dev_type), obj_fn, data, ts)
end

maybe_wrap_adtype(backend::ReactantBackend, _) = backend
maybe_wrap_adtype(ad::AbstractADType, _) = ad
function maybe_wrap_adtype(ad::AbstractADType, ::Type{ReactantDevice})
ad isa AutoEnzyme && return ReactantBackend()
maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend
maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
function maybe_wrap_adtype(
ad::AbstractADType, ::Type{ReactantDevice};
return_gradients::Utils.BoolType=True()
)
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients))
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
Enzyme.jl (`AutoEnzyme`)."))
end
Expand Down Expand Up @@ -258,39 +263,53 @@ function wrap_objective_function(
end

"""
single_train_step!(backend, obj_fn::F, data, ts::TrainState)
single_train_step!(backend, obj_fn::F, data, ts::TrainState; return_gradients=True())

Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients!`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.

## Keyword Arguments

- `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned
gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend.

## Return

Returned values are the same as [`compute_gradients`](@ref). Note that despite the `!`,
only the parameters in `ts` are updated inplace. Users should be using the returned `ts`
object for further training steps, else there is no caching and performance will be
suboptimal (and absolutely terrible for backends like `AutoReactant`).
"""
function single_train_step!(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
function single_train_step!(backend, obj_fn::F, data, ts::TrainState;
return_gradients::Utils.BoolType=True()) where {F}
backend = maybe_wrap_adtype(
backend, get_device_type((ts.parameters, ts.states)); return_gradients)
return single_train_step_impl!(backend, obj_fn, data, ts)
end

"""
single_train_step(backend, obj_fn::F, data, ts::TrainState)
single_train_step(backend, obj_fn::F, data, ts::TrainState; return_gradients=True())

Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
updates the parameters using [`apply_gradients`](@ref). All backends supported via
[`compute_gradients`](@ref) are supported here.

In most cases you should use [`single_train_step!`](@ref) instead of this function.

## Keyword Arguments

- `return_gradients`: If `True()`, the gradients are returned. If `False()`, the returned
gradients are `nothing`. Defaults to `True()`. This is only used for Reactant Backend.

## Return

Returned values are the same as [`compute_gradients`](@ref).
Returned values are the same as [`single_train_step!`](@ref).
"""
function single_train_step(backend, obj_fn::F, data, ts::TrainState) where {F}
backend = maybe_wrap_adtype(backend, get_device_type((ts.parameters, ts.states)))
function single_train_step(backend, obj_fn::F, data, ts::TrainState;
return_gradients::Utils.BoolType=True()) where {F}
backend = maybe_wrap_adtype(
backend, get_device_type((ts.parameters, ts.states)); return_gradients)
return single_train_step_impl(backend, obj_fn, data, ts)
end

Expand Down
Loading