diff --git a/Quaternion/Quaternion.py b/Quaternion/Quaternion.py index 017c938..305b1b3 100644 --- a/Quaternion/Quaternion.py +++ b/Quaternion/Quaternion.py @@ -116,29 +116,49 @@ class Quat(object): """ - def __init__(self, attitude, intype=None): + def __init__(self, attitude=None, transform=None, q=None, equatorial=None): + npar = (int(attitude is not None) + int(transform is not None) + + int(q is not None) + int(equatorial is not None)) + if npar != 1: + raise Exception('One and only one of attitude, transform, quaternion, equatorial must be given ({})'.format(npar)) self._q = None self._equatorial = None self._T = None # checks to see if we've been passed a Quat if isinstance(attitude, Quat): - self._set_q(attitude.q) - else: - # make it an array and check to see if it is a supported shape + q = attitude.q + elif attitude is not None: + # check to see if it is a supported shape attitude = np.array(attitude) - if ((attitude.shape == (3, 3) and (intype is None or intype == 'transform')) - or (attitude.ndim == 3 and attitude.shape[-1] == 3 and attitude.shape[2] == 3)): - self._set_transform(attitude) - elif (intype == 'quaternion' or attitude.shape == (4,) or - (attitude.ndim == 2 and attitude.shape[-1] == 4)): - self._set_q(attitude) - elif (intype == 'equatorial' or attitude.shape == (3,) - or (attitude.ndim == 2 and attitude.shape[-1] == 3)): - self._set_equatorial(attitude) + if attitude.shape == (4,): + q = attitude + elif attitude.shape == (3, 3): + transform = attitude + elif attitude.shape == (3,): + equatorial = attitude else: raise TypeError( - "attitude is not one of possible types " - "(3 or 4 elements, Quat, or 3x3 matrix, or N x (those types))") + "attitude is not one of possible types (3 or 4 elements, Quat, or 3x3 matrix): {}".format(attitude.shape)) + + # checking correct shapes + if q is not None: + q = np.atleast_1d(q) + if q.shape[-1:] != (4,): + raise TypeError("Creating a Quaternion from quaternion(s) " + "requires shape (..., 4), not {}".format(q.shape)) + self._set_q(q) + elif transform is not None: + transform = np.atleast_2d(transform) + if transform.shape[-2:] != (3, 3): + raise TypeError("Creating a Quaternion from quaternion(s) " + "requires shape (..., 3, 3), not {}".format(transform.shape)) + self._set_transform(transform) + elif equatorial is not None: + equatorial = np.atleast_1d(equatorial) + if equatorial.shape[-1:] != (3,): + raise TypeError("Creating a Quaternion from ra, dec, roll " + "requires shape (..., 3), not {}".format(equatorial.shape)) + self._set_equatorial(equatorial) def _set_q(self, q): """ @@ -147,14 +167,12 @@ def _set_q(self, q): :param q: list or array of normalized quaternion elements """ - q = np.array(q) - if q.ndim == 1: - q = q[np.newaxis] - if np.any((np.sum(q ** 2, axis=-1)[:, np.newaxis] - 1.0) > 1e-6): + q = np.atleast_2d(np.array(q)) + if np.any((np.sum(q ** 2, axis=-1, keepdims=True) - 1.0) > 1e-6): raise ValueError( 'Quaternions must be normalized so sum(q**2) == 1; use Quaternion.normalize') self._q = q - flip_q = q[:, 3] < 0 + flip_q = q[..., 3] < 0 self._q[flip_q] = -1 * q[flip_q] # Erase internal values of other representations self._equatorial = None @@ -196,9 +214,7 @@ def _set_equatorial(self, equatorial): :param equatorial: list or array [ RA, Dec, Roll] in degrees """ - self._equatorial = np.array(equatorial) - if self._equatorial.ndim == 1: - self._equatorial = self._equatorial[np.newaxis] + self._equatorial = np.atleast_2d(np.array(equatorial)) def _get_equatorial(self): """Retrieve [RA, Dec, Roll] @@ -219,24 +235,15 @@ def _get_equatorial(self): def _get_ra(self): """Retrieve RA term from equatorial system in degrees""" - if self.equatorial.ndim == 2: - return self.equatorial[:, 0] - else: - return self.equatorial[0] + return np.squeeze(self.equatorial[..., 0]) def _get_dec(self): """Retrieve Dec term from equatorial system in degrees""" - if self.equatorial.ndim == 2: - return self.equatorial[:, 1] - else: - return self.equatorial[1] + return np.squeeze(self.equatorial[..., 1]) def _get_roll(self): """Retrieve Roll term from equatorial system in degrees""" - if self.equatorial.ndim == 2: - return self.equatorial[:, 2] - else: - return self.equatorial[2] + return np.squeeze(self.equatorial[..., 2]) ra = property(_get_ra) dec = property(_get_dec) @@ -247,9 +254,7 @@ def _get_zero(self, attr): Return a version of attr that is between -180 <= val < 180 """ if not hasattr(self, '_' + attr): - val = getattr(self, attr) - if val.ndim == 0: - val = val[np.newaxis] + val = np.atleast_1d(getattr(self, attr)) val = val % 360 val[val >= 180] -= 360 return val @@ -324,17 +329,15 @@ def _quat2equatorial(self): :rtype: numpy array [ra,dec,roll] """ - q = self.q - if q.ndim == 1: - q = q[np.newaxis] + q = np.atleast_2d(self.q) q2 = q ** 2 # calculate direction cosine matrix elements from $quaternions - xa = q2[:, 0] - q2[:, 1] - q2[:, 2] + q2[:, 3] - xb = 2 * (q[:, 0] * q[:, 1] + q[:, 2] * q[:, 3]) - xn = 2 * (q[:, 0] * q[:, 2] - q[:, 1] * q[:, 3]) - yn = 2 * (q[:, 1] * q[:, 2] + q[:, 0] * q[:, 3]) - zn = q2[:, 3] + q2[:, 2] - q2[:, 0] - q2[:, 1] + xa = q2[..., 0] - q2[..., 1] - q2[..., 2] + q2[..., 3] + xb = 2 * (q[..., 0] * q[..., 1] + q[..., 2] * q[..., 3]) + xn = 2 * (q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3]) + yn = 2 * (q[..., 1] * q[..., 2] + q[..., 0] * q[..., 3]) + zn = q2[..., 3] + q2[..., 2] - q2[..., 0] - q2[..., 1] # Due to numerical precision this can go negative. Allow *slightly* negative # values but raise an exception otherwise. @@ -350,7 +353,24 @@ def _quat2equatorial(self): roll = np.degrees(np.arctan2(yn, zn)) ra[ra < 0] = ra[ra < 0] + 360 roll[roll < 0] = roll[roll < 0] + 360 - return np.array([ra, dec, roll]).transpose() + return np.moveaxis(np.array([ra, dec, roll]), 0, -1) + + + def _quat2equatorial_norm(self): + q = np.atleast_2d(self.q) + q2 = q ** 2 + + # calculate direction cosine matrix elements from $quaternions + xa = q2[..., 0] - q2[..., 1] - q2[..., 2] + q2[..., 3] + xb = 2 * (q[..., 0] * q[..., 1] + q[..., 2] * q[..., 3]) + xn = 2 * (q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3]) + yn = 2 * (q[..., 1] * q[..., 2] + q[..., 0] * q[..., 3]) + zn = q2[..., 3] + q2[..., 2] - q2[..., 0] - q2[..., 1] + + # Due to numerical precision this can go negative. Allow *slightly* negative + # values but raise an exception otherwise. + one_minus_xn2 = 1 - xn**2 + return one_minus_xn2 # _quat2transform is largely from Enthought's quaternion.rotmat, though this math is @@ -394,11 +414,9 @@ def _quat2transform(self): :rtype: numpy array """ - q = self.q - if q.ndim == 1: - q = q[np.newaxis] + q = np.atleast_2d(self.q) - x, y, z, w = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] xx2 = x * x * 2. yy2 = y * y * 2. zz2 = z * z * 2. @@ -409,16 +427,16 @@ def _quat2transform(self): yz2 = y * z * 2. wx2 = w * x * 2. - t = np.empty((len(q), 3, 3), float) - t[:, 0, 0] = 1. - yy2 - zz2 - t[:, 0, 1] = xy2 - wz2 - t[:, 0, 2] = zx2 + wy2 - t[:, 1, 0] = xy2 + wz2 - t[:, 1, 1] = 1. - xx2 - zz2 - t[:, 1, 2] = yz2 - wx2 - t[:, 2, 0] = zx2 - wy2 - t[:, 2, 1] = yz2 + wx2 - t[:, 2, 2] = 1. - xx2 - yy2 + t = np.empty(tuple(q.shape[:-1] + (3, 3)), float) + t[..., 0, 0] = 1. - yy2 - zz2 + t[..., 0, 1] = xy2 - wz2 + t[..., 0, 2] = zx2 + wy2 + t[..., 1, 0] = xy2 + wz2 + t[..., 1, 1] = 1. - xx2 - zz2 + t[..., 1, 2] = yz2 - wx2 + t[..., 2, 0] = zx2 - wy2 + t[..., 2, 1] = yz2 + wx2 + t[..., 2, 2] = 1. - xx2 - yy2 return t @@ -454,7 +472,7 @@ def _equatorial2transform(self): [-ca * sd * sr - sa * cr, -sa * sd * sr + ca * cr, cd * sr], [-ca * sd * cr + sa * sr, -sa * sd * cr - ca * sr, cd * cr]]) - return rmat.transpose() + return np.moveaxis(np.moveaxis(rmat, 0, -1), 0, -2) def _transform2quat(self): """Construct quaternions from the transform/rotation matrices @@ -467,42 +485,42 @@ def _transform2quat(self): if transform.ndim == 2: transform = transform[np.newaxis] # Code was copied from perl PDL code that uses backwards index ordering - T = transform.transpose(0, 2, 1) - den = np.array([1.0 + T[:, 0, 0] - T[:, 1, 1] - T[:, 2, 2], - 1.0 - T[:, 0, 0] + T[:, 1, 1] - T[:, 2, 2], - 1.0 - T[:, 0, 0] - T[:, 1, 1] + T[:, 2, 2], - 1.0 + T[:, 0, 0] + T[:, 1, 1] + T[:, 2, 2]]) + T = transform.swapaxes(-2, -1) + den = np.array([1.0 + T[..., 0, 0] - T[..., 1, 1] - T[..., 2, 2], + 1.0 - T[..., 0, 0] + T[..., 1, 1] - T[..., 2, 2], + 1.0 - T[..., 0, 0] - T[..., 1, 1] + T[..., 2, 2], + 1.0 + T[..., 0, 0] + T[..., 1, 1] + T[..., 2, 2]]) half_rt_q_max = 0.5 * np.sqrt(np.max(den, axis=0)) max_idx = np.argmax(den, axis=0) - poss_quat = np.zeros((4, len(T), 4)) + poss_quat = np.zeros(tuple((4,) + T.shape[:-2] + (4,))) denom = 4.0 * half_rt_q_max - poss_quat[0] = np.transpose( + poss_quat[0] = np.moveaxis( np.array( [half_rt_q_max, - (T[:, 1, 0] + T[:, 0, 1]) / denom, - (T[:, 2, 0] + T[:, 0, 2]) / denom, - -(T[:, 2, 1] - T[:, 1, 2]) / denom])) - poss_quat[1] = np.transpose( + (T[..., 1, 0] + T[..., 0, 1]) / denom, + (T[..., 2, 0] + T[..., 0, 2]) / denom, + -(T[..., 2, 1] - T[..., 1, 2]) / denom]), 0, -1) + poss_quat[1] = np.moveaxis( np.array( - [(T[:, 1, 0] + T[:, 0, 1]) / denom, + [(T[..., 1, 0] + T[..., 0, 1]) / denom, half_rt_q_max, - (T[:, 2, 1] + T[:, 1, 2]) / denom, - -(T[:, 0, 2] - T[:, 2, 0]) / denom])) - poss_quat[2] = np.transpose( + (T[..., 2, 1] + T[..., 1, 2]) / denom, + -(T[..., 0, 2] - T[..., 2, 0]) / denom]), 0, -1) + poss_quat[2] = np.moveaxis( np.array( - [(T[:, 2, 0] + T[:, 0, 2]) / denom, - (T[:, 2, 1] + T[:, 1, 2]) / denom, + [(T[..., 2, 0] + T[..., 0, 2]) / denom, + (T[..., 2, 1] + T[..., 1, 2]) / denom, half_rt_q_max, - -(T[:, 1, 0] - T[:, 0, 1]) / denom])) - poss_quat[3] = np.transpose( + -(T[..., 1, 0] - T[..., 0, 1]) / denom]), 0, -1) + poss_quat[3] = np.moveaxis( np.array( - [-(T[:, 2, 1] - T[:, 1, 2]) / denom, - -(T[:, 0, 2] - T[:, 2, 0]) / denom, - -(T[:, 1, 0] - T[:, 0, 1]) / denom, - half_rt_q_max])) + [-(T[..., 2, 1] - T[..., 1, 2]) / denom, + -(T[..., 0, 2] - T[..., 2, 0]) / denom, + -(T[..., 1, 0] - T[..., 0, 1]) / denom, + half_rt_q_max]), 0, -1) - q = np.zeros((len(T), 4)) + q = np.zeros(tuple(T.shape[:-2] + (4,))) for idx in range(0, 4): max_match = max_idx == idx q[max_match] = poss_quat[idx][max_match] @@ -563,18 +581,14 @@ def __mul__(self, quat2): :rtype: Quat """ - q1 = self.q - if q1.ndim == 1: - q1 = q1[np.newaxis] - q2 = quat2.q - if q2.ndim == 1: - q2 = q2[np.newaxis] + q1 = np.atleast_2d(self.q) + q2 = np.atleast_2d(quat2.q) mult = np.zeros((len(q1), 4)) - mult[:,0] = q1[:,3] * q2[:,0] - q1[:,2] * q2[:,1] + q1[:,1] * q2[:,2] + q1[:,0] * q2[:,3] - mult[:,1] = q1[:,2] * q2[:,0] + q1[:,3] * q2[:,1] - q1[:,0] * q2[:,2] + q1[:,1] * q2[:,3] - mult[:,2] = -q1[:,1] * q2[:,0] + q1[:,0] * q2[:,1] + q1[:,3] * q2[:,2] + q1[:,2] * q2[:,3] - mult[:,3] = -q1[:,0] * q2[:,0] - q1[:,1] * q2[:,1] - q1[:,2] * q2[:,2] + q1[:,3] * q2[:,3] - return Quat(mult) + mult[...,0] = q1[...,3] * q2[...,0] - q1[...,2] * q2[...,1] + q1[...,1] * q2[...,2] + q1[...,0] * q2[...,3] + mult[...,1] = q1[...,2] * q2[...,0] + q1[...,3] * q2[...,1] - q1[...,0] * q2[...,2] + q1[...,1] * q2[...,3] + mult[...,2] = -q1[...,1] * q2[...,0] + q1[...,0] * q2[...,1] + q1[...,3] * q2[...,2] + q1[...,2] * q2[...,3] + mult[...,3] = -q1[...,0] * q2[...,0] - q1[...,1] * q2[...,1] - q1[...,2] * q2[...,2] + q1[...,3] * q2[...,3] + return Quat(q=mult) def inv(self): """ @@ -583,10 +597,8 @@ def inv(self): :returns: inverted quaternion :rtype: Quat """ - q = self.q - if q.ndim == 1: - q = q[np.newaxis] - return Quat(np.array([q[:, 0], q[:, 1], q[:, 2], -1.0 * q[:, 3]]).transpose()) + q = np.atleast_2d(self.q) + return Quat(q=np.array([q[..., 0], q[..., 1], q[..., 2], -1.0 * q[..., 3]]).swapaxes(-2, -1)) def dq(self, q2): """ @@ -621,9 +633,4 @@ def normalize(array): """ quat = np.array(array) - if quat.ndim == 1: - return quat / np.sqrt(np.dot(quat, quat)) - elif quat.ndim == 2: - return quat / np.sqrt(np.sum(quat * quat, axis=-1)[:, np.newaxis]) - else: - raise TypeError("Input must be 1 or 2d") + return np.squeeze(quat/np.sqrt(np.sum(quat * quat, axis=-1, keepdims=True))) diff --git a/Quaternion/tests/test_all.py b/Quaternion/tests/test_all.py index acb63c2..333aef6 100644 --- a/Quaternion/tests/test_all.py +++ b/Quaternion/tests/test_all.py @@ -1,4 +1,6 @@ import numpy as np +import pytest + from .. import Quat ra = 10. @@ -6,6 +8,52 @@ roll = 30. q0 = Quat([ra, dec, roll]) +equatorial_23 = np.array([[[ 10, 20, 30], + [ 10, 20, -30], + [ 10, -90, 30]], + [[ 10, 20, 0], + [ 10, 90, 30], + [ 10, 90, -30]]]) + +q_23 = np.array([[[ 0.26853582, -0.14487813, 0.12767944, 0.94371436], + [-0.23929834, -0.18930786, 0.03813458, 0.95154852], + [ 0.1227878 , 0.69636424, -0.1227878 , 0.69636424]], + [[ 0.01513444, -0.17298739, 0.08583165, 0.98106026], + [ 0.24184476, -0.66446302, 0.24184476, 0.66446302], + [ 0.1227878 , 0.69636424, 0.1227878 , -0.69636424]]]) + +transform_23 = np.array([[[[ 9.25416578e-01, -3.18795778e-01, -2.04874129e-01], + [ 1.63175911e-01, 8.23172945e-01, -5.43838142e-01], + [ 3.42020143e-01, 4.69846310e-01, 8.13797681e-01]], + [[ 9.25416578e-01, 1.80283112e-02, -3.78522306e-01], + [ 1.63175911e-01, 8.82564119e-01, 4.40969611e-01], + [ 3.42020143e-01, -4.69846310e-01, 8.13797681e-01]], + [[ 6.03020831e-17, 3.42020143e-01, 9.39692621e-01], + [ 1.06328842e-17, 9.39692621e-01, -3.42020143e-01], + [ -1.00000000e+00, 3.06161700e-17, 5.30287619e-17]]], + [[[ 9.25416578e-01, -1.73648178e-01, -3.36824089e-01], + [ 1.63175911e-01, 9.84807753e-01, -5.93911746e-02], + [ 3.42020143e-01, 0.00000000e+00, 9.39692621e-01]], + [[ 6.03020831e-17, -6.42787610e-01, -7.66044443e-01], + [ 1.06328842e-17, 7.66044443e-01, -6.42787610e-01], + [ 1.00000000e+00, 3.06161700e-17, 5.30287619e-17]], + [[ 6.03020831e-17, 3.42020143e-01, -9.39692621e-01], + [ 1.06328842e-17, 9.39692621e-01, 3.42020143e-01], + [ 1.00000000e+00, -3.06161700e-17, 5.30287619e-17]]]]) + + +def test_init_exceptions(): + with pytest.raises(TypeError): + q = Quat(np.zeros((2,))) + with pytest.raises(TypeError): + q = Quat(np.zeros((5,))) + with pytest.raises(TypeError): + q = Quat(equatorial_23) + with pytest.raises(TypeError): + q = Quat(q_23) + with pytest.raises(TypeError): + q = Quat(transform_23) + def test_from_eq(): q = Quat([ra, dec, roll]) @@ -17,6 +65,28 @@ def test_from_eq(): assert np.allclose(q.ra0, 10) +def test_from_eq_vectorized(): + # the following line would give unexpected results + # because the input is interpreted as a (non-vectorized) transform + # the shape of the input is (3,3) + # q = Quat(equatorial_23[0]) + + # this is the proper way: + q = Quat(equatorial=equatorial_23[0]) + assert q.q.shape == (3, 4) + assert np.allclose(q.q, q_23[0]) + + q = Quat(equatorial=equatorial_23) + assert q.q.shape == (2, 3, 4) + assert np.allclose(q.q, q_23) + + +def test_transform_from_eq(): + q = Quat(equatorial=equatorial_23) + assert q.transform.shape == (2, 3, 3, 3) + assert np.allclose(q.transform, transform_23) + + def test_from_transform(): """Initialize from inverse of q0 via transform matrix""" q = Quat(q0.transform.transpose()) @@ -30,6 +100,33 @@ def test_from_transform(): assert np.allclose(q.ra0, 10) +def test_from_transform_vectorized(): + q = Quat(transform=transform_23) + assert q.q.shape == (2, 3, 4) + assert np.allclose(q.q, q_23) + + +def test_eq_from_transform(): + # this raises 'Unexpected negative norm' exception due to roundoff in copy/paste above + #q = Quat(transform=transform_23) + #assert q.equatorial.shape == (2, 3, 3) + #assert np.allclose(q.equatorial, equatorial_23) + + # this one fails (quaternion -> equatorial -> quaternion is not an identity) + #q = Quat(transform=np.vstack([q0.transform[np.newaxis], q0.transform[np.newaxis]])) + #assert np.allclose(q.roll0, 30) + #assert np.allclose(q.ra0, 10) + + t = np.zeros((4,5,3,3)) + t[:] = q0.transform[np.newaxis][np.newaxis] + q = Quat(transform=t) + print('roll', q.roll0) + assert np.allclose(q.roll0, 30) + assert np.allclose(q.ra0, 10) + + assert q.equatorial.shape == (4, 5, 3) + + def test_inv_eq(): q = Quat(q0.equatorial) t = q.transform