diff --git a/.gitignore b/.gitignore index 64e7b3df..36868b0f 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ /examples/**/plots/ /examples/**/trajectories/ pardiso.lic -/.CondaPkg/ \ No newline at end of file +/.CondaPkg/ +*.code-workspace \ No newline at end of file diff --git a/src/objectives.jl b/src/objectives.jl index 1de467da..742b8b14 100644 --- a/src/objectives.jl +++ b/src/objectives.jl @@ -299,7 +299,8 @@ function QuadraticRegularizer(; times::AbstractVector{Int}=1:traj.T, dim::Int=nothing, R::AbstractVector{<:Real}=ones(traj.dims[name]), - eval_hessian=true + eval_hessian=true, + timestep_symbol=:Δt ) @assert !isnothing(name) "name must be specified" @@ -318,18 +319,34 @@ function QuadraticRegularizer(; @views function L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory) J = 0.0 for t ∈ times + if Z.timestep isa Symbol + Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] + else + Δt = Z.timestep + end + vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] - J += 0.5 * vₜ' * (R .* vₜ) + rₜ = Δt .* vₜ + J += 0.5 * rₜ' * (R .* rₜ) end return J end @views function ∇L(Z⃗::AbstractVector{<:Real}, Z::NamedTrajectory) - ∇ = zeros(Z.dim * Z.T) + ∇ = zeros(Z.dim * Z.T) Threads.@threads for t ∈ times - vₜ_slice = slice(t, Z.components[name], Z.dim) + vₜ_slice = slice(t, Z.components[name], Z.dim) vₜ = Z⃗[vₜ_slice] - ∇[vₜ_slice] = R .* vₜ + + if Z.timestep isa Symbol + Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim) + Δt = Z⃗[Δt_slice] + ∇[Δt_slice] .= vₜ' * (R .* (Δt .* vₜ)) + else + Δt = Z.timestep + end + + ∇[vₜ_slice] .= R .* (Δt.^2 .* vₜ) end return ∇ end @@ -341,16 +358,48 @@ function QuadraticRegularizer(; ∂²L_structure = Z -> begin structure = [] - # vₜ Hessian structure (eq. 17) + # Hessian structure (eq. 17) for t ∈ times vₜ_slice = slice(t, Z.components[name], Z.dim) - diag_inds = collect(zip(vₜ_slice, vₜ_slice)) - append!(structure, diag_inds) + vₜ_vₜ_inds = collect(zip(vₜ_slice, vₜ_slice)) + append!(structure, vₜ_vₜ_inds) + + if Z.timestep isa Symbol + Δt_slice = slice(t, Z.components[timestep_symbol], Z.dim) + # ∂²_vₜ_Δt + vₜ_Δt_inds = [(i, j) for i ∈ vₜ_slice for j ∈ Δt_slice] + append!(structure, vₜ_Δt_inds) + # ∂²_Δt_vₜ + Δt_vₜ_inds = [(i, j) for i ∈ Δt_slice for j ∈ vₜ_slice] + append!(structure, Δt_vₜ_inds) + # ∂²_Δt_Δt + Δt_Δt_inds = collect(zip(Δt_slice, Δt_slice)) + append!(structure, Δt_Δt_inds) + end end return structure end - ∂²L = (Z⃗, Z) -> vcat(fill(R, length(times))...) + ∂²L = (Z⃗, Z) -> begin + values = [] + # Match Hessian structure indices + for t ∈ times + if Z.timestep isa Symbol + Δt = Z⃗[slice(t, Z.components[timestep_symbol], Z.dim)] + append!(values, R .* Δt.^2) + # ∂²_vₜ_Δt, ∂²_Δt_vₜ + vₜ = Z⃗[slice(t, Z.components[name], Z.dim)] + append!(values, 2 * (R .* (Δt .* vₜ))) + append!(values, 2 * (R .* (Δt .* vₜ))) + # ∂²_Δt_Δt + append!(values, vₜ' * (R .* vₜ)) + else + Δt = Z.timestep + append!(values, R .* Δt.^2) + end + end + return values + end end return Objective(L, ∇L, ∂²L, ∂²L_structure, Dict[params])