From d50505c35f03d25c951fef5f6667025b4cfa9181 Mon Sep 17 00:00:00 2001 From: andy Date: Mon, 27 Nov 2023 23:59:57 -0600 Subject: [PATCH 1/2] Add timesteps to quadratic regularizers --- .gitignore | 3 ++- src/objectives.jl | 57 +++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 64e7b3df..36868b0f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ /examples/**/plots/ /examples/**/trajectories/ pardiso.lic -/.CondaPkg/ \ No newline at end of file +/.CondaPkg/ +*.code-workspace \ No newline at end of file diff --git a/src/objectives.jl b/src/objectives.jl index 7f3247a5..bfd7107c 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -297,7 +297,8 @@ function QuadraticRegularizer(; times::AbstractVector{Int}=1:traj.T, dim::Int=nothing, R::AbstractVector{<:Real}=ones(traj.dims[name]), - eval_hessian=true + eval_hessian=true, + timestep_symbol=:Δt ) @assert !isnothing(name) "name must be specified" @@ -316,8 +317,10 @@ function QuadraticRegularizer(; @views function L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory) J = 0.0 for t ∈ times + Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] - J += 0.5 * vₜ' * (R .* vₜ) + rₜ = Δt .* vₜ + J += 0.5 * rₜ' * (R .* rₜ) end return J end @@ -325,9 +328,15 @@ function QuadraticRegularizer(; @views function ∇L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory) ∇ = zeros(Z.dim * Z.T) Threads.@threads for t ∈ times - vₜ_slice = slice(t, Z.components[name], Z.dim) + Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim) + vₜ_slice = slice(t, Z.components[name], Z.dim) + Δt = Z⃗[Δt_slice] vₜ = Z⃗[vₜ_slice] - ∇[vₜ_slice] = R .* vₜ + ∇[vₜ_slice] .= R .* (Δt.^2 .* vₜ) + + if Z.timestep isa Symbol + ∇[Δt_slice] .= vₜ' * (R .* (Δt .* vₜ)) + end end return ∇ end @@ -339,16 +348,46 @@ function QuadraticRegularizer(; ∂²L_structure = Z -> begin structure = [] - # vₜ Hessian structure (eq. 17) + # Hessian structure (eq. 17) for t ∈ times vₜ_slice = slice(t, Z.components[name], Z.dim) - diag_inds = collect(zip(vₜ_slice, vₜ_slice)) - append!(structure, diag_inds) + vₜ_vₜ_inds = collect(zip(vₜ_slice, vₜ_slice)) + append!(structure, vₜ_vₜ_inds) + + if Z.timestep isa Symbol + Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim) + # ∂²_vₜ_Δt + vₜ_Δt_inds = [(i, j) for i ∈ vₜ_slice for j ∈ Δt_slice] + append!(structure, vₜ_Δt_inds) + # ∂²_Δt_vₜ + Δt_vₜ_inds = [(i, j) for i ∈ Δt_slice for j ∈ vₜ_slice] + append!(structure, Δt_vₜ_inds) + # ∂²_Δt_Δt + Δt_Δt_inds = collect(zip(Δt_slice, Δt_slice)) + append!(structure, Δt_Δt_inds) + end end return structure end - - ∂²L = (Z⃗, Z) -> vcat(fill(R, length(times))...) + + ∂²L = (Z⃗, Z) -> begin + values = [] + # Match Hessian structure indices + for t ∈ times + Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] + append!(values, R .* Δt.^2) + + if Z.timestep isa Symbol + vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] + # ∂²_vₜ_Δt, ∂²_Δt_vₜ + append!(values, 2 * (R .* (Δt .* vₜ))) + append!(values, 2 * (R .* (Δt .* vₜ))) + # ∂²_Δt_Δt + append!(values, vₜ' * (R .* vₜ)) + end + end + return values + end end return Objective(L, ∇L, ∂²L, ∂²L_structure, Dict[params]) From 13b054e3bbd4ef9886cddcdf1abce6e941bbe17f Mon Sep 17 00:00:00 2001 From: andy Date: Tue, 28 Nov 2023 14:30:58 -0600 Subject: [PATCH 2/2] Bug fix: timestep for free_time=false --- src/objectives.jl | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/objectives.jl b/src/objectives.jl index bfd7107c..5a377a33 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -317,7 +317,12 @@ function QuadraticRegularizer(; @views function L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory) J = 0.0 for t ∈ times - Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] + if Z.timestep isa Symbol + Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] + else + Δt = Z.timestep + end + vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] rₜ = Δt .* vₜ J += 0.5 * rₜ' * (R .* rₜ) @@ -326,17 +331,20 @@ function QuadraticRegularizer(; end @views function ∇L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory) - ∇ = zeros(Z.dim * Z.T) + ∇ = zeros(Z.dim * Z.T) Threads.@threads for t ∈ times - Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim) vₜ_slice = slice(t, Z.components[name], Z.dim) - Δt = Z⃗[Δt_slice] vₜ = Z⃗[vₜ_slice] - ∇[vₜ_slice] .= R .* (Δt.^2 .* vₜ) if Z.timestep isa Symbol + Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim) + Δt = Z⃗[Δt_slice] ∇[Δt_slice] .= vₜ' * (R .* (Δt .* vₜ)) + else + Δt = Z.timestep end + + ∇[vₜ_slice] .= R .* (Δt.^2 .* vₜ) end return ∇ end @@ -369,21 +377,23 @@ function QuadraticRegularizer(; end return structure end - + ∂²L = (Z⃗, Z) -> begin values = [] # Match Hessian structure indices for t ∈ times - Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] - append!(values, R .* Δt.^2) - if Z.timestep isa Symbol - vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] + Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] + append!(values, R .* Δt.^2) # ∂²_vₜ_Δt, ∂²_Δt_vₜ + vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] append!(values, 2 * (R .* (Δt .* vₜ))) append!(values, 2 * (R .* (Δt .* vₜ))) # ∂²_Δt_Δt append!(values, vₜ' * (R .* vₜ)) + else + Δt = Z.timestep + append!(values, R .* Δt.^2) end end return values