Skip to content

Commit

Permalink
feat: consolidate compat (#166)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Aug 21, 2024
1 parent 38e6619 commit 9d577db
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 56 deletions.
3 changes: 3 additions & 0 deletions src/coordinax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
__all__ += utils.__all__

# Interoperability
from ._coordinax import compat # noqa: E402

# Astropy
from ._interop import coordinax_interop_astropy # noqa: E402

Expand All @@ -86,4 +88,5 @@
funcs,
RUNTIME_TYPECHECKER,
coordinax_interop_astropy,
compat,
)
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ def to_units(self, units: Any, /) -> "AbstractVector":
Parameters
----------
units : `unxt.AbstractUnitSystem`
units : Any
The units to convert to according to the physical type of the
components.
components. This is passed to [`unxt.unitsystem`][].
Examples
--------
Expand Down
184 changes: 184 additions & 0 deletions src/coordinax/_coordinax/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Intra-ecosystem Compatibility."""

__all__: list[str] = []


from jaxtyping import Shaped
from plum import conversion_method, convert

import quaxed.array_api as xp
from dataclassish import field_values
from unxt import AbstractQuantity, Distance, Quantity, UncheckedQuantity

from .base_pos import AbstractPosition
from coordinax._coordinax.utils import full_shaped

#####################################################################
# Convert to Quantity


@conversion_method(type_from=AbstractPosition, type_to=AbstractQuantity) # type: ignore[misc]
def convert_pos_to_absquantity(obj: AbstractPosition, /) -> AbstractQuantity:
"""`coordinax.AbstractPosition` -> `unxt.AbstractQuantity`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
cart = full_shaped(obj.represent_as(obj._cartesian_cls)) # noqa: SLF001
return xp.stack(tuple(field_values(cart)), axis=-1)


@conversion_method(type_from=AbstractPosition, type_to=Quantity) # type: ignore[misc]
def convert_pos_to_q(obj: AbstractPosition, /) -> Quantity["length"]:
"""`coordinax.AbstractPosition` -> `unxt.Quantity`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
return convert(convert(obj, AbstractQuantity), Quantity)


@conversion_method(type_from=AbstractPosition, type_to=UncheckedQuantity) # type: ignore[misc]
def convert_pos_to_uncheckedq(
obj: AbstractPosition, /
) -> Shaped[UncheckedQuantity, "*batch 1"]:
"""`coordinax.AbstractPosition` -> `unxt.UncheckedQuantity`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity, UncheckedQuantity
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
return convert(convert(obj, AbstractQuantity), UncheckedQuantity)


