Skip to content

Commit

Permalink
axis_angle representation of rotations
Browse files Browse the repository at this point in the history
Summary: We can represent a rotation as a vector in the axis direction, whose length is the rotation anticlockwise in radians around that axis.

Reviewed By: gkioxari

Differential Revision: D24306293

fbshipit-source-id: 2e0f138eda8329f6cceff600a6e5f17a00e4deb7
  • Loading branch information
bottler authored and facebook-github-bot committed Oct 21, 2020
1 parent 005a334 commit c93c4dd
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 7 deletions.
95 changes: 95 additions & 0 deletions pytorch3d/transforms/rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,101 @@ def quaternion_apply(quaternion, point):
return out[..., 1:]


def axis_angle_to_matrix(axis_angle):
"""
Convert rotations given as axis/angle to rotation matrices.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))


def matrix_to_axis_angle(matrix):
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))


def axis_angle_to_quaternion(axis_angle):
"""
Convert rotations given as axis/angle to quaternions.
Args:
axis_angle: Rotations given as a vector in axis angle form,
as a tensor of shape (..., 3), where the magnitude is
the angle turned anticlockwise in radians around the
vector's direction.
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
half_angles = 0.5 * angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
)
return quaternions


def quaternion_to_axis_angle(quaternions):
"""
Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles


def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
Expand Down
43 changes: 36 additions & 7 deletions tests/test_rotation_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms.rotation_conversions import (
axis_angle_to_matrix,
axis_angle_to_quaternion,
euler_angles_to_matrix,
matrix_to_axis_angle,
matrix_to_euler_angles,
matrix_to_quaternion,
matrix_to_rotation_6d,
quaternion_apply,
quaternion_multiply,
quaternion_to_axis_angle,
quaternion_to_matrix,
random_quaternions,
random_rotation,
Expand Down Expand Up @@ -60,13 +64,13 @@ def test_from_quat(self):
"""quat -> mtx -> quat"""
data = random_quaternions(13, dtype=torch.float64)
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)

def test_to_quat(self):
"""mtx -> quat -> mtx"""
data = random_rotations(13, dtype=torch.float64)
mdata = quaternion_to_matrix(matrix_to_quaternion(data))
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)

def test_quat_grad_exists(self):
"""Quaternion calculations are differentiable."""
Expand Down Expand Up @@ -107,21 +111,21 @@ def test_from_euler(self):
for convention in self._tait_bryan_conventions():
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)

data[:, 1] += half_pi
for convention in self._proper_euler_conventions():
matrices = euler_angles_to_matrix(data, convention)
mdata = matrix_to_euler_angles(matrices, convention)
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)

def test_to_euler(self):
"""mtx -> euler -> mtx"""
data = random_rotations(13, dtype=torch.float64)
for convention in self._all_euler_angle_conventions():
euler_angles = matrix_to_euler_angles(data, convention)
mdata = euler_angles_to_matrix(euler_angles, convention)
self.assertTrue(torch.allclose(data, mdata))
self.assertClose(data, mdata)

def test_euler_grad_exists(self):
"""Euler angle calculations are differentiable."""
Expand All @@ -143,7 +147,7 @@ def test_quaternion_multiplication(self):
ab_matrix = torch.matmul(a_matrix, b_matrix)
ab_from_matrix = matrix_to_quaternion(ab_matrix)
self.assertEqual(ab.shape, ab_from_matrix.shape)
self.assertTrue(torch.allclose(ab, ab_from_matrix))
self.assertClose(ab, ab_from_matrix)

def test_matrix_to_quaternion_corner_case(self):
"""Check no bad gradients from sqrt(0)."""
Expand All @@ -159,14 +163,39 @@ def test_matrix_to_quaternion_corner_case(self):

self.assertClose(matrix, 0.95 * torch.eye(3))

def test_from_axis_angle(self):
"""axis_angle -> mtx -> axis_angle"""
n_repetitions = 20
data = torch.rand(n_repetitions, 3)
matrices = axis_angle_to_matrix(data)
mdata = matrix_to_axis_angle(matrices)
self.assertClose(data, mdata, atol=2e-6)

def test_from_axis_angle_has_grad(self):
n_repetitions = 20
data = torch.rand(n_repetitions, 3, requires_grad=True)
matrices = axis_angle_to_matrix(data)
mdata = matrix_to_axis_angle(matrices)
quats = axis_angle_to_quaternion(data)
mdata2 = quaternion_to_axis_angle(quats)
(grad,) = torch.autograd.grad(mdata.sum() + mdata2.sum(), data)
self.assertTrue(torch.isfinite(grad).all())

def test_to_axis_angle(self):
"""mtx -> axis_angle -> mtx"""
data = random_rotations(13, dtype=torch.float64)
euler_angles = matrix_to_axis_angle(data)
mdata = axis_angle_to_matrix(euler_angles)
self.assertClose(data, mdata)

def test_quaternion_application(self):
"""Applying a quaternion is the same as applying the matrix."""
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
matrices = quaternion_to_matrix(quaternions)
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
transform1 = quaternion_apply(quaternions, points)
transform2 = torch.matmul(matrices, points[..., None])[..., 0]
self.assertTrue(torch.allclose(transform1, transform2))
self.assertClose(transform1, transform2)

[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
self.assertTrue(torch.isfinite(p).all())
Expand Down

0 comments on commit c93c4dd

Please sign in to comment.