Skip to content

Commit

Permalink
Merge pull request #116 from SciML/DIv6
Browse files Browse the repository at this point in the history
fix symbolic analysis dispatches
  • Loading branch information
Vaibhavdixit02 authored Oct 3, 2024
2 parents 0edc12b + 8acf17c commit 1cb8a90
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 70 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "2.2.0"
version = "2.2.1"


[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
117 changes: 63 additions & 54 deletions ext/OptimizationSymbolicAnalysisExt.jl
Original file line number Diff line number Diff line change
@@ -1,79 +1,84 @@
module OptimizationSymbolicAnalysisExt

using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics
using OptimizationBase, SciMLBase, SymbolicAnalysis, SymbolicAnalysis.Symbolics,
OptimizationBase.ArrayInterface
using SymbolicAnalysis: AnalysisResult
import Symbolics: variable, Equation, Inequality, unwrap, @variables
import SymbolicAnalysis.Symbolics: variable, Equation, Inequality, unwrap, @variables

function OptimizationBase.symify_cache(
f::OptimizationFunction{iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP,
CJP, CHP, O, EX, CEX, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV},
prob) where {iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
prob, num_cons,
manifold) where {
iip, AD, F, G, FG, H, FGH, HV, C, CJ, CJV, CVJ, CH, HP, CJP, CHP, O,
EX <: Nothing, CEX <: Nothing, SYS, LH, LHP, HCV, CJCV, CHCV, LHCV}
try
vars = if prob.u0 isa Matrix
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
else
ArrayInterface.restructure(
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
end
params = if prob.p isa SciMLBase.NullParameters
[]
elseif prob.p isa MTK.MTKParameters
[variable(, i) for i in eachindex(vcat(p...))]
else
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
end
obj_expr = f.expr
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))

