Skip to content

Commit

Permalink
generalized for arbitrary shape and updated interface according to is…
Browse files Browse the repository at this point in the history
…sue #9
  • Loading branch information
javierggt committed Sep 7, 2019
1 parent 4fd5f18 commit 2971311
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 107 deletions.
221 changes: 114 additions & 107 deletions Quaternion/Quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,29 +116,49 @@ class Quat(object):
"""

def __init__(self, attitude, intype=None):
def __init__(self, attitude=None, transform=None, q=None, equatorial=None):
npar = (int(attitude is not None) + int(transform is not None) +
int(q is not None) + int(equatorial is not None))
if npar != 1:
raise Exception('One and only one of attitude, transform, quaternion, equatorial must be given ({})'.format(npar))
self._q = None
self._equatorial = None
self._T = None
# checks to see if we've been passed a Quat
if isinstance(attitude, Quat):
self._set_q(attitude.q)
else:
# make it an array and check to see if it is a supported shape
q = attitude.q
elif attitude is not None:
# check to see if it is a supported shape
attitude = np.array(attitude)
if ((attitude.shape == (3, 3) and (intype is None or intype == 'transform'))
or (attitude.ndim == 3 and attitude.shape[-1] == 3 and attitude.shape[2] == 3)):
self._set_transform(attitude)
elif (intype == 'quaternion' or attitude.shape == (4,) or
(attitude.ndim == 2 and attitude.shape[-1] == 4)):
self._set_q(attitude)
elif (intype == 'equatorial' or attitude.shape == (3,)
or (attitude.ndim == 2 and attitude.shape[-1] == 3)):
self._set_equatorial(attitude)
if attitude.shape == (4,):
q = attitude
elif attitude.shape == (3, 3):
transform = attitude
elif attitude.shape == (3,):
equatorial = attitude
else:
raise TypeError(
"attitude is not one of possible types "
"(3 or 4 elements, Quat, or 3x3 matrix, or N x (those types))")
"attitude is not one of possible types (3 or 4 elements, Quat, or 3x3 matrix): {}".format(attitude.shape))

# checking correct shapes
if q is not None:
q = np.atleast_1d(q)
if q.shape[-1:] != (4,):
raise TypeError("Creating a Quaternion from quaternion(s) "
"requires shape (..., 4), not {}".format(q.shape))
self._set_q(q)
elif transform is not None:
transform = np.atleast_2d(transform)
if transform.shape[-2:] != (3, 3):
raise TypeError("Creating a Quaternion from quaternion(s) "
"requires shape (..., 3, 3), not {}".format(transform.shape))
self._set_transform(transform)
elif equatorial is not None:
equatorial = np.atleast_1d(equatorial)
if equatorial.shape[-1:] != (3,):
raise TypeError("Creating a Quaternion from ra, dec, roll "
"requires shape (..., 3), not {}".format(equatorial.shape))
self._set_equatorial(equatorial)

def _set_q(self, q):
"""
Expand All @@ -147,14 +167,12 @@ def _set_q(self, q):
:param q: list or array of normalized quaternion elements
"""
q = np.array(q)
if q.ndim == 1:
q = q[np.newaxis]
if np.any((np.sum(q ** 2, axis=-1)[:, np.newaxis] - 1.0) > 1e-6):
q = np.atleast_2d(np.array(q))
if np.any((np.sum(q ** 2, axis=-1, keepdims=True) - 1.0) > 1e-6):
raise ValueError(
'Quaternions must be normalized so sum(q**2) == 1; use Quaternion.normalize')
self._q = q
flip_q = q[:, 3] < 0
flip_q = q[..., 3] < 0
self._q[flip_q] = -1 * q[flip_q]
# Erase internal values of other representations
self._equatorial = None
Expand Down Expand Up @@ -196,9 +214,7 @@ def _set_equatorial(self, equatorial):
:param equatorial: list or array [ RA, Dec, Roll] in degrees
"""
self._equatorial = np.array(equatorial)
if self._equatorial.ndim == 1:
self._equatorial = self._equatorial[np.newaxis]
self._equatorial = np.atleast_2d(np.array(equatorial))