@conversion_method(type_from=AbstractPosition, type_to=Distance) # type: ignore[misc]
def convert_pos_to_distance(obj: AbstractPosition, /) -> Shaped[Distance, "*batch 1"]:
"""`coordinax.AbstractPosition` -> `unxt.Distance`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity, Distance
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, Distance)
Distance(Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, Distance)
Distance(Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, Distance)
Distance(Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, Distance)
Distance(Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, Distance)
Distance(Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, Distance)
Distance(Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, Distance)
Distance(Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
return convert(convert(obj, AbstractQuantity), Distance)
16 changes: 4 additions & 12 deletions src/coordinax/_coordinax/d1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

import quaxed.array_api as xp
from dataclassish import field_values
from unxt import Quantity
from unxt import AbstractQuantity, Quantity

from .base import AbstractPosition1D
from .cartesian import CartesianAcceleration1D, CartesianPosition1D, CartesianVelocity1D
from coordinax._coordinax.operators.base import AbstractOperator, op_call_dispatch
from coordinax._coordinax.typing import TimeBatchOrScalar
Expand All @@ -22,18 +21,11 @@
# Convert to Quantity


@conversion_method(type_from=AbstractPosition1D, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: AbstractPosition1D, /) -> Shaped[Quantity["length"], "*batch 1"]:
"""`coordinax.AbstractPosition1D` -> `unxt.Quantity`."""
cart = full_shaped(obj.represent_as(CartesianPosition1D))
return xp.stack(tuple(field_values(cart)), axis=-1)


@conversion_method(type_from=CartesianAcceleration1D, type_to=Quantity) # type: ignore[misc]
@conversion_method(type_from=CartesianVelocity1D, type_to=Quantity) # type: ignore[misc]
@conversion_method(type_from=CartesianAcceleration1D, type_to=AbstractQuantity) # type: ignore[misc]
@conversion_method(type_from=CartesianVelocity1D, type_to=AbstractQuantity) # type: ignore[misc]
def vec_diff_to_q(
obj: CartesianVelocity1D | CartesianAcceleration1D, /
) -> Shaped[Quantity["speed"], "*batch 1"]:
) -> Shaped[AbstractQuantity, "*batch 1"]:
"""`coordinax.CartesianVelocity1D` -> `unxt.Quantity`."""
return xp.stack(tuple(field_values(full_shaped(obj))), axis=-1)

Expand Down
8 changes: 0 additions & 8 deletions src/coordinax/_coordinax/d2/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dataclassish import field_values
from unxt import AbstractQuantity, Quantity

from .base import AbstractPosition2D
from .cartesian import CartesianAcceleration2D, CartesianPosition2D, CartesianVelocity2D
from coordinax._coordinax.operators.base import AbstractOperator, op_call_dispatch
from coordinax._coordinax.typing import TimeBatchOrScalar
Expand All @@ -22,13 +21,6 @@
# Convert to Quantity


@conversion_method(type_from=AbstractPosition2D, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: AbstractPosition2D, /) -> Shaped[Quantity["length"], "*batch 2"]:
"""`coordinax.AbstractPosition2D` -> `unxt.Quantity`."""
cart = full_shaped(obj.represent_as(CartesianPosition2D))
return xp.stack(tuple(field_values(cart)), axis=-1)


@conversion_method(type_from=CartesianAcceleration2D, type_to=Quantity) # type: ignore[misc]
@conversion_method(type_from=CartesianVelocity2D, type_to=Quantity) # type: ignore[misc]
def vec_diff_to_q(obj: CartesianVelocity2D, /) -> Shaped[AbstractQuantity, "*batch 2"]:
Expand Down
34 changes: 0 additions & 34 deletions src/coordinax/_coordinax/d3/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dataclassish import field_values
from unxt import Quantity

from .base import AbstractPosition3D
from .cartesian import CartesianAcceleration3D, CartesianPosition3D, CartesianVelocity3D
from coordinax._coordinax.operators.base import AbstractOperator, op_call_dispatch
from coordinax._coordinax.typing import TimeBatchOrScalar
Expand All @@ -22,39 +21,6 @@
# Convert to Quantity


@conversion_method(AbstractPosition3D, Quantity) # type: ignore[misc]
def vec_to_q(obj: AbstractPosition3D, /) -> Shaped[Quantity["length"], "*batch 3"]:
"""`coordinax.AbstractPosition3D` -> `unxt.Quantity`.
Examples
--------
>>> import coordinax as cx
>>> from plum import convert
>>> from unxt import Quantity
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "kpc")
>>> convert(vec, Quantity)
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='kpc')
>>> vec = cx.SphericalPosition(r=Quantity(1, unit="kpc"),
... theta=Quantity(2, unit="deg"),
... phi=Quantity(3, unit="deg"))
>>> convert(vec, Quantity)
Quantity['length'](Array([0.03485167, 0.0018265 , 0.99939084], dtype=float32),
unit='kpc')
>>> vec = cx.CylindricalPosition(rho=Quantity(1, unit="kpc"),
... phi=Quantity(2, unit="deg"),
... z=Quantity(3, unit="pc"))
>>> convert(vec, Quantity)
Quantity['length'](Array([0.99939084, 0.0348995 , 0.003 ], dtype=float32),
unit='kpc')
"""
cart = full_shaped(obj.represent_as(CartesianPosition3D))
return xp.stack(tuple(field_values(cart)), axis=-1)


@conversion_method(CartesianAcceleration3D, Quantity) # type: ignore[misc]
@conversion_method(CartesianVelocity3D, Quantity) # type: ignore[misc]
def vec_diff_to_q(obj: CartesianVelocity3D, /) -> Shaped[Quantity["speed"], "*batch 3"]:
Expand Down

0 comments on commit 9d577db

Please sign in to comment.