Skip to content

Commit

Permalink
Merge pull request #278 from control-toolbox/auto-juliaformatter-pr
Browse files Browse the repository at this point in the history
[AUTO] JuliaFormatter.jl run
  • Loading branch information
ocots authored Sep 8, 2024
2 parents 7ae5234 + d53bdcf commit 21397b3
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 109 deletions.
6 changes: 3 additions & 3 deletions benchmark/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function bench(;
precompile = true,
display = false,
verbose = true,
discretization = :trapeze
discretization = :trapeze,
)

#######################################################
Expand Down Expand Up @@ -67,7 +67,7 @@ function bench(;
linear_solver = linear_solver,
max_iter = 0,
display = display,
discretization = discretization
discretization = discretization,
)
t_precomp += t
end
Expand All @@ -86,7 +86,7 @@ function bench(;
linear_solver = linear_solver,
grid_size = grid_size,
tol = tol,
discretization = discretization
discretization = discretization,
)
if !isnothing(problem[:obj]) && !isapprox(sol.objective, problem[:obj], rtol = 5e-2)
error(
Expand Down
17 changes: 12 additions & 5 deletions benchmark/prof.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ if precompile
CTDirect.DOCP_objective(CTDirect.DOCP_initial_guess(docp), docp)
end
if test_constraints
CTDirect.DOCP_constraints!(zeros(docp.dim_NLP_constraints), CTDirect.DOCP_initial_guess(docp), docp)
CTDirect.DOCP_constraints!(
zeros(docp.dim_NLP_constraints),
CTDirect.DOCP_initial_guess(docp),
docp,
)
end
end

Expand All @@ -49,8 +53,12 @@ if test_objective
@btime CTDirect.DOCP_objective(CTDirect.DOCP_initial_guess(docp), docp)
end
if test_constraints
println("Timed constraints")
@btime CTDirect.DOCP_constraints!(zeros(docp.dim_NLP_constraints), CTDirect.DOCP_initial_guess(docp), docp)
println("Timed constraints")
@btime CTDirect.DOCP_constraints!(
zeros(docp.dim_NLP_constraints),
CTDirect.DOCP_initial_guess(docp),
docp,
)
end

# transcription
Expand All @@ -62,10 +70,9 @@ end
# full solve
if test_solve
println("Timed full solve")
@btime sol = direct_solve(ocp, grid_size = grid_size, display=false)
@btime sol = direct_solve(ocp, grid_size = grid_size, display = false)
end


if test_code_warntype
if test_objective
# NB. Pb with the mayer part: obj is type unstable (Any) because ocp.mayer is Union(Mayer,nothing), even for mayer problems (also, we should not even enter this code part for lagrange problems since has_mayer us defined as const in DOCP oO ...).
Expand Down
8 changes: 7 additions & 1 deletion src/gauss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ struct GaussLegendre2 <: Discretization
butcher_a::Matrix{Float64}
butcher_b::Vector{Float64}
butcher_c::Vector{Float64}
GaussLegendre2() = new(2, 0, [0.25 (0.25 - sqrt(3)/6); (0.25 + sqrt(3)/6) 0.25], [0.5, 0.5], [(0.5 - sqrt(3)/6), (0.5 + sqrt(3)/6)])
GaussLegendre2() = new(
2,
0,
[0.25 (0.25-sqrt(3) / 6); (0.25+sqrt(3) / 6) 0.25],
[0.5, 0.5],
[(0.5 - sqrt(3) / 6), (0.5 + sqrt(3) / 6)],
)
end
52 changes: 19 additions & 33 deletions src/midpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@ struct Midpoint <: Discretization
Midpoint() = new(1, 0)
end


"""
$(TYPEDSIGNATURES)
Retrieve state and control variables at given time step from the NLP variables.
"""
function get_variables_at_time_step(xu, docp::DOCP{Midpoint}, i)

nx = docp.dim_NLP_x
n = docp.dim_OCP_x
m = docp.dim_NLP_u
N = docp.dim_NLP_steps
offset = (nx*(1+docp.discretization.stage) + m) * i
offset = (nx * (1 + docp.discretization.stage) + m) * i

# retrieve scalar/vector OCP state (w/o lagrange state)
if n == 1
Expand All @@ -42,7 +40,7 @@ function get_variables_at_time_step(xu, docp::DOCP{Midpoint}, i)
if i < N
offset_u = offset
else
offset_u = (nx*2 + m) * (i-1)
offset_u = (nx * 2 + m) * (i - 1)
end
if m == 1
ui = xu[offset_u + nx + 1]
Expand All @@ -52,53 +50,49 @@ function get_variables_at_time_step(xu, docp::DOCP{Midpoint}, i)

# retrieve vector stage variable (except at final time)
if i < N
ki = xu[(offset + nx + m + 1):(offset + nx + m + nx) ]
ki = xu[(offset + nx + m + 1):(offset + nx + m + nx)]
else
ki = nothing
end

return xi, ui, xli, ki
end


# internal NLP version for solution parsing
# could be fused with one above if
# - using extended dynamics that include lagrange cost
# - scalar case is handled at OCP level
function get_NLP_variables_at_time_step(xu, docp, i, disc::Midpoint)

nx = docp.dim_NLP_x
m = docp.dim_NLP_u
N = docp.dim_NLP_steps
offset = (nx*2 + m) * i
offset = (nx * 2 + m) * i

# state
xi = xu[(offset + 1):(offset + nx)]
# control
if i < N
offset_u = offset
else
offset_u = (nx*2 + m) * (i-1)
end
offset_u = (nx * 2 + m) * (i - 1)
end
ui = xu[(offset_u + nx + 1):(offset_u + nx + m)]
# stage
if i < N
ki = xu[(offset + nx + m + 1):(offset + nx + m + nx) ]
ki = xu[(offset + nx + m + 1):(offset + nx + m + nx)]
else
ki = nothing
end

return xi, ui, ki
end


function set_variables_at_time_step!(xu, x_init, u_init, docp, i, disc::Midpoint)

nx = docp.dim_NLP_x
n = docp.dim_OCP_x
m = docp.dim_NLP_u
N = docp.dim_NLP_steps
offset = (nx*2 + m) * i
offset = (nx * 2 + m) * i

# NB. only set the actual state variables from the OCP
# - skip the possible additional state for lagrange cost
Expand All @@ -111,7 +105,6 @@ function set_variables_at_time_step!(xu, x_init, u_init, docp, i, disc::Midpoint
end
end


# trivial version for now...
# +++multiple dispatch here seems to cause more allocations !
# +++? use abstract type for all Args ?
Expand All @@ -129,20 +122,19 @@ struct ArgsAtTimeStep_Midpoint
next_time::Any
next_state::Any
next_lagrange_state::Any

function ArgsAtTimeStep_Midpoint(xu, docp::DOCP{Midpoint}, v, time_grid, i::Int)

function ArgsAtTimeStep_Midpoint(xu, docp::DOCP{Midpoint}, v, time_grid, i::Int)
disc = docp.discretization

# variables
ti = time_grid[i+1]
ti = time_grid[i + 1]
xi, ui, xli, ki = get_variables_at_time_step(xu, docp, i)

if i == docp.dim_NLP_steps
return new(ti, xi, ui, xli, ki, disc)
else
tip1 = time_grid[i+2]
xip1, uip1, xlip1 = get_variables_at_time_step(xu, docp, i+1)
tip1 = time_grid[i + 2]
xip1, uip1, xlip1 = get_variables_at_time_step(xu, docp, i + 1)
return new(ti, xi, ui, xli, ki, tip1, xip1, xlip1)
end
end
Expand All @@ -151,20 +143,18 @@ function initArgs(xu, docp::DOCP{Midpoint}, time_grid)
v = Float64[]
docp.has_variable && (v = get_optim_variable(xu, docp))
args = ArgsAtTimeStep_Midpoint(xu, docp, v, time_grid, 0)
return args, v
return args, v
end
function updateArgs(args, xu, docp::DOCP{Midpoint}, v, time_grid, i::Int)
return ArgsAtTimeStep_Midpoint(xu, docp, v, time_grid, i+1)
return ArgsAtTimeStep_Midpoint(xu, docp, v, time_grid, i + 1)
end


"""
$(TYPEDSIGNATURES)
Set the constraints corresponding to the state equation
"""
function setStateEquation!(docp::DOCP{Midpoint}, c, index::Int, args, v, i)

ocp = docp.ocp

# +++ later use butcher table in struct ?
Expand All @@ -181,36 +171,32 @@ function setStateEquation!(docp::DOCP{Midpoint}, c, index::Int, args, v, i)
hi = tip1 - ti

# midpoint rule
@. c[index:(index + docp.dim_OCP_x - 1)] =
xip1 - (xi + hi * ki[1:docp.dim_OCP_x])
@. c[index:(index + docp.dim_OCP_x - 1)] = xip1 - (xi + hi * ki[1:(docp.dim_OCP_x)])
# +++ just define extended dynamics !
if docp.has_lagrange
c[index + docp.dim_OCP_x] = xlip1 - (xli + hi * ki[end])
end
index += docp.dim_NLP_x

# stage equation at mid-step
t_s = 0.5 * (ti + tip1)
x_s = 0.5 * (xi + xip1)
c[index:(index + docp.dim_OCP_x - 1)] .=
ki[1:docp.dim_OCP_x] .- ocp.dynamics(t_s, x_s, ui, v)
c[index:(index + docp.dim_OCP_x - 1)] .= ki[1:(docp.dim_OCP_x)] .- ocp.dynamics(t_s, x_s, ui, v)
# +++ just define extended dynamics !
if docp.has_lagrange
c[index + docp.dim_OCP_x] = ki[end] - ocp.lagrange(t_s, x_s, ui, v)
c[index + docp.dim_OCP_x] = ki[end] - ocp.lagrange(t_s, x_s, ui, v)
end
index += docp.dim_NLP_x

return index
end


"""
$(TYPEDSIGNATURES)
Set the path constraints at given time step
"""
function setPathConstraints!(docp::DOCP{Midpoint}, c, index::Int, args, v, i::Int)

ocp = docp.ocp
ti = args.time
xi = args.state
Expand Down
Loading

0 comments on commit 21397b3

Please sign in to comment.