Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTM after next push] fix: linear continuous multiphase continuity #857

Merged
merged 11 commits into from
Mar 12, 2024
28 changes: 25 additions & 3 deletions bioptim/examples/getting_started/example_multiphase.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
MultinodeObjectiveList,
PhaseDynamics,
ControlType,
QuadratureRule,
)


Expand All @@ -42,6 +43,7 @@ def prepare_ocp(
phase_dynamics: PhaseDynamics = PhaseDynamics.SHARED_DURING_THE_PHASE,
expand_dynamics: bool = True,
control_type: ControlType = ControlType.CONSTANT,
quadrature_rule: QuadratureRule = QuadratureRule.RECTANGLE_LEFT,
) -> OptimalControlProgram:
"""
Prepare the ocp
Expand All @@ -65,6 +67,8 @@ def prepare_ocp(
(for instance IRK is not compatible with expanded dynamics)
control_type: ControlType
The type of the controls
quadrature_rule: QuadratureRule
The quadrature method to use to integrate the objective functions

Returns
-------
Expand All @@ -83,9 +87,27 @@ def prepare_ocp(

# Add objective functions
objective_functions = ObjectiveList()
objective_functions.add(ObjectiveFcn.Lagrange.MINIMIZE_CONTROL, key="tau", weight=100, phase=0)
objective_functions.add(ObjectiveFcn.Lagrange.MINIMIZE_CONTROL, key="tau", weight=100, phase=1)
objective_functions.add(ObjectiveFcn.Lagrange.MINIMIZE_CONTROL, key="tau", weight=100, phase=2)
objective_functions.add(
ObjectiveFcn.Lagrange.MINIMIZE_CONTROL,
key="tau",
weight=100,
phase=0,
integration_rule=quadrature_rule,
)
objective_functions.add(
ObjectiveFcn.Lagrange.MINIMIZE_CONTROL,
key="tau",
weight=100,
phase=1,
integration_rule=quadrature_rule,
)
objective_functions.add(
ObjectiveFcn.Lagrange.MINIMIZE_CONTROL,
key="tau",
weight=100,
phase=2,
integration_rule=quadrature_rule,
)

multinode_objective = MultinodeObjectiveList()
multinode_objective.add(
Expand Down
8 changes: 3 additions & 5 deletions bioptim/limits/penalty_option.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Any, Callable

from casadi import vertcat, Function, MX, SX, jacobian, diag
import numpy as np
from casadi import vertcat, Function, MX, SX, jacobian, diag

from .penalty_controller import PenaltyController
from ..limits.penalty_helpers import PenaltyHelpers
from ..misc.enums import Node, PlotType, ControlType, PenaltyType, QuadratureRule, PhaseDynamics
from ..misc.options import OptionGeneric
from ..misc.mapping import BiMapping
from ..misc.options import OptionGeneric
from ..models.protocols.stochastic_biomodel import StochasticBioModel
from ..limits.penalty_helpers import PenaltyHelpers


class PenaltyOption(OptionGeneric):
Expand Down Expand Up @@ -51,8 +51,6 @@ class PenaltyOption(OptionGeneric):
If the minimization is applied to derivative of the penalty [f(t, t+1)]
integration_rule: QuadratureRule
The integration rule to use for the penalty
transition: bool
If the penalty is a transition
nodes_phase: tuple[int, ...]
The index of the phases when penalty is multinodes
penalty_type: PenaltyType
Expand Down
85 changes: 39 additions & 46 deletions bioptim/limits/phase_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@

from casadi import vertcat, MX

from .multinode_penalty import MultinodePenalty, MultinodePenaltyFunctions
from .multinode_constraint import MultinodeConstraint
from .path_conditions import Bounds
from .multinode_penalty import MultinodePenalty, MultinodePenaltyFunctions
from .objective_functions import ObjectiveFunction
from .path_conditions import Bounds
from ..limits.penalty import PenaltyFunctionAbstract, PenaltyController
from ..misc.enums import Node, PenaltyType, InterpolationType
from ..misc.enums import Node, PenaltyType, InterpolationType, ControlType
from ..misc.fcn_enum import FcnEnum
from ..misc.options import UniquePerPhaseOptionList
from ..misc.mapping import BiMapping
from ..misc.options import UniquePerPhaseOptionList


class PhaseTransition(MultinodePenalty):
Expand Down Expand Up @@ -127,48 +127,6 @@ def print(self):
"""
raise NotImplementedError("Printing of PhaseTransitionList is not ready yet")

def prepare_phase_transitions(self, ocp) -> list:
"""
Configure all the phase transitions and put them in a list

Parameters
----------
ocp: OptimalControlProgram
A reference to the ocp

Returns
-------
The list of all the transitions prepared
"""

# By default it assume Continuous. It can be change later
full_phase_transitions = [
PhaseTransition(
phase_pre_idx=i,
transition=PhaseTransitionFcn.CONTINUOUS,
weight=ocp.nlp[i].dynamics_type.state_continuity_weight,
)
for i in range(ocp.n_phases - 1)
]

existing_phases = []

for pt in self:
idx_phase = pt.nodes_phase[0]
if idx_phase >= ocp.n_phases:
raise RuntimeError("Phase index of the phase transition is higher than the number of phases")
existing_phases.append(idx_phase)

if pt.weight:
pt.base = ObjectiveFunction.MayerFunction

if idx_phase % ocp.n_phases == ocp.n_phases - 1:
# Add a cyclic constraint or objective
full_phase_transitions.append(pt)
else:
full_phase_transitions[idx_phase] = pt
return full_phase_transitions


