-
Notifications
You must be signed in to change notification settings - Fork 5
/
ctbase.jl
74 lines (61 loc) · 1.96 KB
/
ctbase.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Olivier, tu as ici l'extension qui fait les sauvergardes / lectures de solution en format julia (jld2) et texte (json). Le load/save est assez trivial, mais on pourrait unifier avec le json en ajoutant un argument format=:jld [:json]. Cette extension irqit ensuite dans CTBase typiquement vu que l'aspect direct n'intervient pas
module CTDirectExt
using CTDirect
using CTBase
using DocStringExtensions
using JLD2 # load / save
using JSON3 # read / export
"""
$(TYPEDSIGNATURES)
Save OCP solution in JLD2 format
"""
function JLD2.save(sol::OptimalControlSolution; filename_prefix = "solution")
save_object(filename_prefix * ".jld2", sol)
return nothing
end
"""
$(TYPEDSIGNATURES)
Load OCP solution in JLD2 format
"""
function JLD2.load(filename_prefix = "solution")
return load_object(filename_prefix * ".jld2")
end
"""
$(TYPEDSIGNATURES)
Export OCP solution in JSON format
"""
function CTDirect.export_ocp_solution(sol::OptimalControlSolution; filename_prefix = "solution")
# fuse into save ?
blob = Dict(
"objective" => sol.objective,
"time_grid" => sol.time_grid,
"state" => state_discretized(sol),
"control" => control_discretized(sol),
"costate" => costate_discretized(sol)[1:(end - 1), :],
"variable" => sol.variable,
)
open(filename_prefix * ".json", "w") do io
JSON3.pretty(io, blob)
end
return nothing
end
"""
$(TYPEDSIGNATURES)
Read OCP solution in JSON format
"""
function CTDirect.import_ocp_solution(ocp::OptimalControlModel; filename_prefix = "solution")
# fuse into load ?
json_string = read(filename_prefix * ".json", String)
blob = JSON3.read(json_string)
# NB. convert vect{vect} to matrix
return OptimalControlSolution(
ocp,
blob.time_grid,
stack(blob.state, dims = 1),
stack(blob.control, dims = 1),
blob.variable,
stack(blob.costate, dims = 1);
objective = blob.objective,
)
end
end