def _get_equatorial(self):
"""Retrieve [RA, Dec, Roll]
Expand All @@ -219,24 +235,15 @@ def _get_equatorial(self):

def _get_ra(self):
"""Retrieve RA term from equatorial system in degrees"""
if self.equatorial.ndim == 2:
return self.equatorial[:, 0]
else:
return self.equatorial[0]
return np.squeeze(self.equatorial[..., 0])

def _get_dec(self):
"""Retrieve Dec term from equatorial system in degrees"""
if self.equatorial.ndim == 2:
return self.equatorial[:, 1]
else:
return self.equatorial[1]
return np.squeeze(self.equatorial[..., 1])

def _get_roll(self):
"""Retrieve Roll term from equatorial system in degrees"""
if self.equatorial.ndim == 2:
return self.equatorial[:, 2]
else:
return self.equatorial[2]
return np.squeeze(self.equatorial[..., 2])

ra = property(_get_ra)
dec = property(_get_dec)
Expand All @@ -247,9 +254,7 @@ def _get_zero(self, attr):
Return a version of attr that is between -180 <= val < 180
"""
if not hasattr(self, '_' + attr):
val = getattr(self, attr)
if val.ndim == 0:
val = val[np.newaxis]
val = np.atleast_1d(getattr(self, attr))
val = val % 360
val[val >= 180] -= 360
return val
Expand Down Expand Up @@ -324,17 +329,15 @@ def _quat2equatorial(self):
:rtype: numpy array [ra,dec,roll]
"""

q = self.q
if q.ndim == 1:
q = q[np.newaxis]
q = np.atleast_2d(self.q)
q2 = q ** 2

# calculate direction cosine matrix elements from $quaternions
xa = q2[:, 0] - q2[:, 1] - q2[:, 2] + q2[:, 3]
xb = 2 * (q[:, 0] * q[:, 1] + q[:, 2] * q[:, 3])
xn = 2 * (q[:, 0] * q[:, 2] - q[:, 1] * q[:, 3])
yn = 2 * (q[:, 1] * q[:, 2] + q[:, 0] * q[:, 3])
zn = q2[:, 3] + q2[:, 2] - q2[:, 0] - q2[:, 1]
xa = q2[..., 0] - q2[..., 1] - q2[..., 2] + q2[..., 3]
xb = 2 * (q[..., 0] * q[..., 1] + q[..., 2] * q[..., 3])
xn = 2 * (q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3])
yn = 2 * (q[..., 1] * q[..., 2] + q[..., 0] * q[..., 3])
zn = q2[..., 3] + q2[..., 2] - q2[..., 0] - q2[..., 1]

# Due to numerical precision this can go negative. Allow *slightly* negative
# values but raise an exception otherwise.
Expand All @@ -350,7 +353,24 @@ def _quat2equatorial(self):
roll = np.degrees(np.arctan2(yn, zn))
ra[ra < 0] = ra[ra < 0] + 360
roll[roll < 0] = roll[roll < 0] + 360
return np.array([ra, dec, roll]).transpose()
return np.moveaxis(np.array([ra, dec, roll]), 0, -1)


def _quat2equatorial_norm(self):
q = np.atleast_2d(self.q)
q2 = q ** 2

# calculate direction cosine matrix elements from $quaternions
xa = q2[..., 0] - q2[..., 1] - q2[..., 2] + q2[..., 3]
xb = 2 * (q[..., 0] * q[..., 1] + q[..., 2] * q[..., 3])
xn = 2 * (q[..., 0] * q[..., 2] - q[..., 1] * q[..., 3])
yn = 2 * (q[..., 1] * q[..., 2] + q[..., 0] * q[..., 3])
zn = q2[..., 3] + q2[..., 2] - q2[..., 0] - q2[..., 1]

# Due to numerical precision this can go negative. Allow *slightly* negative
# values but raise an exception otherwise.
one_minus_xn2 = 1 - xn**2
return one_minus_xn2


# _quat2transform is largely from Enthought's quaternion.rotmat, though this math is
Expand Down Expand Up @@ -394,11 +414,9 @@ def _quat2transform(self):
:rtype: numpy array
"""
q = self.q
if q.ndim == 1:
q = q[np.newaxis]
q = np.atleast_2d(self.q)

x, y, z, w = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
xx2 = x * x * 2.
yy2 = y * y * 2.
zz2 = z * z * 2.
Expand All @@ -409,16 +427,16 @@ def _quat2transform(self):
yz2 = y * z * 2.
wx2 = w * x * 2.