if prob.u0 isa Matrix
vars = vars[1]
end
if obj_expr === nothing || cons_expr === nothing
try
vars = if prob.u0 isa Matrix
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
else
ArrayInterface.restructure(
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
end
params = if prob.p isa SciMLBase.NullParameters
[]
elseif prob.p isa MTK.MTKParameters
[variable(, i) for i in eachindex(vcat(p...))]
else
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
end

if prob.u0 isa Matrix
vars = vars[1]
end

obj_expr = f.f(vars, params)
if obj_expr === nothing
obj_expr = f.f(vars, params)
end

if SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
lhs = Array{Symbolics.Num}(undef, num_cons)
f.cons(lhs, vars)
cons = Union{Equation, Inequality}[]
if cons_expr === nothing && SciMLBase.isinplace(prob) && !isnothing(prob.f.cons)
lhs = Array{Symbolics.Num}(undef, num_cons)
f.cons(lhs, vars)
cons = Union{Equation, Inequality}[]

if !isnothing(prob.lcons)
for i in 1:num_cons
if !isinf(prob.lcons[i])
if prob.lcons[i] != prob.ucons[i]
push!(cons, prob.lcons[i] lhs[i])
else
push!(cons, lhs[i] ~ prob.ucons[i])
if !isnothing(prob.lcons)
for i in 1:num_cons
if !isinf(prob.lcons[i])
if prob.lcons[i] != prob.ucons[i]
push!(cons, prob.lcons[i] lhs[i])
else
push!(cons, lhs[i] ~ prob.ucons[i])
end
end
end
end
end

if !isnothing(prob.ucons)
for i in 1:num_cons
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
push!(cons, lhs[i] prob.ucons[i])
if !isnothing(prob.ucons)
for i in 1:num_cons
if !isinf(prob.ucons[i]) && prob.lcons[i] != prob.ucons[i]
push!(cons, lhs[i] prob.ucons[i])
end
end
end
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
(isnothing(prob.ucons) || all(isinf, prob.ucons))
throw(ArgumentError("Constraints passed have no proper bounds defined.
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
or pass the lower and upper bounds for inequality constraints."))
end
cons_expr = lhs
elseif cons_expr === nothing && !isnothing(prob.f.cons)
cons_expr = f.cons(vars, params)
end
if (isnothing(prob.lcons) || all(isinf, prob.lcons)) &&
(isnothing(prob.ucons) || all(isinf, prob.ucons))
throw(ArgumentError("Constraints passed have no proper bounds defined.
Ensure you pass equal bounds (the scalar that the constraint should evaluate to) for equality constraints
or pass the lower and upper bounds for inequality constraints."))
end
cons_expr = lhs
elseif !isnothing(prob.f.cons)
cons_expr = f.cons(vars, params)
else
cons_expr = nothing
catch err
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
end
catch err
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
end
return obj_expr, cons_expr
end

function analysis(obj_expr, cons_expr)
if obj_expr !== nothing
obj_expr = obj_expr |> Symbolics.unwrap
if manifold === nothing
Expand All @@ -85,6 +90,8 @@ function analysis(obj_expr, cons_expr)
if obj_res.gcurvature !== nothing
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
end
else
obj_res = nothing
end

if cons_expr !== nothing
Expand All @@ -101,6 +108,8 @@ function analysis(obj_expr, cons_expr)
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
end
end
else
cons_res = nothing
end

return obj_res, cons_res
Expand Down
7 changes: 1 addition & 6 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt;
cons_vjp = SciMLBase.allowsconsjvp(opt), cons_jvp = SciMLBase.allowsconsjvp(opt), lag_h = SciMLBase.requireslagh(opt))

if structural_analysis
obj_expr, cons_expr = symify_cache(f, prob)
try
obj_res, cons_res = analysis(obj_expr, cons_expr)
catch err
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
end
obj_res, cons_res = symify_cache(f, prob, num_cons, manifold)
else
obj_res = nothing
cons_res = nothing
Expand Down
7 changes: 2 additions & 5 deletions src/symify.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
function symify_cache(f::OptimizationFunction, prob)
obj_expr = f.expr
cons_expr = f.cons_expr === nothing ? nothing : getfield.(f.cons_expr, Ref(:lhs))

return obj_expr, cons_expr
function symify_cache(f::OptimizationFunction, prob, num_cons, manifold)
throw("Structural analysis requires SymbolicAnalysis.jl to be loaded, either add `using SymbolicAnalysis` to your script or set `structural_analysis = false`.")
end
9 changes: 9 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationManopt = "e57b7fff-7ee7-4550-b4f0-90e9476e9fb6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SymbolicAnalysis = "4297ee4d-0239-47d8-ba5d-195ecdf594fe"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -35,4 +40,8 @@ Lux = ">= 0.4.50"
Manifolds = "0.9"
Optim = ">= 1.4.1"
Optimisers = ">= 0.2.5"
Optimization = "4"
OptimizationManopt = "0.0.4"
SparseConnectivityTracer = "0.6"
SymbolicAnalysis = "0.3.0"
SafeTestsets = ">= 0.0.1"
6 changes: 3 additions & 3 deletions test/cvxtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ optf = OptimizationFunction(rosenbrock, AutoZygote(), cons = con2_c)
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ucons = [1.0, 0.0],
lb = [-1.0, -1.0], ub = [1.0, 1.0], structural_analysis = true)
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
@test res.cache.analysis_results.objective.curvature == SymbolicAnalysis.Convex
@test res.cache.analysis_results.objective.curvature == SymbolicAnalysis.UnknownCurvature
@test res.cache.analysis_results.constraints[1].curvature == SymbolicAnalysis.Convex
@test res.cache.analysis_results.constraints[2].curvature ==
SymbolicAnalysis.UnknownCurvature
Expand All @@ -46,7 +46,7 @@ optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
prob = OptimizationProblem(optf, data2[1]; manifold = M, structural_analysis = true)

opt = OptimizationManopt.GradientDescentOptimizer()
@time sol = solve(prob, Optimization.LBFGS(), maxiters = 100)
@test sol.minimizer < 1e-1
@time sol = solve(prob, opt, maxiters = 100)
@test sol.minimum < 1e-1
@test sol.cache.analysis_results.objective.curvature == SymbolicAnalysis.UnknownCurvature
@test sol.cache.analysis_results.objective.gcurvature == SymbolicAnalysis.GConvex
2 changes: 1 addition & 1 deletion test/matrixvalued.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ using Test, ReverseDiff
Omega_mc = rand(4, 4) .> 0.5 # Mask for observed entries (boolean matrix)
X_mc = rand(4, 4) # Matrix to be completed
optf = OptimizationFunction{false}(
matrix_completion_objective, adtype, cons = rank_constraint)
matrix_completion_objective, adtype)
optf = OptimizationBase.instantiate_function(
optf, X_mc, adtype, (A_mc, Omega_mc), g = true, h = true)
optf.grad(X_mc)
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ using Test

@testset "OptimizationBase.jl" begin
include("adtests.jl")
include("cvxtest.jl")
include("matrixvalued.jl")
end

0 comments on commit 1cb8a90

Please sign in to comment.