From 708c3b8a8a5321e6e51465b0ab4d3ca7d650b56f Mon Sep 17 00:00:00 2001 From: mayitzin Date: Wed, 6 Sep 2023 10:26:44 +0200 Subject: [PATCH] Remove loops for 'hughes' and 'chiaverini' (they vectorize 3d arrays already). Add handling of wrong input in method 'from_DCM()'. --- ahrs/common/quaternion.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/ahrs/common/quaternion.py b/ahrs/common/quaternion.py index c558625..952d7b5 100644 --- a/ahrs/common/quaternion.py +++ b/ahrs/common/quaternion.py @@ -2588,14 +2588,17 @@ def from_DCM(self, DCM: np.ndarray, method: str='chiaverini', inplace: bool = Tr array([0.94371436, 0.26853582, 0.14487813, 0.12767944]) """ + # Handle input + if method.lower() not in ['chiaverini', 'hughes', 'itzhack', 'sarabandi', 'shepperd']: + raise ValueError(f"Method '{method}' not available. Options are: 'chiaverini', 'hughes', 'itzhack', 'sarabandi', and 'shepperd'.") + _assert_iterables(DCM, 'Direction Cosine Matrices') + # Allocate local quaternion array quaternion_array = np.zeros((DCM.shape[0], 4)) try: if method.lower() == 'hughes': - for i, R in enumerate(DCM): - quaternion_array[i] = hughes(R) + quaternion_array = hughes(DCM) if method.lower() == 'chiaverini': - for i, R in enumerate(DCM): - quaternion_array[i] = chiaverini(R) + quaternion_array = chiaverini(DCM) if method.lower() == 'shepperd': for i, R in enumerate(DCM): quaternion_array[i] = shepperd(R) @@ -2603,14 +2606,13 @@ def from_DCM(self, DCM: np.ndarray, method: str='chiaverini', inplace: bool = Tr version = kw.get('version', 3) for i, R in enumerate(DCM): quaternion_array[i] = itzhack(R, version=version) - q = itzhack(self.A, version=kw.get('version', 3)) if method.lower() == 'sarabandi': threshold = kw.get('threshold', 0.0) for i, R in enumerate(DCM): quaternion_array[i] = sarabandi(R, eta=threshold) except RuntimeWarning: failed_DCM = DCM[i] - msg = f"Method '{method}' failed at DCM:\n{failed_DCM}\n" + msg = f"Method '{method}' failed at DCM[{i}]:\n{failed_DCM}\n" raise RuntimeError(msg) quaternion_array /= np.linalg.norm(quaternion_array, axis=1)[:, None] if inplace: