Skip to content

Commit

Permalink
feat: generic (#186)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Sep 16, 2024
1 parent 964e879 commit e3f002d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

BT = TypeVar("BT", bound="AbstractVector")
VT = TypeVar("VT", bound="AbstractVector")


class ToUnitsOptions(Enum):
Expand Down Expand Up @@ -640,7 +640,7 @@ def sizes(self) -> MappingProxyType[str, int]:
# Convenience methods

@abstractmethod
def represent_as(self, target: type[BT], /, *args: Any, **kwargs: Any) -> BT:
def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
"""Represent the vector as another type."""
raise NotImplementedError # pragma: no cover

Expand Down
5 changes: 3 additions & 2 deletions src/coordinax/_coordinax/base/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import override

import jax
from quax import register
Expand All @@ -17,6 +18,7 @@
from .base import AbstractVector
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity
from coordinax._coordinax.funcs import represent_as
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
Expand Down Expand Up @@ -111,6 +113,7 @@ def __neg__(self) -> "Self":
# ===============================================================
# Convenience methods

@override
def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT:
"""Represent the vector as another type.
Expand Down Expand Up @@ -152,8 +155,6 @@ def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT
Quantity['acceleration'](Array(13.363062, dtype=float32), unit='m / s2')
"""
from coordinax import represent_as # pylint: disable=import-outside-toplevel

return represent_as(self, target, *args, **kwargs)

@partial(jax.jit, inline=True)
Expand Down
13 changes: 12 additions & 1 deletion src/coordinax/_coordinax/d3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
# pylint: disable=duplicate-code
"""3-dimensional representations."""

from . import base, cartesian, compat, constructor, cylindrical, spherical, transform
from . import (
base,
cartesian,
compat,
constructor,
cylindrical,
generic,
spherical,
transform,
)
from .base import *
from .cartesian import *
from .compat import *
from .constructor import *
from .cylindrical import *
from .generic import *
from .spherical import *
from .transform import *

Expand All @@ -15,6 +25,7 @@
__all__ += cartesian.__all__
__all__ += cylindrical.__all__
__all__ += spherical.__all__
__all__ += generic.__all__
__all__ += transform.__all__
__all__ += compat.__all__
__all__ += constructor.__all__
108 changes: 108 additions & 0 deletions src/coordinax/_coordinax/d3/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Built-in vector classes."""

__all__ = [
"CartesianGeneric3D",
]

from dataclasses import fields
from functools import partial
from typing import Any, TypeVar, final

import equinox as eqx
import jax

import quaxed.numpy as jnp
from unxt import AbstractQuantity, Quantity

import coordinax._coordinax.typing as ct
from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base.mixins import AvalMixin

VT = TypeVar("VT", bound="CartesianGeneric3D")


@final
class CartesianGeneric3D(AvalMixin, AbstractVector):
"""Generic 3D Cartesian coordinates.
The fields of this class are not restricted to any specific dimensions.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianGeneric3D.constructor([1, 2, 3], "kg m /s")
>>> vec
CartesianGeneric3D(
x=Quantity[...]( value=f32[], unit=Unit("kg m / s") ),
y=Quantity[...]( value=f32[], unit=Unit("kg m / s") ),
z=Quantity[...]( value=f32[], unit=Unit("kg m / s") )
)
"""

x: ct.BatchableFloatScalarQ = eqx.field(
converter=partial(Quantity.constructor, dtype=float)
)

y: ct.BatchableFloatScalarQ = eqx.field(
converter=partial(Quantity.constructor, dtype=float)
)

z: ct.BatchableFloatScalarQ = eqx.field(
converter=partial(Quantity.constructor, dtype=float)
)

# -----------------------------------------------------
# Unary operations

def __neg__(self) -> "CartesianGeneric3D":
"""Negate the `coordinax.CartesianGeneric3D`.
Examples
--------
>>> import coordinax as cx
>>> q = cx.CartesianGeneric3D.constructor([1, 2, 3], "kpc")
>>> (-q).x
Quantity['length'](Array(-1., dtype=float32), unit='kpc')
"""
return jax.tree.map(jnp.negative, self)

# ===============================================================
# Convenience methods

def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
"""Represent the vector as another type."""
raise NotImplementedError # pragma: no cover


# =====================================================
# Constructors


@CartesianGeneric3D.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(
cls: type[CartesianGeneric3D],
obj: AbstractQuantity, # TODO: Shaped[AbstractQuantity, "*batch 3"]
/,
) -> CartesianGeneric3D:
"""Construct a 3D Cartesian position.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianGeneric3D.constructor(Quantity([1, 2, 3], "m"))
>>> vec
CartesianGeneric3D(
x=Quantity[...](value=f32[], unit=Unit("m")),
y=Quantity[...](value=f32[], unit=Unit("m")),
z=Quantity[...](value=f32[], unit=Unit("m"))
)
"""
comps = {f.name: obj[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import astropy.units as u
from jaxtyping import Float, Int, Shaped

from unxt import AbstractDistance, Quantity
from unxt import AbstractDistance, AbstractQuantity, Quantity

Unit: TypeAlias = u.Unit | u.UnitBase | u.CompositeUnit

FloatScalarQ = Float[Quantity, ""]
FloatScalarQ = Float[AbstractQuantity, ""]
BatchFloatScalarQ = Shaped[FloatScalarQ, "*batch"]
BatchableFloatScalarQ = Shaped[FloatScalarQ, "*#batch"]

Expand Down

0 comments on commit e3f002d

Please sign in to comment.