Skip to content

Commit

Permalink
refactor: aval
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Aug 14, 2024
1 parent 1551207 commit 9e00066
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 37 deletions.
28 changes: 2 additions & 26 deletions src/coordinax/_base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import coordinax._typing as ct
from ._base import AbstractVector
from ._mixins import AvalMixin
from ._utils import classproperty

if TYPE_CHECKING:
Expand All @@ -32,7 +33,7 @@
VECTOR_CLASSES: set[type["AbstractPosition"]] = set()


class AbstractPosition(AbstractVector): # pylint: disable=abstract-method
class AbstractPosition(AvalMixin, AbstractVector): # pylint: disable=abstract-method
"""Abstract representation of coordinates in different systems."""

def __init_subclass__(cls, **kwargs: Any) -> None:
Expand Down Expand Up @@ -85,31 +86,6 @@ def differential_cls(cls) -> type["AbstractVelocity"]:
"""
raise NotImplementedError

# ===============================================================
# Quax

def aval(self) -> jax.core.ShapedArray:
"""Return the vector as a JAX array.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> vec.aval()
ConcreteArray([1. 2. 3.], dtype=float32)
>>> vec = cx.CartesianPosition3D.constructor([[1, 2, 3], [4, 5, 6]], "m")
>>> vec.aval()
ConcreteArray([[1. 2. 3.]
[4. 5. 6.]], dtype=float32)
"""
return jax.core.get_aval(
convert(self.represent_as(self._cartesian_cls), Quantity).value
)

# ===============================================================
# Array

Expand Down
5 changes: 3 additions & 2 deletions src/coordinax/_d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D
from coordinax._base import AbstractVector
from coordinax._base_pos import AbstractPosition
from coordinax._mixins import AvalMixin
from coordinax._utils import classproperty


Expand Down Expand Up @@ -135,7 +136,7 @@ def __sub__(


@final
class CartesianVelocity1D(AbstractVelocity1D):
class CartesianVelocity1D(AvalMixin, AbstractVelocity1D):
"""Cartesian differential representation."""

d_x: ct.BatchableSpeed = eqx.field(converter=Quantity["speed"].constructor)
Expand Down Expand Up @@ -168,7 +169,7 @@ def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed:


@final
class CartesianAcceleration1D(AbstractAcceleration1D):
class CartesianAcceleration1D(AvalMixin, AbstractAcceleration1D):
"""Cartesian differential representation."""

d2_x: ct.BatchableAcc = eqx.field(converter=Quantity["acceleration"].constructor)
Expand Down
5 changes: 3 additions & 2 deletions src/coordinax/_d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .base import AbstractAcceleration2D, AbstractPosition2D, AbstractVelocity2D
from coordinax._base import AbstractVector
from coordinax._base_pos import AbstractPosition
from coordinax._mixins import AvalMixin
from coordinax._utils import classproperty


Expand Down Expand Up @@ -107,7 +108,7 @@ def __sub__(


@final
class CartesianVelocity2D(AbstractVelocity2D):
class CartesianVelocity2D(AvalMixin, AbstractVelocity2D):
"""Cartesian differential representation."""

d_x: ct.BatchableSpeed = eqx.field(
Expand All @@ -132,7 +133,7 @@ def differential_cls(cls) -> type["CartesianAcceleration2D"]:


@final
class CartesianAcceleration2D(AbstractAcceleration2D):
class CartesianAcceleration2D(AvalMixin, AbstractAcceleration2D):
"""Cartesian acceleration representation."""

d2_x: ct.BatchableSpeed = eqx.field(
Expand Down
7 changes: 4 additions & 3 deletions src/coordinax/_d2/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@

import quaxed.array_api as xp
from dataclassish import field_values
from unxt import Quantity
from unxt import AbstractQuantity, Quantity

from .base import AbstractPosition2D
from .cartesian import CartesianPosition2D, CartesianVelocity2D
from .cartesian import CartesianAcceleration2D, CartesianPosition2D, CartesianVelocity2D
from coordinax._utils import full_shaped

#####################################################################
Expand All @@ -25,7 +25,8 @@ def vec_to_q(obj: AbstractPosition2D, /) -> Shaped[Quantity["length"], "*batch 2
return xp.stack(tuple(field_values(cart)), axis=-1)


@conversion_method(type_from=CartesianAcceleration2D, type_to=Quantity) # type: ignore[misc]
@conversion_method(type_from=CartesianVelocity2D, type_to=Quantity) # type: ignore[misc]
def vec_diff_to_q(obj: CartesianVelocity2D, /) -> Shaped[Quantity["speed"], "*batch 2"]:
def vec_diff_to_q(obj: CartesianVelocity2D, /) -> Shaped[AbstractQuantity, "*batch 2"]:
"""`coordinax.CartesianVelocity2D` -> `unxt.Quantity`."""
return xp.stack(tuple(field_values(full_shaped(obj))), axis=-1)
5 changes: 3 additions & 2 deletions src/coordinax/_d3/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from coordinax._base import AbstractVector
from coordinax._base_pos import AbstractPosition
from coordinax._base_vel import AdditionMixin
from coordinax._mixins import AvalMixin
from coordinax._utils import classproperty


Expand Down Expand Up @@ -111,7 +112,7 @@ def __sub__(


@final
class CartesianVelocity3D(AbstractVelocity3D, AdditionMixin):
class CartesianVelocity3D(AvalMixin, AdditionMixin, AbstractVelocity3D):
"""Cartesian differential representation."""

d_x: ct.BatchableSpeed = eqx.field(
Expand Down Expand Up @@ -156,7 +157,7 @@ def norm(self, _: AbstractPosition3D | None = None, /) -> ct.BatchableSpeed:


@final
class CartesianAcceleration3D(AbstractAcceleration3D):
class CartesianAcceleration3D(AvalMixin, AbstractAcceleration3D):
"""Cartesian differential representation."""

d2_x: ct.BatchableSpeed = eqx.field(
Expand Down
5 changes: 3 additions & 2 deletions src/coordinax/_dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .base import AbstractAccelerationND, AbstractPositionND, AbstractVelocityND
from coordinax._base import AbstractVector
from coordinax._base_pos import AbstractPosition
from coordinax._mixins import AvalMixin
from coordinax._utils import classproperty

##############################################################################
Expand Down Expand Up @@ -282,7 +283,7 @@ def _mul_vcnd(lhs: ArrayLike, rhs: CartesianPositionND, /) -> CartesianPositionN


@final
class CartesianVelocityND(AbstractVelocityND):
class CartesianVelocityND(AvalMixin, AbstractVelocityND):
"""Cartesian differential representation.
Examples
Expand Down Expand Up @@ -429,7 +430,7 @@ def constructor(


@final
class CartesianAccelerationND(AbstractAccelerationND):
class CartesianAccelerationND(AvalMixin, AbstractAccelerationND):
"""Cartesian N-dimensional acceleration representation.
Examples
Expand Down
120 changes: 120 additions & 0 deletions src/coordinax/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Mixin classes."""

__all__: list[str] = []

import jax
from plum import convert

from unxt import Quantity

from ._funcs import represent_as


class AvalMixin:
"""Mixin class to add an ``aval`` method.
See [quax doc](https://docs.kidger.site/quax/examples/custom_rules/) for
more details.
"""

def aval(self) -> jax.core.ShapedArray:
"""Return the vector as a JAX array.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
1 dimensional vectors:
>>> vec = cx.CartesianPosition1D.constructor([1], "m")
>>> vec.aval()
ConcreteArray([1.], dtype=float32)
>>> vec = cx.RadialPosition.constructor([1], "m")
>>> vec.aval()
ConcreteArray([1.], dtype=float32)
>>> vec = cx.CartesianVelocity1D.constructor([1], "m/s")
>>> vec.aval()
ConcreteArray([1], dtype=int32)
>>> vec = cx.RadialVelocity.constructor([1], "m/s")
>>> try: vec.aval()
... except NotImplementedError as e: print("nope")
nope
>>> vec = cx.CartesianAcceleration1D.constructor([1], "m/s2")
>>> vec.aval()
ConcreteArray([1], dtype=int32)
>>> vec = cx.RadialAcceleration.constructor([1], "m/s2")
>>> try: vec.aval()
... except NotImplementedError as e: print("nope")
nope
2 dimensional vectors:
>>> vec = cx.CartesianPosition2D.constructor([1, 2], "m")
>>> vec.aval()
ConcreteArray([1. 2.], dtype=float32)
>>> vec = cx.PolarPosition(r=Quantity(1, "m"), phi=Quantity(0, "rad"))
>>> vec.aval()
ConcreteArray([1. 0.], dtype=float32)
>>> vec = cx.CartesianVelocity2D.constructor([1, 2], "m/s")
>>> vec.aval()
ConcreteArray([1. 2.], dtype=float32)
>>> vec = cx.PolarVelocity(d_r=Quantity(1, "m/s"), d_phi=Quantity(0, "rad/s"))
>>> try: vec.aval()
... except NotImplementedError as e: print("nope")
nope
>>> vec = cx.CartesianAcceleration2D.constructor([1,2], "m/s2")
>>> vec.aval()
ConcreteArray([1. 2.], dtype=float32)
>>> vec = cx.PolarAcceleration(d2_r=Quantity(1, "m/s2"), d2_phi=Quantity(0, "rad/s2"))
>>> try: vec.aval()
... except NotImplementedError as e: print("nope")
nope
3 dimensional vectors:
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> vec.aval()
ConcreteArray([1. 2. 3.], dtype=float32)
>>> vec = cx.CartesianPosition3D.constructor([[1, 2, 3], [4, 5, 6]], "m")
>>> vec.aval()
ConcreteArray([[1. 2. 3.]
[4. 5. 6.]], dtype=float32)
>>> vec = cx.SphericalPosition(r=Quantity(1, "m"), phi=Quantity(0, "rad"), theta=Quantity(0, "rad"))
>>> vec.aval()
ConcreteArray([0. 0. 1.], dtype=float32)
>>> vec = cx.CartesianVelocity3D.constructor([1,2,3], "m/s")
>>> vec.aval()
ConcreteArray([1. 2. 3.], dtype=float32)
>>> vec = cx.SphericalVelocity(d_r=Quantity(1, "m/s"), d_phi=Quantity(0, "rad/s"), d_theta=Quantity(0, "rad/s"))
>>> try: vec.aval()
... except NotImplementedError as e: print("nope")
nope
>>> vec = cx.CartesianAcceleration3D.constructor([1,2,3], "m/s2")
>>> vec.aval()
ConcreteArray([1. 2. 3.], dtype=float32)
>>> vec = cx.SphericalAcceleration(d2_r=Quantity(1, "m/s2"), d2_phi=Quantity(0, "rad/s2"), d2_theta=Quantity(0, "rad/s2"))
>>> try: vec.aval()
... except NotImplementedError as e: print("nope")
nope
""" # noqa: E501
# TODO: change to UncheckedQuantity
target = self._cartesian_cls # type: ignore[attr-defined]
return jax.core.get_aval(convert(represent_as(target), Quantity).value)

0 comments on commit 9e00066

Please sign in to comment.