Skip to content

Commit

Permalink
fix GPU errors by doing a little type piracy. will remove once FluxML…
Browse files Browse the repository at this point in the history
  • Loading branch information
vpuri3 committed Oct 3, 2024
1 parent 6c64238 commit 0a11fb5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 53 deletions.
12 changes: 4 additions & 8 deletions src/train/backend.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ function train_loop!(
opt_iter.epoch_dt[] = time() - opt_iter.epoch_time[] - opt_iter.start_time[]
trigger_callback!(trainer, :EPOCH_END)

# update state loss
@set! state.loss_ = trainer.STATS.LOSS_[end]

# save state to CPU if loss improves
state, ifbreak = update_trainer_state!(trainer, state)
ifbreak && break
Expand All @@ -67,7 +64,8 @@ function doepoch(
)
end

state.st = Lux.trainmode(state.st)
# state.st = Lux.trainmode(state.st)
@set! state.st = Lux.trainmode(state.st)

for batch in _loader
state, (l, stats) = step(trainer, state, opt, batch)
Expand Down Expand Up @@ -104,9 +102,7 @@ function step(
throw(ErrorException("Loss in NaN"))
end

# wrong l. want STATS.LOSS_[end]
state = TrainState(NN, p, st, opt_st, state.count, state.loss_)

state = TrainState(NN, p, st, opt_st)
return state, (l, stats)
end

Expand Down Expand Up @@ -138,7 +134,7 @@ function train_loop!(

function optcb(optx, l, st, stats)
evaluate(trainer, state, loaders)
state = TrainState(state.NN, optx.u, Lux.trainmode(st), state.opt_st, state.count, trainer.STATS.LOSS_[end])
state = TrainState(state.NN, optx.u, Lux.trainmode(st), state.opt_st)

if !isempty(stats) & verbose
println(io, stats)
Expand Down
22 changes: 11 additions & 11 deletions src/train/train.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#
# function train_model(
# NN::Lux.AbstractExplicitLayer,
# NN::AbstractLuxLayer,
# _data::NTuple{2, Any},
# data_::Union{Nothing,NTuple{2, Any}} = nothing;
# #
Expand All @@ -26,7 +26,7 @@
# p = nothing,
# st = nothing,
# lossfun = mse,
# device = Lux.gpu_device(),
# device = gpu_device(),
# #
# cb_epoch = nothing, # (NN, p, st) -> nothing
# ) where{M}
Expand Down Expand Up @@ -117,7 +117,7 @@
# save_trainer(trainer, dir, name_; metadata)
# end
#
# p, st = Lux.cpu_device()((p, st))
# p, st = cpu_device()((p, st))
# (NN, p, st), STATS
# end
#===============================================================#
Expand All @@ -140,7 +140,7 @@ $SIGNATURES
- `p/st`: initial model parameter, state. if nothing, initialized with `Lux.setup(rng, NN)`
"""
function train_model(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
_data::NTuple{2, Any},
data_::Union{Nothing,NTuple{2, Any}} = nothing;
#
Expand All @@ -166,7 +166,7 @@ function train_model(
p = nothing,
st = nothing,
lossfun = mse,
device = Lux.gpu_device(),
device = gpu_device(),
#
cb_epoch = nothing, # (NN, p, st) -> nothing
) where{M}
Expand Down Expand Up @@ -326,7 +326,7 @@ function train_model(
println(io, "Optimization done")
println(io, "#======================#")

p, st = Lux.cpu_device()((p, st))
p, st = cpu_device()((p, st))

(NN, p, st), STATS
end
Expand All @@ -336,7 +336,7 @@ end
$SIGNATURES
"""
function makecallback(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
_loader::Union{CuIterator, MLUtils.DataLoader},
loader_::Union{CuIterator, MLUtils.DataLoader},
lossfun;
Expand Down Expand Up @@ -510,7 +510,7 @@ Train parameters `p` to minimize `loss` using optimization strategy `opt`.
"""
function optimize(
opt::Optimisers.AbstractRule,
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractVector},
st::NamedTuple,
nepoch::Integer,
Expand Down Expand Up @@ -609,7 +609,7 @@ https://lux.csail.mit.edu/dev/tutorials/advanced/1_GravitationalWaveForm#trainin
"""
function optimize(
opt::Optim.AbstractOptimizer,
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractVector},
st::NamedTuple,
nepoch::Integer,
Expand Down Expand Up @@ -718,7 +718,7 @@ end

#===============================================================#
function savemodel!( # modifies STATS
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractVector},
st::NamedTuple,
metadata,
Expand All @@ -741,7 +741,7 @@ function savemodel!( # modifies STATS
close(statsio)

# transfer model to host device
p, st = (p, st) |> Lux.cpu_device()
p, st = (p, st) |> cpu_device()
model = NN, p, st

# training plot
Expand Down
68 changes: 38 additions & 30 deletions src/train/trainer.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#
#===============================================================#
# 1. make io async with @async or Threads.@spawn
# 2. call evaluate, update_trainer_state! every so many steps
# 3. asssert that evaluate has been called before update_trainer_state!
# 4. make data transfer in update_trainer_state! async
#===============================================================#
abstract type AbstractTrainState end

Expand All @@ -9,23 +12,25 @@ abstract type AbstractTrainState end
p
st
opt_st
count
loss_
end

# rm once https://github.com/FluxML/Optimisers.jl/pull/180
# is merged
Adapt.@adapt_structure Optimisers.Leaf

function Adapt.adapt_structure(to, state::TrainState)
p = Adapt.adapt_structure(to, state.p )
st = Adapt.adapt_structure(to, state.st)
opt_st = Adapt.adapt_structure(to, state.opt_st)

TrainState(state.NN, p, st, opt_st, state.count, state.loss_)
TrainState(state.NN, p, st, opt_st)
end

#===============================================================#
abstract type AbstractTrainer end

@concrete mutable struct Trainer <: AbstractTrainer
state # (NN, p, st, opt_st, count, patience, early_stopping)
state # (NN, p, st, opt_st)
data # (_data, data_)
batchsizes
notestdata
Expand All @@ -47,7 +52,7 @@ abstract type AbstractTrainer end
end

function Trainer(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
_data::NTuple{2, Any},
data_::Union{Nothing,NTuple{2, Any}} = nothing;
# MODEL PARAMETER/ STATES
Expand All @@ -71,7 +76,7 @@ function Trainer(
io::IO = stdout,
rng::Random.AbstractRNG = Random.default_rng(),
name::String = "model",
device = Lux.gpu_device(),
device = gpu_device(),
verbose::Bool = true,
callbacks = nothing,
)
Expand Down Expand Up @@ -138,21 +143,24 @@ function Trainer(

patience = round(Int, nepochs * patience_frac)
opt_args = (; nepochs, schedule, early_stopping, patience, return_last)
opt_iter = (; epoch = [0], start_time=[0f0], epoch_time=[0f0], epoch_dt=[0f0],)

opt_iter = (;
epoch = [0], start_time=[0f0], epoch_time=[0f0], epoch_dt=[0f0],
# early stopping
count = [0], loss_mincfg = [Inf32],
)

#========= MISC =========#

STATS = (;
EPOCH = Int[] , TIME = Float32[],
_LOSS = Float32[], LOSS_ = Float32[],
_MSE = Float32[], MSE_ = Float32[],
_MAE = Float32[], MAE_ = Float32[],

)

#==========================#
data = (; _data, data_)
state = TrainState(NN, p, st, opt_st, 0, Inf32)
state = TrainState(NN, p, st, opt_st)

Trainer(
state, data, batchsizes, notestdata,
Expand Down Expand Up @@ -182,7 +190,7 @@ function train!(
end

state = state |> device
loaders = make_dataloaders(trainer)
loaders = make_dataloaders(trainer) |> device

verbose && printstatistics(trainer, state, loaders)
evaluate(trainer, state, loaders)
Expand All @@ -194,7 +202,7 @@ function train!(
state = train_loop!(trainer, state, loaders) # loop over epochs

if opt_args.return_last
trainer.state = state |> Lux.cpu_device()
trainer.state = state |> cpu_device()
verbose && println(io, "Returning state at final iteration.")
else
state = trainer.state |> device
Expand Down Expand Up @@ -319,52 +327,52 @@ end

#============================================================#
function make_dataloaders(trainer::Trainer)
@unpack device, rng = trainer
@unpack rng = trainer
@unpack _data, data_ = trainer.data
@unpack _batchsize, batchsize_, __batchsize = trainer.batchsizes

_loader = DataLoader(_data; batchsize = _batchsize , rng, shuffle = true)
loader_ = DataLoader(data_; batchsize = batchsize_ , rng, shuffle = false)
__loader = DataLoader(_data; batchsize = __batchsize, rng, shuffle = false)

if device isa Lux.LuxCUDADevice
_loader = _loader |> CuIterator
loader_ = loader_ |> CuIterator
__loader = __loader |> CuIterator
end

(; _loader, loader_, __loader)
end

#===============================================================#
function update_trainer_state!(trainer::Trainer, state::TrainState)
@unpack opt_args, device, io, verbose = trainer
@unpack device, io, verbose = trainer
@unpack opt_args, opt_iter, STATS = trainer
ifbreak = false

transfer = if device isa LuxDeviceUtils.AbstractLuxGPUDevice
Lux.cpu_deivce()
# make this transfer async
transfer = if device isa AbstractGPUDevice
cpu_device()
else
deepcopy
end

if state.loss_ < trainer.state.loss_
# LATEST EPOCH LOSS: STATS.LOSS_[end]
# MIN-CONFIG LOSS: opt_iter.loss_mincfg[]

if STATS.LOSS_[end] < opt_iter.loss_mincfg[]
if verbose
msg = "Improvement in loss found: $(state.loss_) < $(trainer.state.loss_)\n"
msg = "Improvement in loss found: $(STATS.LOSS_[end]) < $(opt_iter.loss_mincfg[])\n"
printstyled(io, msg, color = :green)
end
state.count = 0
opt_iter.count[] = 0
trainer.state = state |> transfer
opt_iter.loss_mincfg[] = STATS.LOSS_[end] # new
else
@set! state.count = state.count + 1
opt_iter.count[] += 1
if verbose
msg = "No improvement in loss found in the last $(state.count) epochs. $(state.loss_) > $(trainer.state.loss_)\n"
msg = "No improvement in loss found in the last $(opt_iter.count[]) epochs. $(STATS.LOSS_[end]) > $(opt_iter.loss_mincfg[])\n"
printstyled(io, msg, color = :red)
end
end

if (state.count >= opt_args.patience) & opt_args.early_stopping
if (opt_iter.count[] >= opt_args.patience) & opt_args.early_stopping
if verbose
msg = "Early Stopping triggered after $(state.count) epochs of no improvement.\n"
msg = "Early Stopping triggered after $(opt_iter.count[]) epochs of no improvement.\n"
printstyled(io, msg, color = :red)
end
ifbreak = true
Expand Down Expand Up @@ -414,7 +422,7 @@ function save_trainer(
verbose && @info "Saving model at $(modelfile)"

# CHECKPOINT
opt_st = opt_st |> Lux.cpu_device()
opt_st = opt_st |> cpu_device()
jldsave(chkptfile; opt_st, STATS)
verbose && @info "Saving model at $(chkptfile)"

Expand Down
8 changes: 4 additions & 4 deletions src/train/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ Only for callbacks. Enforce this by setting Lux.testmode
- `lossfun`: loss function: (x::Array, y::Array) -> l::Real
"""
function fullbatch_metric(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractVector},
st::NamedTuple,
loader::Union{CuIterator, MLUtils.DataLoader},
loader,
lossfun,
)
N = 0
Expand Down Expand Up @@ -59,10 +59,10 @@ $SIGNATURES
"""
function statistics(
NN::Lux.AbstractExplicitLayer,
NN::AbstractLuxLayer,
p::Union{NamedTuple, AbstractVector},
st::NamedTuple,
loader::Union{CuIterator, MLUtils.DataLoader},
loader,
)
st = Lux.testmode(st)

Expand Down

0 comments on commit 0a11fb5

Please sign in to comment.