Skip to content

Commit

Permalink
basic load/save solution in julia JLD2 and JSON format (#104)
Browse files Browse the repository at this point in the history
* basic load/save solution in julia JLD2 format

* added functions for load/save; move to CTBase once finished

* add new struct for OCP solution in discrete form for export

* load / save in JSON format (interpolated solution)
  • Loading branch information
PierreMartinon authored Jun 4, 2024
1 parent 8a0b685 commit 5477756
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
*.jl.mem
*.jl.*.mem

# Julia data files
# Save / load solution data files
*.jld2
*.json

# System-specific files and directories generated by the BinaryProvider and BinDeps packages
# They contain absolute paths specific to the host computer, and so should not be committed
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a"
CTBase = "54762871-cc72-4466-b8e8-f6c8b58076cd"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
NLPModelsIpopt = "f4238b75-b362-5c4c-b852-0801c9a21d71"
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

Expand Down
8 changes: 8 additions & 0 deletions src/CTDirect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ using Symbolics # for optimized AD
using ADNLPModels # docp model with AD
using NLPModelsIpopt # NLP solver
using LinearAlgebra # norm
using JLD2
using JSON3
using StructTypes

# Other declarations
const __grid_size_direct() = 100
Expand All @@ -30,6 +33,9 @@ export getNLP
export setDOCPInit
export OCPSolutionFromDOCP
export OCPSolutionFromDOCP_raw
export save_OCP_solution
export load_OCP_solution
export OCP_Solution_discrete

# CTBase reexports
export @def
Expand All @@ -46,4 +52,6 @@ export remove_constraint!
export OCPInit
export plot



end
101 changes: 100 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,103 @@ function DOCP_initial_guess(docp, init::OCPInit=OCPInit())
end

return xuv
end
end



#+++ to be moved to CTBase !

#struct for interpolated ocp solution with only basic data types that can be exported as json
# +++todo:
# - pass time grid / grid size
# - add more fields from OptimalControlSolution
# - constructor to recreate OptimalControlSolution from this one
mutable struct OCP_Solution_discrete

grid_size
objective
times
#initial_time_name::Union{String, Nothing}=nothing
#final_time_name::Union{String, Nothing}=nothing
#time_name::Union{String, Nothing}=nothing
control_dimension
#control_components_names::Union{Vector{String}, Nothing}=nothing
#control_name::Union{String, Nothing}=nothing
control
state_dimension
#state_components_names::Union{Vector{String}, Nothing}=nothing
#state_name::Union{String, Nothing}=nothing
state
variable_dimension
#variable_components_names::Union{Vector{String}, Nothing}=nothing
#variable_name::Union{String, Nothing}=nothing
variable
costate
#objective::Union{Nothing, ctNumber}=nothing
#iterations::Union{Nothing, Integer}=nothing
#stopping::Union{Nothing, Symbol}=nothing # the stopping criterion
#message::Union{Nothing, String}=nothing # the message corresponding to the stopping criterion
#success::Union{Nothing, Bool}=nothing # whether or not the method has finished successfully: CN1, stagnation vs iterations max
#infos::Dict{Symbol, Any}=Dict{Symbol, Any}()
#OCP_Solution_discrete() = new() # for StructTypes / JSON

function OCP_Solution_discrete(solution::OptimalControlSolution)
solution_d = new()

# raw copy
solution_d.objective = solution.objective
solution_d.times = solution.times
solution_d.state_dimension = solution.state_dimension
solution_d.control_dimension = solution.control_dimension
solution_d.variable_dimension = solution.variable_dimension
solution_d.variable = solution.variable

# interpolate functions into vectors
# +++ ther *must* be a quicker way to do this -_-
solution_d.grid_size = length(solution_d.times) - 1
solution_d.state = zeros(solution_d.grid_size+1, solution_d.state_dimension)
solution_d.control = zeros(solution_d.grid_size+1, solution_d.control_dimension)
solution_d.costate = zeros(solution_d.grid_size+1, solution_d.state_dimension)
for i in 1:solution_d.grid_size
solution_d.state[i,:] .= solution.state(solution_d.times[i])
solution_d.control[i,:] .= solution.control(solution_d.times[i])
solution_d.costate[i,:] .= solution.costate(solution_d.times[i])
end
return solution_d
end
end

"""
$(TYPEDSIGNATURES)
Save OCP solution in JLD2/JSON format
"""
function save_OCP_solution(sol::OptimalControlSolution; filename_prefix="solution", format="JLD2")
if format == "JLD2"
save_object(filename_prefix * ".jld2", sol)
elseif format == "JSON"
open(filename_prefix * ".json", "w") do io
JSON3.pretty(io, OCP_Solution_discrete(sol))
end
else
println("ERROR: save_OCP_solution: format should be JLD2 or JSON, received ", format)
end
return nothing
end

"""
$(TYPEDSIGNATURES)
Load OCP solution in JLD2/JSON format
"""
function load_OCP_solution(filename_prefix="solution"; format="JLD2")
if format == "JLD2"
return load_object(filename_prefix * ".jld2")
elseif format == "JSON"
json_string = read(filename_prefix * ".json", String)
return JSON3.read(json_string)
else
println("ERROR: save_OCP_solution: format should be JLD2 or JSON, received ", format)
return nothing
end
end
12 changes: 9 additions & 3 deletions test/suite/misc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using CTDirect
using JLD2

println("Test: misc")

Expand Down Expand Up @@ -45,7 +44,14 @@ sol_raw = OCPSolutionFromDOCP_raw(docp, dsol.solution)

# test save / load solution in JLD2 format
@testset verbose = true showtiming = true ":save_load :JLD2" begin
save_object("sol.jld2", sol0)
sol_reloaded = load_object("sol.jld2")
save_OCP_solution(sol0, filename_prefix="solution_test")
sol_reloaded = load_OCP_solution("solution_test")
@test sol0.objective == sol_reloaded.objective
end

# test save / load solution in JSON format
@testset verbose = true showtiming = true ":save_load :JSON" begin
save_OCP_solution(sol0, filename_prefix="solution_test", format="JSON")
sol_reloaded = load_OCP_solution("solution_test", format="JSON")
@test sol0.objective == sol_reloaded.objective
end
14 changes: 9 additions & 5 deletions test/test_misc.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using CTDirect
using JLD2

println("Test: misc")

Expand Down Expand Up @@ -42,11 +41,16 @@ dsol2 = solve(docp2, print_level=5, tol=1e-12)
println("\nRebuild OCP solution from raw vector")
sol3 = OCPSolutionFromDOCP_raw(docp2, dsol2.solution)

# save / load solution in JLD2 format (solution includes complex data such as interpolated functions which are less suitable for more generic formats such as JSON)
save_object("sol.jld2", sol)
sol4 = load_object("sol.jld2")
# save / load solution in JLD2 format
save_OCP_solution(sol, filename_prefix="solution_test")
sol4 = load_OCP_solution("solution_test")
plot(sol4, show=true)
println(sol.objective == sol4.objective)

# save / load discrete solution in JSON format
# NB. we recover here a JSON Object...
save_OCP_solution(sol, filename_prefix="solution_test", format="JSON")
sol_disc_reloaded = load_OCP_solution("solution_test", format="JSON")
println(sol.objective == sol_disc_reloaded.objective)

println("")
#

0 comments on commit 5477756

Please sign in to comment.