From 163112c46b4afa4ca0a0bf624ea2de7bc3edc224 Mon Sep 17 00:00:00 2001 From: mayitzin Date: Mon, 6 Nov 2023 21:38:01 +0100 Subject: [PATCH] Add validity check to constructor of QuaternionArray. --- ahrs/common/quaternion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ahrs/common/quaternion.py b/ahrs/common/quaternion.py index c73ce01..c1ce914 100644 --- a/ahrs/common/quaternion.py +++ b/ahrs/common/quaternion.py @@ -2116,6 +2116,9 @@ def __new__(subtype, q: np.ndarray = None, versors: bool = True, order: str = 'H q = np.array(q, dtype=float) if q.ndim != 2 or q.shape[-1] not in [3, 4]: raise ValueError(f"Expected array to have shape (N, 4) or (N, 3), got {q.shape}.") + q_norm = np.linalg.norm(q, axis=1) + if sum(~(q_norm > 0)): + raise ValueError("Quaternion values must be non-zero.") # Build pure quaternions if given as N-by-3 array if q.shape[-1] == 3: @@ -2123,7 +2126,7 @@ def __new__(subtype, q: np.ndarray = None, versors: bool = True, order: str = 'H # Normalize quaternions if versors is True if versors: - q /= np.linalg.norm(q, axis=1)[:, None] + q = q / q_norm[:, None] # Create the ndarray instance of type QuaternionArray. This will call # the standard ndarray constructor, but return an object of type