Skip to content

Commit

Permalink
Merge pull request #55 from andgoldschmidt/quad_reg_fix
Browse files Browse the repository at this point in the history
Add timesteps to quadratic regularizers
  • Loading branch information
aarontrowbridge authored Nov 29, 2023
2 parents 86d57ac + 13b054e commit fdeea38
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
/examples/**/plots/
/examples/**/trajectories/
pardiso.lic
/.CondaPkg/
/.CondaPkg/
*.code-workspace
67 changes: 58 additions & 9 deletions src/objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ function QuadraticRegularizer(;
times::AbstractVector{Int}=1:traj.T,
dim::Int=nothing,
R::AbstractVector{<:Real}=ones(traj.dims[name]),
eval_hessian=true
eval_hessian=true,
timestep_symbol=:Δt
)

@assert !isnothing(name) "name must be specified"
Expand All @@ -318,18 +319,34 @@ function QuadraticRegularizer(;
@views function L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory)
J = 0.0
for t times
if Z.timestep isa Symbol
Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)]
else
Δt = Z.timestep
end

vₜ = Z⃗[slice(t, Z.components[name], Z.dim)]
J += 0.5 * vₜ' * (R .* vₜ)
rₜ = Δt .* vₜ
J += 0.5 * rₜ' * (R .* rₜ)
end
return J
end

@views function ∇L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory)
= zeros(Z.dim * Z.T)
= zeros(Z.dim * Z.T)
Threads.@threads for t times
vₜ_slice = slice(t, Z.components[name], Z.dim)
vₜ_slice = slice(t, Z.components[name], Z.dim)
vₜ = Z⃗[vₜ_slice]
∇[vₜ_slice] = R .* vₜ

if Z.timestep isa Symbol
Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim)
Δt = Z⃗[Δt_slice]
∇[Δt_slice] .= vₜ' * (R .* (Δt .* vₜ))
else
Δt = Z.timestep
end

∇[vₜ_slice] .= R .* (Δt.^2 .* vₜ)
end
return
end
Expand All @@ -341,16 +358,48 @@ function QuadraticRegularizer(;

∂²L_structure = Z -> begin
structure = []
# vₜ Hessian structure (eq. 17)
# Hessian structure (eq. 17)
for t times
vₜ_slice = slice(t, Z.components[name], Z.dim)
diag_inds = collect(zip(vₜ_slice, vₜ_slice))
append!(structure, diag_inds)
vₜ_vₜ_inds = collect(zip(vₜ_slice, vₜ_slice))
append!(structure, vₜ_vₜ_inds)

if Z.timestep isa Symbol
Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim)
# ∂²_vₜ_Δt
vₜ_Δt_inds = [(i, j) for i vₜ_slice for j Δt_slice]
append!(structure, vₜ_Δt_inds)
# ∂²_Δt_vₜ
Δt_vₜ_inds = [(i, j) for i Δt_slice for j vₜ_slice]
append!(structure, Δt_vₜ_inds)
# ∂²_Δt_Δt
Δt_Δt_inds = collect(zip(Δt_slice, Δt_slice))
append!(structure, Δt_Δt_inds)
end
end
return structure
end

∂²L = (Z⃗, Z) -> vcat(fill(R, length(times))...)
∂²L = (Z⃗, Z) -> begin
values = []
# Match Hessian structure indices
for t times
if Z.timestep isa Symbol
Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)]
append!(values, R .* Δt.^2)
# ∂²_vₜ_Δt, ∂²_Δt_vₜ
vₜ = Z⃗[slice(t, Z.components[name], Z.dim)]
append!(values, 2 * (R .* (Δt .* vₜ)))
append!(values, 2 * (R .* (Δt .* vₜ)))
# ∂²_Δt_Δt
append!(values, vₜ' * (R .* vₜ))
else
Δt = Z.timestep
append!(values, R .* Δt.^2)
end
end
return values
end
end

return Objective(L, ∇L, ∂²L, ∂²L_structure, Dict[params])
Expand Down

0 comments on commit fdeea38

Please sign in to comment.