t = np.empty((len(q), 3, 3), float)
t[:, 0, 0] = 1. - yy2 - zz2
t[:, 0, 1] = xy2 - wz2
t[:, 0, 2] = zx2 + wy2
t[:, 1, 0] = xy2 + wz2
t[:, 1, 1] = 1. - xx2 - zz2
t[:, 1, 2] = yz2 - wx2
t[:, 2, 0] = zx2 - wy2
t[:, 2, 1] = yz2 + wx2
t[:, 2, 2] = 1. - xx2 - yy2
t = np.empty(tuple(q.shape[:-1] + (3, 3)), float)
t[..., 0, 0] = 1. - yy2 - zz2
t[..., 0, 1] = xy2 - wz2
t[..., 0, 2] = zx2 + wy2
t[..., 1, 0] = xy2 + wz2
t[..., 1, 1] = 1. - xx2 - zz2
t[..., 1, 2] = yz2 - wx2
t[..., 2, 0] = zx2 - wy2
t[..., 2, 1] = yz2 + wx2
t[..., 2, 2] = 1. - xx2 - yy2

return t

Expand Down Expand Up @@ -454,7 +472,7 @@ def _equatorial2transform(self):
[-ca * sd * sr - sa * cr, -sa * sd * sr + ca * cr, cd * sr],
[-ca * sd * cr + sa * sr, -sa * sd * cr - ca * sr, cd * cr]])

return rmat.transpose()
return np.moveaxis(np.moveaxis(rmat, 0, -1), 0, -2)

