From 331ec8befba7ef3bc9600a822a6daae91f967212 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 4 Mar 2021 02:30:17 -0500 Subject: [PATCH] fix bounds translation fixes https://github.com/SciML/DiffEqFlux.jl/issues/498 --- src/train.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/train.jl b/src/train.jl index cdad55021a..aa286c05ed 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,6 +1,7 @@ -function sciml_train(loss, θ, opt, adtype::DiffEqBase.AbstractADType = GalacticOptim.AutoZygote(), args...; kwargs...) +function sciml_train(loss, θ, opt, adtype::DiffEqBase.AbstractADType = GalacticOptim.AutoZygote(), args...; + lower_bounds = nothing, upper_bounds = nothing, kwargs...) optf = GalacticOptim.OptimizationFunction((x, p) -> loss(x), adtype) optfunc = GalacticOptim.instantiate_function(optf, θ, adtype, nothing) - optprob = GalacticOptim.OptimizationProblem(optfunc, θ; kwargs...) + optprob = GalacticOptim.OptimizationProblem(optfunc, θ; lb = lower_bounds, ub = upper_bounds, kwargs...) GalacticOptim.solve(optprob, opt, args...; kwargs...) end