Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generic representations #186

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
if TYPE_CHECKING:
from typing_extensions import Self

BT = TypeVar("BT", bound="AbstractVector")
VT = TypeVar("VT", bound="AbstractVector")


class ToUnitsOptions(Enum):
Expand Down Expand Up @@ -640,7 +640,7 @@ def sizes(self) -> MappingProxyType[str, int]:
# Convenience methods

@abstractmethod
def represent_as(self, target: type[BT], /, *args: Any, **kwargs: Any) -> BT:
def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
"""Represent the vector as another type."""
raise NotImplementedError # pragma: no cover

Expand Down
5 changes: 3 additions & 2 deletions src/coordinax/_coordinax/base/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import override

import jax
from quax import register
Expand All @@ -17,6 +18,7 @@
from .base import AbstractVector
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity
from coordinax._coordinax.funcs import represent_as
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
Expand Down Expand Up @@ -111,6 +113,7 @@ def __neg__(self) -> "Self":
# ===============================================================
# Convenience methods

@override
def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT:
"""Represent the vector as another type.
Expand Down Expand Up @@ -152,8 +155,6 @@ def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT
Quantity['acceleration'](Array(13.363062, dtype=float32), unit='m / s2')
"""
from coordinax import represent_as # pylint: disable=import-outside-toplevel

return represent_as(self, target, *args, **kwargs)

@partial(jax.jit, inline=True)
Expand Down
13 changes: 12 additions & 1 deletion src/coordinax/_coordinax/d3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
# pylint: disable=duplicate-code
"""3-dimensional representations."""

from . import base, cartesian, compat, constructor, cylindrical, spherical, transform
from . import (
base,
cartesian,
compat,
constructor,
cylindrical,
generic,
spherical,
transform,
)
from .base import *
from .cartesian import *
from .compat import *
from .constructor import *
from .cylindrical import *
from .generic import *
from .spherical import *
from .transform import *

Expand All @@ -15,6 +25,7 @@
__all__ += cartesian.__all__
__all__ += cylindrical.__all__
__all__ += spherical.__all__
__all__ += generic.__all__
__all__ += transform.__all__
__all__ += compat.__all__
__all__ += constructor.__all__
108 changes: 108 additions & 0 deletions src/coordinax/_coordinax/d3/generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Built-in vector classes."""

__all__ = [
"CartesianGeneric3D",
]

from dataclasses import fields
from functools import partial
from typing import Any, TypeVar, final

import equinox as eqx
import jax

import quaxed.numpy as jnp
from unxt import AbstractQuantity, Quantity

import coordinax._coordinax.typing as ct
from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base.mixins import AvalMixin

VT = TypeVar("VT", bound="CartesianGeneric3D")


@final
class CartesianGeneric3D(AvalMixin, AbstractVector):
"""Generic 3D Cartesian coordinates.
The fields of this class are not restricted to any specific dimensions.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianGeneric3D.constructor([1, 2, 3], "kg m /s")
>>> vec
CartesianGeneric3D(
x=Quantity[...]( value=f32[], unit=Unit("kg m / s") ),
y=Quantity[...]( value=f32[], unit=Unit("kg m / s") ),
z=Quantity[...]( value=f32[], unit=Unit("kg m / s") )
)
"""

x: ct.BatchableFloatScalarQ = eqx.field(
converter=partial(Quantity.constructor, dtype=float)
)

y: ct.BatchableFloatScalarQ = eqx.field(
converter=partial(Quantity.constructor, dtype=float)
)

z: ct.BatchableFloatScalarQ = eqx.field(
converter=partial(Quantity.constructor, dtype=float)
)

# -----------------------------------------------------
# Unary operations

def __neg__(self) -> "CartesianGeneric3D":
"""Negate the `coordinax.CartesianGeneric3D`.
Examples
--------
>>> import coordinax as cx
>>> q = cx.CartesianGeneric3D.constructor([1, 2, 3], "kpc")
>>> (-q).x
Quantity['length'](Array(-1., dtype=float32), unit='kpc')
"""
return jax.tree.map(jnp.negative, self)

# ===============================================================
# Convenience methods

def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
"""Represent the vector as another type."""
raise NotImplementedError # pragma: no cover


# =====================================================
# Constructors


@CartesianGeneric3D.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(
cls: type[CartesianGeneric3D],
obj: AbstractQuantity, # TODO: Shaped[AbstractQuantity, "*batch 3"]
/,
) -> CartesianGeneric3D:
"""Construct a 3D Cartesian position.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianGeneric3D.constructor(Quantity([1, 2, 3], "m"))
>>> vec
CartesianGeneric3D(
x=Quantity[...](value=f32[], unit=Unit("m")),
y=Quantity[...](value=f32[], unit=Unit("m")),
z=Quantity[...](value=f32[], unit=Unit("m"))
)
"""
comps = {f.name: obj[..., i] for i, f in enumerate(fields(cls))}
return cls(**comps)
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import astropy.units as u
from jaxtyping import Float, Int, Shaped

from unxt import AbstractDistance, Quantity
from unxt import AbstractDistance, AbstractQuantity, Quantity

Unit: TypeAlias = u.Unit | u.UnitBase | u.CompositeUnit

FloatScalarQ = Float[Quantity, ""]
FloatScalarQ = Float[AbstractQuantity, ""]
BatchFloatScalarQ = Shaped[FloatScalarQ, "*batch"]
BatchableFloatScalarQ = Shaped[FloatScalarQ, "*#batch"]

Expand Down