From 3ca61cf768bafc7a7f1ea6f927246f6a771e72c7 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 26 Sep 2024 07:53:51 -0400 Subject: [PATCH 1/2] Copy auglag from lbfgsb.jl --- src/auglag.jl | 182 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 src/auglag.jl diff --git a/src/auglag.jl b/src/auglag.jl new file mode 100644 index 000000000..e7234334a --- /dev/null +++ b/src/auglag.jl @@ -0,0 +1,182 @@ + +SciMLBase.supports_opt_cache_interface(::LBFGS) = true +SciMLBase.allowsbounds(::LBFGS) = true +SciMLBase.requiresgradient(::LBFGS) = true +SciMLBase.allowsconstraints(::LBFGS) = true +SciMLBase.requiresconsjac(::LBFGS) = true + +function task_message_to_string(task::Vector{UInt8}) + return String(task) +end + +function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS; + callback = nothing, + maxiters::Union{Number, Nothing} = nothing, + maxtime::Union{Number, Nothing} = nothing, + abstol::Union{Number, Nothing} = nothing, + reltol::Union{Number, Nothing} = nothing, + verbose::Bool = false, + kwargs...) + if !isnothing(abstol) + @warn "common abstol is currently not used by $(opt)" + end + if !isnothing(maxtime) + @warn "common abstol is currently not used by $(opt)" + end + + mapped_args = (;) + + if cache.lb !== nothing && cache.ub !== nothing + mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub) + end + + if !isnothing(maxiters) + mapped_args = (; mapped_args..., maxiter = maxiters) + end + + if !isnothing(reltol) + mapped_args = (; mapped_args..., pgtol = reltol) + end + + return mapped_args +end + +function SciMLBase.__solve(cache::OptimizationCache{ + F, + RC, + LB, + UB, + LC, + UC, + S, + O, + D, + P, + C +}) where { + F, + RC, + LB, + UB, + LC, + UC, + S, + O <: + LBFGS, + D, + P, + C +} +maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters) + +local x + +solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...) + +if !isnothing(cache.f.cons) + eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)] + ineq_inds = (!).(eq_inds) + + τ = cache.opt.τ + γ = cache.opt.γ + λmin = cache.opt.λmin + λmax = cache.opt.λmax + μmin = cache.opt.μmin + μmax = cache.opt.μmax + ϵ = cache.opt.ϵ + + λ = zeros(eltype(cache.u0), sum(eq_inds)) + μ = zeros(eltype(cache.u0), sum(ineq_inds)) + + cons_tmp = zeros(eltype(cache.u0), length(cache.lcons)) + cache.f.cons(cons_tmp, cache.u0) + ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp))) + + _loss = function (θ) + x = cache.f(θ, cache.p) + cons_tmp .= zero(eltype(θ)) + cache.f.cons(cons_tmp, θ) + cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] + cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] + opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + if cache.callback(opt_state, x...) + error("Optimization halted by callback.") + end + return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) + + 1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+ (ρ .* cons_tmp[ineq_inds]))) .^ 2) + end + + prev_eqcons = zero(λ) + θ = cache.u0 + β = max.(cons_tmp[ineq_inds], Ref(0.0)) + prevβ = zero(β) + eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] + ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] + eqidxs = eqidxs[eqidxs .!= nothing] + ineqidxs = ineqidxs[ineqidxs .!= nothing] + function aug_grad(G, θ) + cache.f.grad(G, θ) + if !isnothing(cache.f.cons_jac_prototype) + J = Float64.(cache.f.cons_jac_prototype) + else + J = zeros((length(cache.lcons), length(θ))) + end + cache.f.cons_j(J, θ) + __tmp = zero(cons_tmp) + cache.f.cons(__tmp, θ) + __tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds] + __tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds] + G .+= sum( + λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :]) + for (i, idx) in enumerate(eqidxs); + init = zero(G)) #should be jvp + G .+= sum( + 1 / ρ * (max.(Ref(0.0), μ[i] .+ (ρ .* __tmp[idx])) .* J[idx, :]) + for (i, idx) in enumerate(ineqidxs); + init = zero(G)) #should be jvp + end + + opt_ret = ReturnCode.MaxIters + n = length(cache.u0) + + sol = solve(....) + + solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing)) + + for i in 1:maxiters + prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds] + prevβ .= copy(β) + res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs..., + m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100) + # @show res[2] + # @show res[1] + # @show cons_tmp + # @show λ + # @show β + # @show μ + # @show ρ + θ = res[2] + cons_tmp .= 0.0 + cache.f.cons(cons_tmp, θ) + λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin) + β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ) + μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin)) + if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) > + τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf)) + ρ = γ * ρ + end + if norm( + (cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) < + ϵ && norm(β, Inf) < ϵ + opt_ret = ReturnCode.Success + break + end + end +end + +stats = Optimization.OptimizationStats(; iterations = maxiters, + time = 0.0, fevals = maxiters, gevals = maxiters) +return SciMLBase.build_solution( + cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], + stats = stats, retcode = opt_ret) +end \ No newline at end of file From b4287a1a060aef155476f55894f9e0b3af3169bc Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Fri, 25 Oct 2024 23:08:41 -0400 Subject: [PATCH 2/2] Work with stochastic optimizers too --- src/Optimization.jl | 1 + src/auglag.jl | 63 +++++++++++++++++++++------------------------ src/lbfgsb.jl | 2 +- test/lbfgsb.jl | 28 ++++++++++++++++++++ 4 files changed, 60 insertions(+), 34 deletions(-) diff --git a/src/Optimization.jl b/src/Optimization.jl index 4cfeead6e..8d0257dd1 100644 --- a/src/Optimization.jl +++ b/src/Optimization.jl @@ -24,6 +24,7 @@ include("utils.jl") include("state.jl") include("lbfgsb.jl") include("sophia.jl") +include("auglag.jl") export solve diff --git a/src/auglag.jl b/src/auglag.jl index e7234334a..390659c8e 100644 --- a/src/auglag.jl +++ b/src/auglag.jl @@ -1,15 +1,21 @@ - -SciMLBase.supports_opt_cache_interface(::LBFGS) = true -SciMLBase.allowsbounds(::LBFGS) = true -SciMLBase.requiresgradient(::LBFGS) = true -SciMLBase.allowsconstraints(::LBFGS) = true -SciMLBase.requiresconsjac(::LBFGS) = true - -function task_message_to_string(task::Vector{UInt8}) - return String(task) +@kwdef struct AugLag + inner + τ = 0.5 + γ = 10.0 + λmin = -1e20 + λmax = 1e20 + μmin = 0.0 + μmax = 1e20 + ϵ = 1e-8 end -function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS; +SciMLBase.supports_opt_cache_interface(::AugLag) = true +SciMLBase.allowsbounds(::AugLag) = true +SciMLBase.requiresgradient(::AugLag) = true +SciMLBase.allowsconstraints(::AugLag) = true +SciMLBase.requiresconsjac(::AugLag) = true + +function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag; callback = nothing, maxiters::Union{Number, Nothing} = nothing, maxtime::Union{Number, Nothing} = nothing, @@ -62,7 +68,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ UC, S, O <: - LBFGS, + AugLag, D, P, C @@ -90,10 +96,10 @@ if !isnothing(cache.f.cons) cons_tmp = zeros(eltype(cache.u0), length(cache.lcons)) cache.f.cons(cons_tmp, cache.u0) - ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp))) + ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp))) - _loss = function (θ) - x = cache.f(θ, cache.p) + _loss = function (θ, p = cache.p) + x = cache.f(θ, p) cons_tmp .= zero(eltype(θ)) cache.f.cons(cons_tmp, θ) cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] @@ -114,8 +120,8 @@ if !isnothing(cache.f.cons) ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)] eqidxs = eqidxs[eqidxs .!= nothing] ineqidxs = ineqidxs[ineqidxs .!= nothing] - function aug_grad(G, θ) - cache.f.grad(G, θ) + function aug_grad(G, θ, p) + cache.f.grad(G, θ, p) if !isnothing(cache.f.cons_jac_prototype) J = Float64.(cache.f.cons_jac_prototype) else @@ -139,23 +145,15 @@ if !isnothing(cache.f.cons) opt_ret = ReturnCode.MaxIters n = length(cache.u0) - sol = solve(....) + augprob = OptimizationProblem(OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p) solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing)) - for i in 1:maxiters + for i in 1:(maxiters/10) prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds] prevβ .= copy(β) - res = optimizer(_loss, aug_grad, θ, bounds; solver_kwargs..., - m = cache.opt.m, pgtol = sqrt(ϵ), maxiter = maxiters / 100) - # @show res[2] - # @show res[1] - # @show cons_tmp - # @show λ - # @show β - # @show μ - # @show ρ - θ = res[2] + res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10) + θ = res.u cons_tmp .= 0.0 cache.f.cons(cons_tmp, θ) λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin) @@ -172,11 +170,10 @@ if !isnothing(cache.f.cons) break end end -end - -stats = Optimization.OptimizationStats(; iterations = maxiters, + stats = Optimization.OptimizationStats(; iterations = maxiters, time = 0.0, fevals = maxiters, gevals = maxiters) -return SciMLBase.build_solution( - cache, cache.opt, res[2], cache.f(res[2], cache.p)[1], + return SciMLBase.build_solution( + cache, cache.opt, θ, x, stats = stats, retcode = opt_ret) +end end \ No newline at end of file diff --git a/src/lbfgsb.jl b/src/lbfgsb.jl index 514b20666..f3d3c79c7 100644 --- a/src/lbfgsb.jl +++ b/src/lbfgsb.jl @@ -45,7 +45,7 @@ function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::LBFGS; @warn "common abstol is currently not used by $(opt)" end if !isnothing(maxtime) - @warn "common abstol is currently not used by $(opt)" + @warn "common maxtime is currently not used by $(opt)" end mapped_args = (;) diff --git a/test/lbfgsb.jl b/test/lbfgsb.jl index 2b5ec1691..b981cc2f2 100644 --- a/test/lbfgsb.jl +++ b/test/lbfgsb.jl @@ -25,3 +25,31 @@ prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ub = [1.0, 1.0]) @time res = solve(prob, Optimization.LBFGS(), maxiters = 100) @test res.retcode == SciMLBase.ReturnCode.Success + +using MLUtils, OptimizationOptimisers + +x0 = -pi:0.001:pi +y0 = sin.(x0) +data = MLUtils.DataLoader((x0, y0), batchsize = 100) +function loss(coeffs, data) + ypred = [evalpoly(data[1][i], coeffs) for i in eachindex(data[1])] + return sum(abs2, ypred .- data[2]) +end + +function cons1(res, coeffs, p = nothing) + res[1] = coeffs[1] * coeffs[5] - 1 + return nothing +end + +optf = OptimizationFunction(loss, AutoSparseForwardDiff(), cons = cons1) +callback = (st, l) -> (@show l; return false) + +prob = OptimizationProblem(optf, rand(5), (x0, y0), lcons = [-0.5], ucons = [0.5], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) +opt1 = solve(prob, Optimization.LBFGS(), maxiters = 1000, callback = callback) + +prob = OptimizationProblem(optf, rand(5), data, lcons = [0.0], ucons = [0.0], lb = [-10.0, -10.0, -10.0, -10.0, -10.0], ub = [10.0, 10.0, 10.0, 10.0, 10.0]) +opt = solve(prob, Optimization.AugLag(; inner = Adam()), maxiters = 500, callback = callback) + +optf1 = OptimizationFunction(loss, AutoSparseForwardDiff()) +prob1 = OptimizationProblem(optf1, rand(5), data) +sol1 = solve(prob1, OptimizationOptimisers.Adam(), maxiters = 1000, callback = callback)