def _transform2quat(self):
"""Construct quaternions from the transform/rotation matrices
Expand All @@ -467,42 +485,42 @@ def _transform2quat(self):
if transform.ndim == 2:
transform = transform[np.newaxis]
# Code was copied from perl PDL code that uses backwards index ordering
T = transform.transpose(0, 2, 1)
den = np.array([1.0 + T[:, 0, 0] - T[:, 1, 1] - T[:, 2, 2],
1.0 - T[:, 0, 0] + T[:, 1, 1] - T[:, 2, 2],
1.0 - T[:, 0, 0] - T[:, 1, 1] + T[:, 2, 2],
1.0 + T[:, 0, 0] + T[:, 1, 1] + T[:, 2, 2]])
T = transform.swapaxes(-2, -1)
den = np.array([1.0 + T[..., 0, 0] - T[..., 1, 1] - T[..., 2, 2],
1.0 - T[..., 0, 0] + T[..., 1, 1] - T[..., 2, 2],
1.0 - T[..., 0, 0] - T[..., 1, 1] + T[..., 2, 2],
1.0 + T[..., 0, 0] + T[..., 1, 1] + T[..., 2, 2]])

half_rt_q_max = 0.5 * np.sqrt(np.max(den, axis=0))
max_idx = np.argmax(den, axis=0)
poss_quat = np.zeros((4, len(T), 4))
poss_quat = np.zeros(tuple((4,) + T.shape[:-2] + (4,)))
denom = 4.0 * half_rt_q_max
poss_quat[0] = np.transpose(
poss_quat[0] = np.moveaxis(
np.array(
[half_rt_q_max,
(T[:, 1, 0] + T[:, 0, 1]) / denom,
(T[:, 2, 0] + T[:, 0, 2]) / denom,
-(T[:, 2, 1] - T[:, 1, 2]) / denom]))
poss_quat[1] = np.transpose(
(T[..., 1, 0] + T[..., 0, 1]) / denom,
(T[..., 2, 0] + T[..., 0, 2]) / denom,
-(T[..., 2, 1] - T[..., 1, 2]) / denom]), 0, -1)
poss_quat[1] = np.moveaxis(
np.array(
[(T[:, 1, 0] + T[:, 0, 1]) / denom,
[(T[..., 1, 0] + T[..., 0, 1]) / denom,
half_rt_q_max,
(T[:, 2, 1] + T[:, 1, 2]) / denom,
-(T[:, 0, 2] - T[:, 2, 0]) / denom]))
poss_quat[2] = np.transpose(
(T[..., 2, 1] + T[..., 1, 2]) / denom,
-(T[..., 0, 2] - T[..., 2, 0]) / denom]), 0, -1)
poss_quat[2] = np.moveaxis(
np.array(
[(T[:, 2, 0] + T[:, 0, 2]) / denom,
(T[:, 2, 1] + T[:, 1, 2]) / denom,
[(T[..., 2, 0] + T[..., 0, 2]) / denom,
(T[..., 2, 1] + T[..., 1, 2]) / denom,
half_rt_q_max,
-(T[:, 1, 0] - T[:, 0, 1]) / denom]))
poss_quat[3] = np.transpose(
-(T[..., 1, 0] - T[..., 0, 1]) / denom]), 0, -1)
poss_quat[3] = np.moveaxis(
np.array(
[-(T[:, 2, 1] - T[:, 1, 2]) / denom,
-(T[:, 0, 2] - T[:, 2, 0]) / denom,
-(T[:, 1, 0] - T[:, 0, 1]) / denom,
half_rt_q_max]))
[-(T[..., 2, 1] - T[..., 1, 2]) / denom,
-(T[..., 0, 2] - T[..., 2, 0]) / denom,
-(T[..., 1, 0] - T[..., 0, 1]) / denom,
half_rt_q_max]), 0, -1)

q = np.zeros((len(T), 4))
q = np.zeros(tuple(T.shape[:-2] + (4,)))
for idx in range(0, 4):
max_match = max_idx == idx
q[max_match] = poss_quat[idx][max_match]
Expand Down Expand Up @@ -563,18 +581,14 @@ def __mul__(self, quat2):
:rtype: Quat
"""
q1 = self.q
if q1.ndim == 1:
q1 = q1[np.newaxis]
q2 = quat2.q
if q2.ndim == 1:
q2 = q2[np.newaxis]
q1 = np.atleast_2d(self.q)
q2 = np.atleast_2d(quat2.q)
mult = np.zeros((len(q1), 4))
mult[:,0] = q1[:,3] * q2[:,0] - q1[:,2] * q2[:,1] + q1[:,1] * q2[:,2] + q1[:,0] * q2[:,3]
mult[:,1] = q1[:,2] * q2[:,0] + q1[:,3] * q2[:,1] - q1[:,0] * q2[:,2] + q1[:,1] * q2[:,3]
mult[:,2] = -q1[:,1] * q2[:,0] + q1[:,0] * q2[:,1] + q1[:,3] * q2[:,2] + q1[:,2] * q2[:,3]
mult[:,3] = -q1[:,0] * q2[:,0] - q1[:,1] * q2[:,1] - q1[:,2] * q2[:,2] + q1[:,3] * q2[:,3]
return Quat(mult)
mult[...,0] = q1[...,3] * q2[...,0] - q1[...,2] * q2[...,1] + q1[...,1] * q2[...,2] + q1[...,0] * q2[...,3]
mult[...,1] = q1[...,2] * q2[...,0] + q1[...,3] * q2[...,1] - q1[...,0] * q2[...,2] + q1[...,1] * q2[...,3]
mult[...,2] = -q1[...,1] * q2[...,0] + q1[...,0] * q2[...,1] + q1[...,3] * q2[...,2] + q1[...,2] * q2[...,3]
mult[...,3] = -q1[...,0] * q2[...,0] - q1[...,1] * q2[...,1] - q1[...,2] * q2[...,2] + q1[...,3] * q2[...,3]
return Quat(q=mult)

def inv(self):
"""
Expand All @@ -583,10 +597,8 @@ def inv(self):
:returns: inverted quaternion
:rtype: Quat
"""
q = self.q
if q.ndim == 1:
q = q[np.newaxis]
return Quat(np.array([q[:, 0], q[:, 1], q[:, 2], -1.0 * q[:, 3]]).transpose())
q = np.atleast_2d(self.q)
return Quat(q=np.array([q[..., 0], q[..., 1], q[..., 2], -1.0 * q[..., 3]]).swapaxes(-2, -1))

def dq(self, q2):
"""
Expand Down Expand Up @@ -621,9 +633,4 @@ def normalize(array):
"""
quat = np.array(array)
if quat.ndim == 1:
return quat / np.sqrt(np.dot(quat, quat))
elif quat.ndim == 2:
return quat / np.sqrt(np.sum(quat * quat, axis=-1)[:, np.newaxis])
else:
raise TypeError("Input must be 1 or 2d")
return np.squeeze(quat/np.sqrt(np.sum(quat * quat, axis=-1, keepdims=True)))
Loading

0 comments on commit 2971311

Please sign in to comment.