class PhaseTransitionFunctions(PenaltyFunctionAbstract):
"""
Expand Down Expand Up @@ -211,6 +169,40 @@ def continuous(
transition, controllers, "all", states_mapping=states_mapping
)

@staticmethod
def continuous_controls(
transition,
controllers: list[PenaltyController, PenaltyController],
controls_mapping: list[BiMapping, ...] = None,
):
"""
This continuity function is only relevant for ControlType.LINEAR_CONTINUOUS otherwise don't use it.

Parameters
----------
transition : PhaseTransition
A reference to the phase transition
controllers: list[PenaltyController, PenaltyController]
The penalty node elements
controls_mapping: list
A list of the mapping for the states between nodes. It should provide a mapping between 0 and i, where
the first (0) link the controllers[0].controls to a number of values using to_second. Thereafter, the
to_first is used sequentially for all the controllers (meaning controllers[1] uses the
controls_mapping[0].to_first. Therefore, the dimension of the states_mapping
should be 'len(controllers) - 1'

Returns
-------
The difference between the controls after and before
"""
if controls_mapping is not None:
raise NotImplementedError(
"Controls_mapping is not yet implemented "
"for continuous_controls with linear continuous control type."
)

return MultinodePenaltyFunctions.Functions.controls_equality(transition, controllers, "all")

@staticmethod
def discontinuous(transition, controllers: list[PenaltyController, PenaltyController]):
"""
Expand Down Expand Up @@ -353,6 +345,7 @@ class PhaseTransitionFcn(FcnEnum):
"""

CONTINUOUS = (PhaseTransitionFunctions.Functions.continuous,)
CONTINUOUS_CONTROLS = (PhaseTransitionFunctions.Functions.continuous_controls,)
DISCONTINUOUS = (PhaseTransitionFunctions.Functions.discontinuous,)
IMPACT = (PhaseTransitionFunctions.Functions.impact,)
CYCLIC = (PhaseTransitionFunctions.Functions.cyclic,)
Expand Down
100 changes: 100 additions & 0 deletions bioptim/limits/phase_transtion_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .objective_functions import ObjectiveFunction
from .phase_transition import PhaseTransition, PhaseTransitionFcn, PhaseTransitionList
from ..misc.enums import ControlType


class PhaseTransitionFactory:
"""
A class to prepare the phase transitions for the ocp builder

Methods
-------
create_default_transitions()
Create the default phase transitions for states continuity between phases.
extend_transitions_for_linear_continuous()
Add phase transitions for linear continuous controls.
update_existing_transitions()
Update the existing phase transitions with Mayer functions and add cyclic transitions

Attributes
----------
ocp: OptimalControlProgram
A reference to the ocp
full_phase_transitions: list[PhaseTransition]
The list of all the transitions prepared
"""

def __init__(self, ocp):
self.ocp = ocp
self.full_phase_transitions = self.create_default_transitions()

def create_default_transitions(self) -> list[PhaseTransition]:
"""Create the default phase transitions for states continuity between phases."""
return [
PhaseTransition(
phase_pre_idx=i,
transition=PhaseTransitionFcn.CONTINUOUS,
weight=self.ocp.nlp[i].dynamics_type.state_continuity_weight,
)
for i in range(self.ocp.n_phases - 1)
]

def extend_transitions_for_linear_continuous_controls(self):
"""Add phase transitions for linear continuous controls.
This is a special case where the controls are continuous"""

for phase, nlp in enumerate(self.ocp.nlp[:-1]):
if nlp.control_type == ControlType.LINEAR_CONTINUOUS:
self.full_phase_transitions.append(
PhaseTransition(
phase_pre_idx=phase,
transition=PhaseTransitionFcn.CONTINUOUS_CONTROLS,
weight=None, # Continuity always enforced by the linear continuous control
)
)

def check_phase_index(self, idx_phase):
"""Check if the phase index is valid."""
if idx_phase >= self.ocp.n_phases:
raise RuntimeError("Phase index of the phase transition is higher than the number of phases")

def update_transition_base(self, pt):
"""Update the transition base with Mayer functions
if the user provided a weight like for an objective function."""
if pt.weight:
pt.base = ObjectiveFunction.MayerFunction

def handle_cyclic_transition(self, idx_phase, pt):
"""The case of a cyclic transition, the terminal phase is linked (n) to the initial phase (0)."""
if idx_phase % self.ocp.n_phases == self.ocp.n_phases - 1:
self.full_phase_transitions.append(pt)
else:
self.full_phase_transitions[idx_phase] = pt

def update_existing_transitions(self, phase_transition_list) -> list[PhaseTransition]:
"""Update the existing phase transitions with Mayer functions and add cyclic transitions."""
existing_phases = []
for pt in phase_transition_list:
idx_phase = pt.nodes_phase[0]
self.check_phase_index(idx_phase)
existing_phases.append(idx_phase)
self.update_transition_base(pt)
self.handle_cyclic_transition(idx_phase, pt)
return self.full_phase_transitions

def prepare_phase_transitions(self, phase_transition_list: PhaseTransitionList) -> list[PhaseTransition]:
"""
Configure all the phase transitions and put them in a list

Parameters
----------
phase_transition_list: PhaseTransitionList
The phase transitions to prepare added by the user

Returns
-------
list[PhaseTransition]
The list of all the transitions prepared
"""
self.extend_transitions_for_linear_continuous_controls()
return self.update_existing_transitions(phase_transition_list)
Loading
Loading