-
Notifications
You must be signed in to change notification settings - Fork 27
/
gravity.py
131 lines (102 loc) · 4.32 KB
/
gravity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Tensor class for gravity vector in camera frame."""
import torch
from torch.nn import functional as F
from geocalib.misc import EuclideanManifold, SphericalManifold, TensorWrapper, autocast
from geocalib.utils import rad2rotmat
# mypy: ignore-errors
class Gravity(TensorWrapper):
"""Gravity vector in camera frame."""
eps = 1e-4
@autocast
def __init__(self, data: torch.Tensor) -> None:
"""Create gravity vector from data.
Args:
data (torch.Tensor): gravity vector as 3D vector in camera frame.
"""
assert data.shape[-1] == 3, data.shape
data = F.normalize(data, dim=-1)
super().__init__(data)
@classmethod
def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
"""Create gravity vector from roll and pitch angles."""
if not isinstance(roll, torch.Tensor):
roll = torch.tensor(roll)
if not isinstance(pitch, torch.Tensor):
pitch = torch.tensor(pitch)
sr, cr = torch.sin(roll), torch.cos(roll)
sp, cp = torch.sin(pitch), torch.cos(pitch)
return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
@property
def vec3d(self) -> torch.Tensor:
"""Return the gravity vector in the representation."""
return self._data
@property
def x(self) -> torch.Tensor:
"""Return first component of the gravity vector."""
return self._data[..., 0]
@property
def y(self) -> torch.Tensor:
"""Return second component of the gravity vector."""
return self._data[..., 1]
@property
def z(self) -> torch.Tensor:
"""Return third component of the gravity vector."""
return self._data[..., 2]
@property
def roll(self) -> torch.Tensor:
"""Return the roll angle of the gravity vector."""
roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
offset = -torch.pi * torch.sign(self.x)
return torch.where(self.y < 0, roll, -roll + offset)
def J_roll(self) -> torch.Tensor:
"""Return the Jacobian of the roll angle of the gravity vector."""
cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
cr, sr = torch.cos(self.roll), torch.sin(self.roll)
Jr = self.new_zeros(self.shape + (3,))
Jr[..., 0] = -cr * cp
Jr[..., 1] = sr * cp
return Jr
@property
def pitch(self) -> torch.Tensor:
"""Return the pitch angle of the gravity vector."""
return torch.asin(self.z)
def J_pitch(self) -> torch.Tensor:
"""Return the Jacobian of the pitch angle of the gravity vector."""
cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
cr, sr = torch.cos(self.roll), torch.sin(self.roll)
Jp = self.new_zeros(self.shape + (3,))
Jp[..., 0] = sr * sp
Jp[..., 1] = cr * sp
Jp[..., 2] = cp
return Jp
@property
def rp(self) -> torch.Tensor:
"""Return the roll and pitch angles of the gravity vector."""
return torch.stack([self.roll, self.pitch], dim=-1)
def J_rp(self) -> torch.Tensor:
"""Return the Jacobian of the roll and pitch angles of the gravity vector."""
return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
@property
def R(self) -> torch.Tensor:
"""Return the rotation matrix from the gravity vector."""
return rad2rotmat(roll=self.roll, pitch=self.pitch)
def J_R(self) -> torch.Tensor:
"""Return the Jacobian of the rotation matrix from the gravity vector."""
raise NotImplementedError
def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
"""Update the gravity vector by adding a delta."""
if spherical:
data = SphericalManifold.plus(self.vec3d, delta)
return self.__class__(data)
data = EuclideanManifold.plus(self.rp, delta)
return self.from_rp(data[..., 0], data[..., 1])
def J_update(self, spherical: bool = False) -> torch.Tensor:
"""Return the Jacobian of the update."""
return (
SphericalManifold.J_plus(self.vec3d)
if spherical
else EuclideanManifold.J_plus(self.vec3d)
)
def __repr__(self):
"""Print the Camera object."""
return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"