Skip to content

Commit

Permalink
fixing regression on t_i types; still some errors in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMartinon committed Sep 30, 2024
1 parent 8c80e65 commit 6931a08
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 497 deletions.
183 changes: 39 additions & 144 deletions benchmark/prof.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ function init(;in_place, grid_size, disc_method)
end


function test_unit(;test_get=false, test_dyn=false, test_unit_cons=false, test_mayer=false, test_obj=false, test_block=true, test_cons=false, test_trans=false, test_solve=false, warntype=false, jet=false, profile=false, grid_size=100, disc_method=:trapeze, in_place=true)
function test_unit(;test_get=false, test_dyn=false, test_unit_cons=false, test_mayer=false, test_obj=false, test_block=false, test_cons=false, test_trans=false, test_solve=false, warntype=false, jet=false, profile=false, grid_size=100, disc_method=:trapeze, in_place=true)

# define problem and variables
prob, docp, xu = init(in_place=in_place, grid_size=grid_size, disc_method=disc_method)
Expand All @@ -45,24 +45,18 @@ function test_unit(;test_get=false, test_dyn=false, test_unit_cons=false, test_m
if test_get
println("Getters")
print("t"); @btime CTDirect.get_final_time($xu, $docp)
print("v"); @btime CTDirect.get_optim_variable($xu, $docp)
print("vv"); @btime CTDirect.get_OCP_variable($xu, $docp)
print("x"); @btime CTDirect.get_state_at_time_step($xu, $docp, $docp.dim_NLP_steps)
print("xx"); @btime CTDirect.get_OCP_state_at_time_step($xu, $docp, $docp.dim_NLP_steps)
print("u"); @btime CTDirect.get_control_at_time_step($xu, $docp, $docp.dim_NLP_steps)
print("x"); @btime CTDirect.get_OCP_state_at_time_step($xu, $docp, 1)
print("u"); @btime CTDirect.get_OCP_control_at_time_step($xu, $docp, 1)
print("v"); @btime CTDirect.get_OCP_variable($xu, $docp)
if warntype
@code_warntype CTDirect.get_final_time(xu, docp)
@code_warntype CTDirect.get_optim_variable(xu, docp)
@code_warntype CTDirect.get_state_at_time_step(xu, docp, docp.dim_NLP_steps)
@code_warntype CTDirect.get_control_at_time_step(xu, docp, docp.dim_NLP_steps)
@code_warntype CTDirect.get_time_grid!(xu, docp)
@code_warntype CTDirect.get_OCP_state_at_time_step(xu, docp, 1)
@code_warntype CTDirect.get_OCP_control_at_time_step(xu, docp, 1)
@code_warntype CTDirect.get_OCP_variable(xu, docp)
end
end

t = CTDirect.get_final_time(xu, docp)
v = CTDirect.get_optim_variable(xu, docp)
x = CTDirect.get_state_at_time_step(xu, docp, docp.dim_NLP_steps)
u = CTDirect.get_control_at_time_step(xu, docp, docp.dim_NLP_steps)
f = similar(xu, docp.dim_NLP_x)

if test_dyn
Expand Down Expand Up @@ -107,73 +101,30 @@ function test_unit(;test_get=false, test_dyn=false, test_unit_cons=false, test_m
nx = docp.dim_NLP_x
m = docp.dim_NLP_u
N = docp.dim_NLP_steps
x0 = CTDirect.get_state_at_time_step(xu, docp, 1)
xx0 = CTDirect.get_OCP_state_at_time_step(xu, docp, 1)
xf = CTDirect.get_state_at_time_step(xu, docp, N+1)
xxf = CTDirect.get_OCP_state_at_time_step(xu, docp, N+1)
v = CTDirect.get_optim_variable(xu, docp)
vv = CTDirect.get_OCP_variable(xu, docp)
x0 = CTDirect.get_OCP_state_at_time_step(xu, docp, 1)
xf = CTDirect.get_OCP_state_at_time_step(xu, docp, N+1)
v = CTDirect.get_OCP_variable(xu, docp)
obj = similar(xu,1)

#=println("")
print("Local Mayer: views for x0/xf and scalar v"); @btime local_mayer($obj, (@view $xu[1:$n]), (@view $xu[($nx + $m) * $N + 1: ($nx + $m) * $N + $n]), $xu[end])
print("Local Mayer: raw getters for x0/xf and v"); @btime local_mayer($obj, $x0, $xf, $v)
print("Local Mayer: getters with scalarization"); @btime local_mayer($obj, $docp._x($x0), $docp._x($xf), $docp._v($v))
print("Local Mayer: scal/vec getters"); @btime local_mayer($obj, $xx0, $xxf, $vv)
=#

# local mayer
println("")
print("Local Mayer: views for x0/xf and scalar v"); @btime local_mayer($obj, (@view $xu[1:$n]), (@view $xu[($nx + $m) * $N + 1: ($nx + $m) * $N + $n]), $xu[end]) # OK
print("Local Mayer: param scal/vec getters"); @btime local_mayer($obj, $x0, $xf, $v) # OK
print("OCP Mayer: param scal/vec getters"); @btime $docp.ocp.mayer($obj, $x0, $xf, $v) # 3 allocs (112)

#print("OCP Mayer: views for x0/xf and scalar v"); @btime $docp.ocp.mayer($obj, (@view $xu[1:$n]), (@view $xu[($nx + $m) * $N + 1: ($nx + $m) * $N + $n]), $xu[end])

#print("OCP Mayer: raw getters for x0/xf and v"); @btime $docp.ocp.mayer($obj, $x0, $xf, $v)

print("OCP Mayer: getters with scalarization"); @btime $docp.ocp.mayer($obj, $docp._x($x0), $docp._x($xf), $docp._v($v))

print("OCP Mayer: scal/vec getters"); @btime $docp.ocp.mayer($obj, $xx0, $xxf, $vv)

if warntype
#println("code warntype local mayer")
#@code_warntype local_mayer(obj, (@view xu[1:n]), (@view xu[(nx + m) * N + 1: (nx + m) * N + n]), xu[end])
#println("code warntype end")
println("code warntype ocp mayer getters + scalarization")
@code_warntype docp.ocp.mayer(obj, docp._x(x0), docp._x(xf), docp._v(v))
println("code warntype end")
println("code warntype ocp mayer scal/vect getters")
@code_warntype docp.ocp.mayer(obj, xx0, xxf, vv)
println("code warntype end")
warntype && @code_warntype docp.ocp.mayer(obj, x0, xf, v)
jet && display(@report_opt docp.ocp.mayer(obj, x0, xf, v))
if profile
Profile.Allocs.@profile sample_rate=1.0 docp.ocp.mayer(obj, x0, xf, v)
results = Profile.Allocs.fetch()
PProf.Allocs.pprof()
end
if jet
println("JET ocp mayer getters + scalarization")
display(@report_opt docp.ocp.mayer(obj, docp._x(x0), docp._x(xf), docp._v(v)))
println("JET end")
println("JET ocp mayer scal/vect getters")
display(@report_opt docp.ocp.mayer(obj, xx0, xxf, vv))
println("JET end")
end
end

if test_obj
print("Objective"); @btime CTDirect.DOCP_objective($xu, $docp)
print("Objective_param"); @btime CTDirect.DOCP_objective_param($xu, $docp)
if warntype
println("code warntype Objective")
@code_warntype CTDirect.DOCP_objective(xu, docp)
println("code warntype end")
println("code warntype Objective_param")
@code_warntype CTDirect.DOCP_objective_param(xu, docp)
println("code warntype end")
end
if jet
println("JET Objective")
display(@report_opt CTDirect.DOCP_objective(xu, docp))
println("JET end")
println("JET Objective_param")
display(@report_opt CTDirect.DOCP_objective_param(xu, docp))
println("JET end")
end
warntype && @code_warntype CTDirect.DOCP_objective(xu, docp) # quasi OK (inplace/outplace for ocp.mayer return ?)
jet && display(@report_opt CTDirect.DOCP_objective(xu, docp))
if profile
Profile.Allocs.@profile sample_rate=1.0 CTDirect.DOCP_objective(xu, docp)
results = Profile.Allocs.fetch()
Expand All @@ -182,34 +133,29 @@ function test_unit(;test_get=false, test_dyn=false, test_unit_cons=false, test_m
end

if test_block
print("Constraints block")
CTDirect.get_time_grid!(xu, docp)
v = CTDirect.get_OCP_variable_param(xu, docp)
work = CTDirect.setWorkArray_param(docp, xu, docp.NLP_time_grid, v)
CTDirect.get_time_grid!(xu, docp) # type OK
i = 1
@btime CTDirect.setConstraintBlock_param!($docp, $c, $xu, $v, $docp.NLP_time_grid, $i, $work)
+warntype
+profile

v = CTDirect.get_OCP_variable(xu, docp)
work = CTDirect.setWorkArray(docp, xu, docp.NLP_time_grid, v)
print("Constraints block")
@btime CTDirect.setConstraintBlock!($docp, $c, $xu, $v, $docp.NLP_time_grid, $i, $work)
warntype && @code_warntype CTDirect.setConstraintBlock!(docp, c, xu, v, docp.NLP_time_grid, i, work)
jet && display(@report_opt CTDirect.setConstraintBlock!(docp, c, xu, v, docp.NLP_time_grid, i, work))
if profile
Profile.Allocs.@profile sample_rate=1.0 CTDirect.setConstraintBlock!(docp, c, xu, v, docp.NLP_time_grid, i, work)
results = Profile.Allocs.fetch()
PProf.Allocs.pprof()
end
end

# DOCP_constraints
if test_cons
#print("Constraints"); @btime CTDirect.DOCP_constraints!($c, $xu, $docp)
print("Constraints param"); @btime CTDirect.DOCP_constraints_param!($c, $xu, $docp)
if any(c.==666.666)
error("undefined values in constraints ",c)
end
if warntype
#@code_warntype CTDirect.DOCP_constraints!(c, xu, docp)
@code_warntype CTDirect.DOCP_constraints_param!(c, xu, docp)
end
if jet
#display(@report_opt CTDirect.DOCP_constraints!(c, xu, docp))
display(@report_opt CTDirect.DOCP_constraints_param!(c, xu, docp))
end
print("Constraints"); @btime CTDirect.DOCP_constraints!($c, $xu, $docp)
any(c.==666.666) && error("undefined values in constraints ",c)
warntype && @code_warntype CTDirect.DOCP_constraints!(c, xu, docp)
jet && display(@report_opt CTDirect.DOCP_constraints!(c, xu, docp))
if profile
Profile.Allocs.@profile sample_rate=1.0 CTDirect.DOCP_constraints_param!(c, xu, docp)
Profile.Allocs.@profile sample_rate=1.0 CTDirect.DOCP_constraints!(c, xu, docp)
results = Profile.Allocs.fetch()
PProf.Allocs.pprof()
end
Expand All @@ -231,54 +177,3 @@ function test_unit(;test_get=false, test_dyn=false, test_unit_cons=false, test_m

end



#= constraints profile
Total: 0 6998 (flat, cum) 89.08%
95 . . ui = get_control_at_time_step(xu, docp, i)
96 . .
97 . . #1. state equation
98 . . if i <= docp.dim_NLP_steps
99 . . # more variables
100 . 100 fi = copy(work) # create new copy, not just a reference
101 . . tip1 = time_grid[i+1]
102 . . xip1 = get_state_at_time_step(xu, docp, i+1)
103 . . uip1 = get_control_at_time_step(xu, docp, i+1)
104 . . if docp.has_inplace
105 . . docp.dynamics_ext(work, tip1, xip1, uip1, v)
106 . . else
107 . . # copy, do not create a new variable !
108 . 1000 work[:] = docp.dynamics_ext(tip1, xip1, uip1, v) #+++ jet runtime dispatch here
109 . . end
110 . .
111 . . # trapeze rule with 'smart' update for dynamics (similar with @.)
112 . 800 c[offset+1:offset+docp.dim_NLP_x] = xip1 - (xi + 0.5 * (tip1 - ti) * (fi + work)) #+++ jet runtime dispatch here even with explicit index ranges
113 . . offset += docp.dim_NLP_x
114 . . end
115 . .
116 . . # 2. path constraints
117 . . # Notes on allocations:.= seems similar
118 . . if docp.dim_u_cons > 0
119 . . if docp.has_inplace
120 . . docp.control_constraints[2]((@view c[offset+1:offset+docp.dim_u_cons]),ti, docp._u(ui), v)
121 . . else
122 . 1632 c[offset+1:offset+docp.dim_u_cons] = docp.control_constraints[2](ti, docp._u(ui), v)
123 . . end
124 . . end
125 . . if docp.dim_x_cons > 0
126 . . if docp.has_inplace
127 . . docp.state_constraints[2]((@view c[offset+docp.dim_u_cons+1:offset+docp.dim_u_cons+docp.dim_x_cons]),ti, docp._x(xi), v)
128 . . else
129 . 1531 c[offset+docp.dim_u_cons+1:offset+docp.dim_u_cons+docp.dim_x_cons] = docp.state_constraints[2](ti, docp._x(xi), v)
130 . . end
131 . . end
132 . . if docp.dim_mixed_cons > 0
133 . . if docp.has_inplace
134 . . docp.mixed_constraints[2]((@view c[offset+docp.dim_u_cons+docp.dim_x_cons+1:offset+docp.dim_u_cons+docp.dim_x_cons+docp.dim_mixed_cons]), ti, docp._x(xi), docp._u(ui), v)
135 . . else
136 . 1935 c[offset+docp.dim_u_cons+docp.dim_x_cons+1:offset+docp.dim_u_cons+docp.dim_x_cons+docp.dim_mixed_cons] = docp.mixed_constraints[2](ti, docp._x(xi), docp._u(ui), v)
137 . . end
138 . . end
139 . .
140 . . end
=#
Loading

0 comments on commit 6931a08

Please sign in to comment.