Skip to content

Commit

Permalink
Forced equality in sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
pariterre committed Feb 8, 2024
1 parent 2a81e46 commit deab218
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 76 deletions.
26 changes: 13 additions & 13 deletions bioptim/models/biorbd/biorbd_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,27 +117,27 @@ def mass(self) -> MX:
return self.model.mass().to_mx()

def check_q_size(self, q):
if q.shape[0] > self.nb_q:
raise ValueError(f"Length of q is too big. Expected size: {self.nb_q}, but got: {q.shape[0]}")
if q.shape[0] != self.nb_q:
raise ValueError(f"Length of q size should be: {self.nb_q}, but got: {q.shape[0]}")

def check_qdot_size(self, qdot):
if qdot.shape[0] > self.nb_qdot:
raise ValueError(f"Length of qdot is too big. Expected size: {self.nb_qdot}, but got: {qdot.shape[0]}")
if qdot.shape[0] != self.nb_qdot:
raise ValueError(f"Length of qdot size should be: {self.nb_qdot}, but got: {qdot.shape[0]}")

def check_qddot_size(self, qddot):
if qddot.shape[0] > self.nb_qddot:
raise ValueError(f"Length of qddot is too big. Expected size: {self.nb_qddot}, but got: {qddot.shape[0]}")
if qddot.shape[0] != self.nb_qddot:
raise ValueError(f"Length of qddot size should be: {self.nb_qddot}, but got: {qddot.shape[0]}")

def check_qddot_joints_size(self, qddot_joints):
nb_qddot_joints = self.nb_q - self.nb_root
if qddot_joints.shape[0] > nb_qddot_joints:
if qddot_joints.shape[0] != nb_qddot_joints:
raise ValueError(
f"Length of qddot_joints is too big. Expected size: {nb_qddot_joints}, but got: {qddot_joints.shape[0]}"
f"Length of qddot_joints size should be: {nb_qddot_joints}, but got: {qddot_joints.shape[0]}"
)

def check_tau_size(self, tau):
if tau.shape[0] > self.nb_tau:
raise ValueError(f"Length of tau is too big. Expected size: {self.nb_tau}, but got: {tau.shape[0]}")
if tau.shape[0] != self.nb_tau:
raise ValueError(f"Length of tau size should be: {self.nb_tau}, but got: {tau.shape[0]}")

def check_muscle_size(self, muscle):
if isinstance(muscle, list):
Expand All @@ -147,8 +147,8 @@ def check_muscle_size(self, muscle):
else:
raise TypeError("Unsupported type for muscle.")

if muscle_size > self.nb_muscles:
raise ValueError(f"Length of muscle is too big. Expected size: {self.nb_muscles}, but got: {muscle_size}")
if muscle_size != self.nb_muscles:
raise ValueError(f"Length of muscle size should be: {self.nb_muscles}, but got: {muscle_size}")

def center_of_mass(self, q) -> MX:
self.check_q_size(q)
Expand Down Expand Up @@ -277,7 +277,7 @@ def torque(self, tau_activations, q, qdot) -> MX:
def forward_dynamics_free_floating_base(self, q, qdot, qddot_joints) -> MX:
self.check_q_size(q)
self.check_qdot_size(qdot)
self.check_qddot_size(qddot_joints)
self.check_qddot_joints_size(qddot_joints)
q_biorbd = GeneralizedCoordinates(q)
qdot_biorbd = GeneralizedVelocity(qdot)
qddot_joints_biorbd = GeneralizedAcceleration(qddot_joints)
Expand Down
Loading

0 comments on commit deab218

Please sign in to comment.