diff --git a/src/coordinax/_dn/base.py b/src/coordinax/_dn/base.py index 4e1112e..8aedd80 100644 --- a/src/coordinax/_dn/base.py +++ b/src/coordinax/_dn/base.py @@ -1,16 +1,19 @@ """Representation of coordinates in different systems.""" -__all__ = ["AbstractPositionND", "AbstractPositionNDDifferential"] +__all__ = ["AbstractPositionND", "AbstractVelocityND", "AbstractAccelerationND"] from abc import abstractmethod from dataclasses import replace from typing import TYPE_CHECKING, Any +import equinox as eqx + import quaxed.lax as qlax import quaxed.numpy as qnp from coordinax._base import AbstractVector +from coordinax._base_acc import AbstractAcceleration from coordinax._base_pos import AbstractPosition from coordinax._base_vel import AbstractVelocity from coordinax._utils import classproperty @@ -82,7 +85,7 @@ def reshape(self, *hape: Any, order: str = "C") -> "Self": return replace(self, q=self.q.reshape(*hape, self.q.shape[-1], order=order)) -class AbstractPositionNDDifferential(AbstractVelocity): +class AbstractVelocityND(AbstractVelocity): """Abstract representation of N-D vector differentials.""" @classproperty @@ -146,3 +149,143 @@ def flatten(self) -> "Self": def reshape(self, *hape: Any, order: str = "C") -> "Self": """Reshape the vector.""" return replace(self, q=self.q.reshape(*hape, self.q.shape[-1], order=order)) + + +class AbstractAccelerationND(AbstractAcceleration): + """Abstract representation of N-D vector differentials.""" + + @classproperty + @classmethod + def _cartesian_cls(cls) -> type[AbstractVector]: + """Get the Cartesian acceleration class. + + Examples + -------- + >>> import coordinax as cx + >>> cx.CartesianAccelerationND._cartesian_cls + + + """ + from .cartesian import CartesianAccelerationND + + return CartesianAccelerationND + + @classproperty + @classmethod + @abstractmethod + def integral_cls(cls) -> type[AbstractVelocityND]: + raise NotImplementedError # pragma: no cover + + # =============================================================== + # Array API + + @property + def mT(self) -> "Self": # noqa: N802 + """Transpose the vector. + + The last axis is interpreted as the feature axis. The matrix + transpose is performed on the last two non-feature axes. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> vec = cx.CartesianAccelerationND(Quantity([[[1, 2, 3]], + ... [[4, 5, 6]]], "m/s^2")) + >>> vec.shape + (2, 1) + + >>> vec.mT.shape + (1, 2) + + """ + ndim = self.ndim + ndim = eqx.error_if( + ndim, + ndim < 2, + f"x must be at least two-dimensional for matrix_transpose; got {ndim=}", + ) + axes = (*range(ndim - 3), ndim - 1, ndim - 2, ndim) + return replace(self, d2_q=qlax.transpose(self.d2_q, axes)) + + @property + def shape(self) -> tuple[int, ...]: + """Get the shape of the vector's components. + + When represented as a single array, the vector has an additional + dimension at the end for the components. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> vec = cx.CartesianAccelerationND(Quantity([[[1, 2, 3]], + ... [[4, 5, 6]]], "m/s^2")) + >>> vec.shape + (2, 1) + + """ + return self.d2_q.shape[:-1] + + @property + def T(self) -> "Self": # noqa: N802 + """Transpose the vector's batch axes, preserving the feature axis. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> vec = cx.CartesianAccelerationND(Quantity([[[1, 2, 3]], + ... [[4, 5, 6]]], "m/s^2")) + >>> vec.shape + (2, 1) + >>> vec.T.shape + (1, 2) + + """ + return replace( + self, + d2_q=qlax.transpose(self.d2_q, (*range(self.ndim)[::-1], self.ndim)), + ) + + # =============================================================== + # Further array methods + + def flatten(self) -> "Self": + """Flatten the vector's batch dimensions, preserving the component axis. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> vec = cx.CartesianAccelerationND(Quantity([[[1, 2, 3]], + ... [[4, 5, 6]]], "m/s^2")) + >>> vec.shape + (2, 1) + + >>> vec.flatten().shape + (2,) + + """ + return replace( + self, d2_q=qnp.reshape(self.d2_q, (self.size, self.d2_q.shape[-1]), "C") + ) + + def reshape(self, *shape: Any, order: str = "C") -> "Self": + """Reshape the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> vec = cx.CartesianAccelerationND(Quantity([1, 2, 3], "m/s^2")) + >>> vec.shape + () + + >>> vec.reshape(1, 1).shape + (1, 1) + + """ + return replace( + self, d2_q=self.d2_q.reshape(*shape, self.d2_q.shape[-1], order=order) + ) diff --git a/src/coordinax/_dn/cartesian.py b/src/coordinax/_dn/cartesian.py index 2e0989f..fc2bb6e 100644 --- a/src/coordinax/_dn/cartesian.py +++ b/src/coordinax/_dn/cartesian.py @@ -1,13 +1,10 @@ """Built-in vector classes.""" -__all__ = [ - "CartesianPositionND", - "CartesianVelocityND", -] +__all__ = ["CartesianPositionND", "CartesianVelocityND", "CartesianAccelerationND"] from dataclasses import replace from functools import partial -from typing import final +from typing import NoReturn, final from typing_extensions import override import equinox as eqx @@ -20,7 +17,7 @@ from unxt import Quantity import coordinax._typing as ct -from .base import AbstractPositionND, AbstractPositionNDDifferential +from .base import AbstractAccelerationND, AbstractPositionND, AbstractVelocityND from coordinax._base import AbstractVector from coordinax._base_pos import AbstractPosition from coordinax._utils import classproperty @@ -233,7 +230,7 @@ def _mul_vcnd(lhs: ArrayLike, rhs: CartesianPositionND, /) -> CartesianPositionN @final -class CartesianVelocityND(AbstractPositionNDDifferential): +class CartesianVelocityND(AbstractVelocityND): """Cartesian differential representation. Examples @@ -300,8 +297,7 @@ def integral_cls(cls) -> type[CartesianPositionND]: @classproperty @classmethod def differential_cls(cls) -> type["CartesianAccelerationND"]: - msg = "Not yet supported" - raise NotImplementedError(msg) # TODO: Implement this + return CartesianAccelerationND @partial(jax.jit) def norm(self, _: AbstractPositionND | None = None, /) -> ct.BatchableSpeed: @@ -320,3 +316,121 @@ def norm(self, _: AbstractPositionND | None = None, /) -> ct.BatchableSpeed: """ return xp.linalg.vector_norm(self.d_q, axis=-1) + + +############################################################################## +# Acceleration + + +@final +class CartesianAccelerationND(AbstractAccelerationND): + """Cartesian N-dimensional acceleration representation. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + A 1D vector: + + >>> q = cx.CartesianAccelerationND(Quantity([[1]], "km/s2")) + >>> q.d2_q + Quantity['acceleration'](Array([[1.]], dtype=float32), unit='km / s2') + >>> q.shape + (1,) + + A 2D vector: + + >>> q = cx.CartesianAccelerationND(Quantity([1, 2], "km/s2")) + >>> q.d2_q + Quantity['acceleration'](Array([1., 2.], dtype=float32), unit='km / s2') + >>> q.shape + () + + A 3D vector: + + >>> q = cx.CartesianAccelerationND(Quantity([1, 2, 3], "km/s2")) + >>> q.d2_q + Quantity['acceleration'](Array([1., 2., 3.], dtype=float32), unit='km / s2') + >>> q.shape + () + + A 4D vector: + + >>> q = cx.CartesianAccelerationND(Quantity([1, 2, 3, 4], "km/s2")) + >>> q.d2_q + Quantity['acceleration'](Array([1., 2., 3., 4.], dtype=float32), unit='km / s2') + >>> q.shape + () + + A 5D vector: + + >>> q = cx.CartesianAccelerationND(Quantity([1, 2, 3, 4, 5], "km/s2")) + >>> q.d2_q + Quantity['acceleration'](Array([1., 2., 3., 4., 5.], dtype=float32), unit='km / s2') + >>> q.shape + () + + """ + + d2_q: ct.BatchableAcc = eqx.field( + converter=partial(Quantity["acceleration"].constructor, dtype=float) + ) + r"""N-D acceleration :math:`d\vec{x}/dt^2 \in (-\infty, \infty). + + Should have shape (*batch, F) where F is the number of features / + dimensions. Arbitrary batch shapes are supported. + """ + + @classproperty + @classmethod + def integral_cls(cls) -> type[CartesianVelocityND]: + """Return the integral class. + + Examples + -------- + >>> import coordinax as cx + >>> cx.CartesianAccelerationND.integral_cls.__name__ + 'CartesianVelocityND' + + """ + return CartesianVelocityND + + @classproperty + @classmethod + def differential_cls(cls) -> NoReturn: + """Return the differential class. + + Examples + -------- + >>> import coordinax as cx + >>> try: cx.CartesianAccelerationND.differential_cls + ... except NotImplementedError as e: print(e) + Not yet supported + + """ + msg = "Not yet supported" + raise NotImplementedError(msg) # TODO: Implement this + + @partial(jax.jit) + def norm( + self, + velocity: AbstractVelocityND | None = None, # noqa: ARG002 + position: AbstractPositionND | None = None, # noqa: ARG002 + /, + ) -> ct.BatchableSpeed: + """Return the norm of the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + A 3D vector: + + >>> c = cx.CartesianAccelerationND(Quantity([1, 2, 3], "km/s2")) + >>> c.norm() + Quantity['acceleration'](Array(3.7416575, dtype=float32), unit='km / s2') + + """ + return xp.linalg.vector_norm(self.d2_q, axis=-1)