Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement quaternion arithmetic from numpy-quaternion #281

Merged
merged 11 commits into from
Feb 18, 2022
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ this project adheres to `Semantic Versioning <https://semver.org/spec/v2.0.0.htm
Unreleased
==========

Changed
-------
- `orix.quaternion.Quaternion` now relies on `numpy-quaternion <https://quaternion.readthedocs.io/en/latest/>`_
for quaternion conjugation, quaternion-quaternion and quaternion-vector multiplication,
and quaternion-quaternion and quaternion-vector outer products.

2022-02-14 - version 0.8.1
==========================

Expand Down
1 change: 1 addition & 0 deletions orix/quaternion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
__all__ = [
"check_quaternion",
"Quaternion",
"QuaternionNumpy",
"Rotation",
"von_mises",
"Misorientation",
Expand Down
94 changes: 44 additions & 50 deletions orix/quaternion/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import dask.array as da
import numpy as np
import quaternion

from orix.base import check, Object3d
from orix.scalar import Scalar
Expand All @@ -39,6 +40,31 @@ class Quaternion(Object3d):
- Inversion.
- Multiplication with other quaternions and vectors.

Quaternion-quaternion multiplication for two quaternions
:math:`q_1 = (a_1, b_1, c_1, d_1)`
and :math:`q_2 = (a_2, b_2, c_2, d_2)`
with :math:`q_3 = (a_3, b_3, c_3, d_3) = q_1 * q_2` follows as:

.. math::
a_3 = (a_1 * a_2 - b_1 * b_2 - c_1 * c_2 - d_1 * d_2)

b_3 = (a_1 * b_2 + b_1 * a_2 + c_1 * d_2 - d_1 * c_2)

c_3 = (a_1 * c_2 - b_1 * d_2 + c_1 * a_2 + d_1 * b_2)

d_3 = (a_1 * d_2 + b_1 * c_2 - c_1 * b_2 + d_1 * a_2)

Quaternion-vector multiplication with a three-dimensional vector
:math:`v = (x, y, z)` calculates a rotated vector
:math:`v' = q * v * q^{-1}` and follows as:

.. math::
v'_x = x(a^2 + b^2 - c^2 - d^2) + 2z(a * c + b * d) + y(b * c - a * d)

v'_y = y(a^2 - b^2 + c^2 - d^2) + 2x(a * d + b * c) + z(c * d - a * b)

v'_z = z(a^2 - b^2 - c^2 + d^2) + 2y(a * b + c * d) + x(b * d - a * c)

Attributes
----------
data : numpy.ndarray
Expand Down Expand Up @@ -89,39 +115,25 @@ def antipodal(self):

@property
def conj(self):
a = self.a.data
b, c, d = -self.b.data, -self.c.data, -self.d.data
q = np.stack((a, b, c, d), axis=-1)
return Quaternion(q)
q = quaternion.from_float_array(self.data).conj()
return Quaternion(quaternion.as_float_array(q))

def __invert__(self):
return self.__class__(self.conj.data / (self.norm.data**2)[..., np.newaxis])

def __mul__(self, other):
hakonanes marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(other, Quaternion):
sa, oa = self.a.data, other.a.data
sb, ob = self.b.data, other.b.data
sc, oc = self.c.data, other.c.data
sd, od = self.d.data, other.d.data
a = sa * oa - sb * ob - sc * oc - sd * od
b = sb * oa + sa * ob - sd * oc + sc * od
c = sc * oa + sd * ob + sa * oc - sb * od
d = sd * oa - sc * ob + sb * oc + sa * od
q = np.stack((a, b, c, d), axis=-1)
return other.__class__(q)
q1 = quaternion.from_float_array(self.data)
q2 = quaternion.from_float_array(other.data)
return other.__class__(quaternion.as_float_array(q1 * q2))
elif isinstance(other, Vector3d):
a, b, c, d = self.a.data, self.b.data, self.c.data, self.d.data
x, y, z = other.x.data, other.y.data, other.z.data
x_new = (a**2 + b**2 - c**2 - d**2) * x + 2 * (
(a * c + b * d) * z + (b * c - a * d) * y
)
y_new = (a**2 - b**2 + c**2 - d**2) * y + 2 * (
(a * d + b * c) * x + (c * d - a * b) * z
# check broadcast shape is correct before calculation, as
# quaternion.rotat_vectors will perform outer product
# this keeps current __mul__ broadcast behaviour
q1 = quaternion.from_float_array(self.data)
v = quaternion.as_vector_part(
(q1 * quaternion.from_vector_part(other.data)) * ~q1
)
z_new = (a**2 - b**2 - c**2 + d**2) * z + 2 * (
(a * b + c * d) * y + (b * d - a * c) * x
)
v = np.stack((x_new, y_new, z_new), axis=-1)
if isinstance(other, Miller):
m = other.__class__(xyz=v, phase=other.phase)
m.coordinate_format = other.coordinate_format
Expand Down Expand Up @@ -224,33 +236,15 @@ def outer(self, other):
orix.quaternion.Quaternion or orix.vector.Vector3d
"""

def e(x, y):
return np.multiply.outer(x, y)

if isinstance(other, Quaternion):
q = np.zeros(self.shape + other.shape + (4,), dtype=float)
sa, oa = self.data[..., 0], other.data[..., 0]
sb, ob = self.data[..., 1], other.data[..., 1]
sc, oc = self.data[..., 2], other.data[..., 2]
sd, od = self.data[..., 3], other.data[..., 3]
q[..., 0] = e(sa, oa) - e(sb, ob) - e(sc, oc) - e(sd, od)
q[..., 1] = e(sb, oa) + e(sa, ob) - e(sd, oc) + e(sc, od)
q[..., 2] = e(sc, oa) + e(sd, ob) + e(sa, oc) - e(sb, od)
q[..., 3] = e(sd, oa) - e(sc, ob) + e(sb, oc) + e(sa, od)
return other.__class__(q)
q1 = quaternion.from_float_array(self.data)
hakonanes marked this conversation as resolved.
Show resolved Hide resolved
q2 = quaternion.from_float_array(other.data)
# np.outer works with flattened array
q = np.outer(q1, q2).reshape(q1.shape + q2.shape)
return other.__class__(quaternion.as_float_array(q))
elif isinstance(other, Vector3d):
a, b, c, d = self.a.data, self.b.data, self.c.data, self.d.data
x, y, z = other.x.data, other.y.data, other.z.data
x_new = e(a**2 + b**2 - c**2 - d**2, x) + 2 * (
e(a * c + b * d, z) + e(b * c - a * d, y)
)
y_new = e(a**2 - b**2 + c**2 - d**2, y) + 2 * (
e(a * d + b * c, x) + e(c * d - a * b, z)
)
z_new = e(a**2 - b**2 - c**2 + d**2, z) + 2 * (
e(a * b + c * d, y) + e(b * d - a * c, x)
)
v = np.stack((x_new, y_new, z_new), axis=-1)
q = quaternion.from_float_array(self.data)
v = quaternion.rotate_vectors(q, other.data)
if isinstance(other, Miller):
m = other.__class__(xyz=v, phase=other.phase)
m.coordinate_format = other.coordinate_format
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@
"matplotlib-scalebar",
"numba",
"numpy",
"numpy-quaternion",
"scipy",
"tqdm",
"tqdm"
],
# fmt: on
package_data={"": ["LICENSE", "README.rst", "readthedocs.yml"], "orix": ["*.py"]},
Expand Down