Skip to content

Commit

Permalink
updated json format export/import using new OptimalControlSolution fr…
Browse files Browse the repository at this point in the history
…om CTBase 0.13
  • Loading branch information
PierreMartinon committed Aug 26, 2024
1 parent 0de20e7 commit e1c4211
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ CTSolveExtMadNLP = ["MadNLP"]

[compat]
ADNLPModels = "0.8"
CTBase = "0.12, 0.13"
CTBase = "0.13"
DocStringExtensions = "0.9"
HSL = "0.4"
JLD2 = "0.4"
Expand Down
35 changes: 27 additions & 8 deletions ext/CTDirectExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,18 @@ $(TYPEDSIGNATURES)
Export OCP solution in JSON format
"""
function CTDirect.export_ocp_solution(sol::OptimalControlSolution; filename_prefix="solution")
# +++ redo this, start with basics, fuse into save
#open(filename_prefix * ".json", "w") do io
# JSON3.pretty(io, CTDirect.OCPDiscreteSolution(sol))
#end
# fuse into save ?
blob = Dict(
"objective" => sol.objective,
"time_grid" => sol.time_grid,
"state" => state_discretized(sol),
"control" => control_discretized(sol),
"costate" => costate_discretized(sol)[1:end-1,:],
"variable" => sol.variable
)
open(filename_prefix * ".json", "w") do io
JSON3.pretty(io, blob)
end
return nothing
end

Expand All @@ -45,10 +53,21 @@ $(TYPEDSIGNATURES)
Read OCP solution in JSON format
"""
function CTDirect.import_ocp_solution(filename_prefix="solution")
# +++ add constructor from json blob, fuse into load
#json_string = read(filename_prefix * ".json", String)
#return OptimalControlSolution(JSON3.read(json_string))
function CTDirect.import_ocp_solution(ocp::OptimalControlModel; filename_prefix="solution")
# fuse into load ?
json_string = read(filename_prefix * ".json", String)
blob = JSON3.read(json_string)

# NB. convert vect{vect} to matrix
return OptimalControlSolution(
ocp,
blob.time_grid,
stack(blob.state, dims=1),
stack(blob.control, dims=1),
blob.variable,
stack(blob.costate, dims=1);
objective = blob.objective
)
end


Expand Down
12 changes: 6 additions & 6 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function CTBase.OptimalControlSolution(

# call lowest level constructor
return OptimalControlSolution(
docp,
docp.ocp,
T,
X,
U,
Expand Down Expand Up @@ -263,7 +263,7 @@ $(TYPEDSIGNATURES)
Build OCP functional solution from DOCP vector solution (given as raw variables and multipliers plus some optional infos)
"""
function CTBase.OptimalControlSolution(
docp,
ocp::OptimalControlModel,
T,
X,
U,
Expand All @@ -280,7 +280,6 @@ function CTBase.OptimalControlSolution(
box_multipliers = ((nothing, nothing), (nothing, nothing), (nothing, nothing)),
)

ocp = docp.ocp
dim_x = state_dimension(ocp)
dim_u = control_dimension(ocp)
dim_v = variable_dimension(ocp)
Expand All @@ -292,13 +291,14 @@ function CTBase.OptimalControlSolution(
"WARNING: time grid at solution is not strictly increasing, replacing with list of indices...",
)
println(T)
T = LinRange(0, docp.dim_NLP_steps, docp.dim_NLP_steps + 1)
dim_NLP_steps = length(T) - 1
T = LinRange(0, dim_NLP_steps, dim_NLP_steps + 1)
end

# variables: remove additional state for lagrange cost
x = ctinterpolate(T, matrix2vec(X[:, 1:dim_x], 1))
p = ctinterpolate(T[1:end-1], matrix2vec(P[:, 1:dim_x], 1))
u = ctinterpolate(T, matrix2vec(U, 1))
u = ctinterpolate(T, matrix2vec(U[:, 1:dim_u], 1))

# force scalar output when dimension is 1
fx = (dim_x == 1) ? deepcopy(t -> x(t)[1]) : deepcopy(t -> x(t))
Expand Down Expand Up @@ -335,7 +335,7 @@ function CTBase.OptimalControlSolution(
) = set_box_multipliers(T, box_multipliers, dim_x, dim_u)

# build and return solution
if docp.has_variable
if is_variable_dependent(ocp)
return OptimalControlSolution(
ocp;
state = fx,
Expand Down
10 changes: 5 additions & 5 deletions test/suite/test_misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ sol0 = direct_solve(ocp, display = false)

# test save / load solution in JLD2 format
@testset verbose = true showtiming = true ":save_load :JLD2" begin
save(sol0, filename_prefix = "solution_test")
save(sol0; filename_prefix = "solution_test")
sol_reloaded = load("solution_test")
@test sol0.objective == sol_reloaded.objective
end

#=

# test export / read solution in JSON format
@testset verbose = true showtiming = true ":export_read :JSON" begin
export_ocp_solution(sol0, filename_prefix = "solution_test")
sol_reloaded = import_ocp_solution("solution_test")
export_ocp_solution(sol0; filename_prefix = "solution_test")
sol_reloaded = import_ocp_solution(ocp; filename_prefix = "solution_test")
@test sol0.objective == sol_reloaded.objective
end
=#

0 comments on commit e1c4211

Please sign in to comment.