From d6c9c441253ff2d01fe7fa4b8ea956c67316d407 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 15 Sep 2024 12:02:20 -0400 Subject: [PATCH 1/2] refactor: consolidate base classes Signed-off-by: nstarman --- src/coordinax/__init__.py | 12 --- src/coordinax/_coordinax/base/__init__.py | 18 ++++ src/coordinax/_coordinax/{ => base}/base.py | 4 +- .../_coordinax/{ => base}/base_acc.py | 2 +- .../_coordinax/{ => base}/base_pos.py | 87 ++++++++++++------- .../_coordinax/{ => base}/base_vel.py | 2 +- src/coordinax/_coordinax/{ => base}/mixins.py | 2 +- src/coordinax/_coordinax/compat.py | 2 +- src/coordinax/_coordinax/d1/base.py | 10 ++- src/coordinax/_coordinax/d1/cartesian.py | 4 +- src/coordinax/_coordinax/d1/transform.py | 3 +- src/coordinax/_coordinax/d2/base.py | 10 ++- src/coordinax/_coordinax/d2/cartesian.py | 4 +- src/coordinax/_coordinax/d2/transform.py | 2 +- src/coordinax/_coordinax/d3/base.py | 10 ++- src/coordinax/_coordinax/d3/cartesian.py | 4 +- src/coordinax/_coordinax/d3/spherical.py | 2 +- src/coordinax/_coordinax/d3/transform.py | 2 +- src/coordinax/_coordinax/d4/base.py | 3 +- src/coordinax/_coordinax/dn/base.py | 10 ++- src/coordinax/_coordinax/dn/cartesian.py | 4 +- src/coordinax/_coordinax/dn/poincare.py | 3 +- src/coordinax/_coordinax/operators/base.py | 2 +- .../_coordinax/operators/composite.py | 2 +- .../operators/galilean/translation.py | 2 +- .../_coordinax/operators/identity.py | 2 +- src/coordinax/_coordinax/space.py | 10 ++- .../_coordinax/transform/accelerations.py | 8 +- .../_coordinax/transform/differentials.py | 3 +- tests/test_jax_ops.py | 2 +- 30 files changed, 133 insertions(+), 98 deletions(-) create mode 100644 src/coordinax/_coordinax/base/__init__.py rename src/coordinax/_coordinax/{ => base}/base.py (99%) rename src/coordinax/_coordinax/{ => base}/base_acc.py (99%) rename src/coordinax/_coordinax/{ => base}/base_pos.py (95%) rename src/coordinax/_coordinax/{ => base}/base_vel.py (99%) rename src/coordinax/_coordinax/{ => base}/mixins.py (98%) diff --git a/src/coordinax/__init__.py b/src/coordinax/__init__.py index 9902cbd5..a951f30f 100644 --- a/src/coordinax/__init__.py +++ b/src/coordinax/__init__.py @@ -10,9 +10,6 @@ from . import operators from ._coordinax import ( base, - base_acc, - base_pos, - base_vel, d1, d2, d3, @@ -26,9 +23,6 @@ utils, ) from ._coordinax.base import * -from ._coordinax.base_acc import * -from ._coordinax.base_pos import * -from ._coordinax.base_vel import * from ._coordinax.d1 import * from ._coordinax.d2 import * from ._coordinax.d3 import * @@ -46,9 +40,6 @@ __all__ = ["__version__", "operators"] __all__ += funcs.__all__ __all__ += base.__all__ -__all__ += base_pos.__all__ -__all__ += base_vel.__all__ -__all__ += base_acc.__all__ __all__ += d1.__all__ __all__ += d2.__all__ __all__ += d3.__all__ @@ -70,9 +61,6 @@ # Cleanup del ( base, - base_vel, - base_pos, - base_acc, space, exceptions, transform, diff --git a/src/coordinax/_coordinax/base/__init__.py b/src/coordinax/_coordinax/base/__init__.py new file mode 100644 index 00000000..e0989195 --- /dev/null +++ b/src/coordinax/_coordinax/base/__init__.py @@ -0,0 +1,18 @@ +"""Bases.""" + +__all__ = [ + # Base + "AbstractVector", + "ToUnitsOptions", + # Position + "AbstractPosition", + # Velocity + "AbstractVelocity", + # Acceleration + "AbstractAcceleration", +] + +from .base import AbstractVector, ToUnitsOptions +from .base_acc import AbstractAcceleration +from .base_pos import AbstractPosition +from .base_vel import AbstractVelocity diff --git a/src/coordinax/_coordinax/base.py b/src/coordinax/_coordinax/base/base.py similarity index 99% rename from src/coordinax/_coordinax/base.py rename to src/coordinax/_coordinax/base/base.py index e85609df..9ebe2de8 100644 --- a/src/coordinax/_coordinax/base.py +++ b/src/coordinax/_coordinax/base/base.py @@ -36,8 +36,8 @@ unitsystem, ) -from .typing import Unit -from .utils import classproperty, full_shaped +from coordinax._coordinax.typing import Unit +from coordinax._coordinax.utils import classproperty, full_shaped if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_coordinax/base_acc.py b/src/coordinax/_coordinax/base/base_acc.py similarity index 99% rename from src/coordinax/_coordinax/base_acc.py rename to src/coordinax/_coordinax/base/base_acc.py index 197aa570..babb792a 100644 --- a/src/coordinax/_coordinax/base_acc.py +++ b/src/coordinax/_coordinax/base/base_acc.py @@ -18,7 +18,7 @@ from .base import AbstractVector from .base_pos import AbstractPosition from .base_vel import AbstractVelocity -from .utils import classproperty +from coordinax._coordinax.utils import classproperty if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_coordinax/base_pos.py b/src/coordinax/_coordinax/base/base_pos.py similarity index 95% rename from src/coordinax/_coordinax/base_pos.py rename to src/coordinax/_coordinax/base/base_pos.py index 1b8c0b2e..b8d3ebe1 100644 --- a/src/coordinax/_coordinax/base_pos.py +++ b/src/coordinax/_coordinax/base/base_pos.py @@ -19,10 +19,10 @@ from dataclassish import field_items from unxt import Quantity -from . import typing as ct from .base import AbstractVector from .mixins import AvalMixin -from .utils import classproperty +from coordinax._coordinax import typing as ct +from coordinax._coordinax.utils import classproperty if TYPE_CHECKING: from typing_extensions import Self @@ -227,17 +227,42 @@ def _add_qq(lhs: AbstractPosition, rhs: AbstractPosition, /) -> AbstractPosition ).represent_as(type(lhs)) -@register(jax.lax.sub_p) # type: ignore[misc] -def _sub_qq(lhs: AbstractPosition, rhs: AbstractPosition) -> AbstractPosition: - """Add another object to this vector.""" - # The base implementation is to convert to Cartesian and perform the - # operation. Cartesian coordinates do not have any branch cuts or - # singularities or ranges that need to be handled, so this is a safe - # default. - return qlax.sub( - lhs.represent_as(lhs._cartesian_cls), # noqa: SLF001 - rhs.represent_as(lhs._cartesian_cls), # noqa: SLF001 - ).represent_as(type(lhs)) +# ------------------------------------------------ + + +@register(jax.lax.convert_element_type_p) # type: ignore[misc] +def _convert_element_type_p( + operand: AbstractPosition, **kwargs: Any +) -> AbstractPosition: + """Convert the element type of a quantity.""" + # TODO: examples + return replace( + operand, + **{k: qlax.convert_element_type(v, **kwargs) for k, v in field_items(operand)}, + ) + + +# ------------------------------------------------ + + +@register(jax.lax.div_p) # type: ignore[misc] +def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition: + """Divide a vector by a scalar. + + Examples + -------- + >>> import quaxed.array_api as jnp + >>> import coordinax as cx + + >>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> jnp.divide(vec, 2).x + Quantity['length'](Array(0.5, dtype=float32), unit='m') + + >>> (vec / 2).x + Quantity['length'](Array(0.5, dtype=float32), unit='m') + + """ + return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)}) # ------------------------------------------------ @@ -389,26 +414,6 @@ def _mul_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> Quantity: return qlax.mul(lq, rq) # re-dispatch to Quantities -@register(jax.lax.div_p) # type: ignore[misc] -def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition: - """Divide a vector by a scalar. - - Examples - -------- - >>> import quaxed.array_api as jnp - >>> import coordinax as cx - - >>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m") - >>> jnp.divide(vec, 2).x - Quantity['length'](Array(0.5, dtype=float32), unit='m') - - >>> (vec / 2).x - Quantity['length'](Array(0.5, dtype=float32), unit='m') - - """ - return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)}) - - # ------------------------------------------------ @@ -446,3 +451,19 @@ def _reshape_pos( for k, v in field_items(operand) }, ) + + +# ------------------------------------------------ + + +@register(jax.lax.sub_p) # type: ignore[misc] +def _sub_qq(lhs: AbstractPosition, rhs: AbstractPosition) -> AbstractPosition: + """Add another object to this vector.""" + # The base implementation is to convert to Cartesian and perform the + # operation. Cartesian coordinates do not have any branch cuts or + # singularities or ranges that need to be handled, so this is a safe + # default. + return qlax.sub( + lhs.represent_as(lhs._cartesian_cls), # noqa: SLF001 + rhs.represent_as(lhs._cartesian_cls), # noqa: SLF001 + ).represent_as(type(lhs)) diff --git a/src/coordinax/_coordinax/base_vel.py b/src/coordinax/_coordinax/base/base_vel.py similarity index 99% rename from src/coordinax/_coordinax/base_vel.py rename to src/coordinax/_coordinax/base/base_vel.py index 6dc6e2ec..7e9cbc5f 100644 --- a/src/coordinax/_coordinax/base_vel.py +++ b/src/coordinax/_coordinax/base/base_vel.py @@ -15,7 +15,7 @@ from .base import AbstractVector from .base_pos import AbstractPosition -from .utils import classproperty +from coordinax._coordinax.utils import classproperty if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_coordinax/mixins.py b/src/coordinax/_coordinax/base/mixins.py similarity index 98% rename from src/coordinax/_coordinax/mixins.py rename to src/coordinax/_coordinax/base/mixins.py index ea9e6615..d0c08f7f 100644 --- a/src/coordinax/_coordinax/mixins.py +++ b/src/coordinax/_coordinax/base/mixins.py @@ -7,7 +7,7 @@ from unxt import Quantity -from .funcs import represent_as +from coordinax._coordinax.funcs import represent_as class AvalMixin: diff --git a/src/coordinax/_coordinax/compat.py b/src/coordinax/_coordinax/compat.py index 4badb45b..096bf06d 100644 --- a/src/coordinax/_coordinax/compat.py +++ b/src/coordinax/_coordinax/compat.py @@ -10,7 +10,7 @@ from dataclassish import field_values from unxt import AbstractQuantity, Distance, Quantity, UncheckedQuantity -from .base_pos import AbstractPosition +from .base import AbstractPosition from coordinax._coordinax.utils import full_shaped ##################################################################### diff --git a/src/coordinax/_coordinax/d1/base.py b/src/coordinax/_coordinax/d1/base.py index 88461526..8b5b2201 100644 --- a/src/coordinax/_coordinax/d1/base.py +++ b/src/coordinax/_coordinax/d1/base.py @@ -11,10 +11,12 @@ import quaxed.numpy as jnp from unxt import Quantity -from coordinax._coordinax.base import AbstractVector -from coordinax._coordinax.base_acc import AbstractAcceleration -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import ( + AbstractAcceleration, + AbstractPosition, + AbstractVector, + AbstractVelocity, +) from coordinax._coordinax.utils import classproperty ##################################################################### diff --git a/src/coordinax/_coordinax/d1/cartesian.py b/src/coordinax/_coordinax/d1/cartesian.py index f765298a..2ae24580 100644 --- a/src/coordinax/_coordinax/d1/cartesian.py +++ b/src/coordinax/_coordinax/d1/cartesian.py @@ -21,8 +21,8 @@ import coordinax._coordinax.typing as ct from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.mixins import AvalMixin +from coordinax._coordinax.base import AbstractPosition +from coordinax._coordinax.base.mixins import AvalMixin from coordinax._coordinax.utils import classproperty diff --git a/src/coordinax/_coordinax/d1/transform.py b/src/coordinax/_coordinax/d1/transform.py index ea3ff3bb..b989bc82 100644 --- a/src/coordinax/_coordinax/d1/transform.py +++ b/src/coordinax/_coordinax/d1/transform.py @@ -9,8 +9,7 @@ from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D from .cartesian import CartesianAcceleration1D, CartesianPosition1D, CartesianVelocity1D from .radial import RadialAcceleration, RadialPosition, RadialVelocity -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import AbstractPosition, AbstractVelocity ############################################################################### # 1D diff --git a/src/coordinax/_coordinax/d2/base.py b/src/coordinax/_coordinax/d2/base.py index df21253f..863bda10 100644 --- a/src/coordinax/_coordinax/d2/base.py +++ b/src/coordinax/_coordinax/d2/base.py @@ -5,10 +5,12 @@ from abc import abstractmethod -from coordinax._coordinax.base import AbstractVector -from coordinax._coordinax.base_acc import AbstractAcceleration -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import ( + AbstractAcceleration, + AbstractPosition, + AbstractVector, + AbstractVelocity, +) from coordinax._coordinax.utils import classproperty diff --git a/src/coordinax/_coordinax/d2/cartesian.py b/src/coordinax/_coordinax/d2/cartesian.py index 3891e125..beffa8d7 100644 --- a/src/coordinax/_coordinax/d2/cartesian.py +++ b/src/coordinax/_coordinax/d2/cartesian.py @@ -21,8 +21,8 @@ import coordinax._coordinax.typing as ct from .base import AbstractAcceleration2D, AbstractPosition2D, AbstractVelocity2D -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.mixins import AvalMixin +from coordinax._coordinax.base import AbstractPosition +from coordinax._coordinax.base.mixins import AvalMixin from coordinax._coordinax.utils import classproperty diff --git a/src/coordinax/_coordinax/d2/transform.py b/src/coordinax/_coordinax/d2/transform.py index f265022c..e0dce789 100644 --- a/src/coordinax/_coordinax/d2/transform.py +++ b/src/coordinax/_coordinax/d2/transform.py @@ -11,7 +11,7 @@ from .base import AbstractPosition2D, AbstractVelocity2D from .cartesian import CartesianAcceleration2D, CartesianPosition2D, CartesianVelocity2D from .polar import PolarPosition, PolarVelocity -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition @dispatch diff --git a/src/coordinax/_coordinax/d3/base.py b/src/coordinax/_coordinax/d3/base.py index d53542a7..9ed86723 100644 --- a/src/coordinax/_coordinax/d3/base.py +++ b/src/coordinax/_coordinax/d3/base.py @@ -6,10 +6,12 @@ from abc import abstractmethod from typing_extensions import override -from coordinax._coordinax.base import AbstractVector -from coordinax._coordinax.base_acc import AbstractAcceleration -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import ( + AbstractAcceleration, + AbstractPosition, + AbstractVector, + AbstractVelocity, +) from coordinax._coordinax.utils import classproperty diff --git a/src/coordinax/_coordinax/d3/cartesian.py b/src/coordinax/_coordinax/d3/cartesian.py index 11a9ef94..9303225c 100644 --- a/src/coordinax/_coordinax/d3/cartesian.py +++ b/src/coordinax/_coordinax/d3/cartesian.py @@ -22,8 +22,8 @@ import coordinax._coordinax.typing as ct from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.mixins import AvalMixin +from coordinax._coordinax.base import AbstractPosition +from coordinax._coordinax.base.mixins import AvalMixin from coordinax._coordinax.utils import classproperty diff --git a/src/coordinax/_coordinax/d3/spherical.py b/src/coordinax/_coordinax/d3/spherical.py index 8448fb25..993f7a3e 100644 --- a/src/coordinax/_coordinax/d3/spherical.py +++ b/src/coordinax/_coordinax/d3/spherical.py @@ -37,7 +37,7 @@ import coordinax._coordinax.typing as ct from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D -from coordinax._coordinax.base_acc import AbstractAcceleration +from coordinax._coordinax.base import AbstractAcceleration from coordinax._coordinax.checks import ( check_azimuth_range, check_polar_range, diff --git a/src/coordinax/_coordinax/d3/transform.py b/src/coordinax/_coordinax/d3/transform.py index 4e734c03..7ec0a932 100644 --- a/src/coordinax/_coordinax/d3/transform.py +++ b/src/coordinax/_coordinax/d3/transform.py @@ -22,7 +22,7 @@ SphericalPosition, SphericalVelocity, ) -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition ############################################################################### # 3D diff --git a/src/coordinax/_coordinax/d4/base.py b/src/coordinax/_coordinax/d4/base.py index a325bcbe..2955a5dc 100644 --- a/src/coordinax/_coordinax/d4/base.py +++ b/src/coordinax/_coordinax/d4/base.py @@ -6,8 +6,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING -from coordinax._coordinax.base import AbstractVector -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition, AbstractVector from coordinax._coordinax.utils import classproperty if TYPE_CHECKING: diff --git a/src/coordinax/_coordinax/dn/base.py b/src/coordinax/_coordinax/dn/base.py index 96fc3ee8..7dc34c9b 100644 --- a/src/coordinax/_coordinax/dn/base.py +++ b/src/coordinax/_coordinax/dn/base.py @@ -12,10 +12,12 @@ import quaxed.lax as qlax import quaxed.numpy as jnp -from coordinax._coordinax.base import AbstractVector -from coordinax._coordinax.base_acc import AbstractAcceleration -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import ( + AbstractAcceleration, + AbstractPosition, + AbstractVector, + AbstractVelocity, +) from coordinax._coordinax.utils import classproperty if TYPE_CHECKING: diff --git a/src/coordinax/_coordinax/dn/cartesian.py b/src/coordinax/_coordinax/dn/cartesian.py index aea4dfe3..3373d02f 100644 --- a/src/coordinax/_coordinax/dn/cartesian.py +++ b/src/coordinax/_coordinax/dn/cartesian.py @@ -18,8 +18,8 @@ import coordinax._coordinax.typing as ct from .base import AbstractAccelerationND, AbstractPositionND, AbstractVelocityND -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.mixins import AvalMixin +from coordinax._coordinax.base import AbstractPosition +from coordinax._coordinax.base.mixins import AvalMixin from coordinax._coordinax.utils import classproperty ############################################################################## diff --git a/src/coordinax/_coordinax/dn/poincare.py b/src/coordinax/_coordinax/dn/poincare.py index c32312f6..3de4bda6 100644 --- a/src/coordinax/_coordinax/dn/poincare.py +++ b/src/coordinax/_coordinax/dn/poincare.py @@ -11,8 +11,7 @@ from unxt import Quantity import coordinax._coordinax.typing as ct -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import AbstractPosition, AbstractVelocity from coordinax._coordinax.utils import classproperty diff --git a/src/coordinax/_coordinax/operators/base.py b/src/coordinax/_coordinax/operators/base.py index e4671862..64e808f0 100644 --- a/src/coordinax/_coordinax/operators/base.py +++ b/src/coordinax/_coordinax/operators/base.py @@ -12,7 +12,7 @@ from dataclassish import field_items from unxt import Quantity -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition if TYPE_CHECKING: from coordinax.operators import OperatorSequence diff --git a/src/coordinax/_coordinax/operators/composite.py b/src/coordinax/_coordinax/operators/composite.py index c8940c7a..a35c4764 100644 --- a/src/coordinax/_coordinax/operators/composite.py +++ b/src/coordinax/_coordinax/operators/composite.py @@ -10,7 +10,7 @@ from unxt import Quantity from .base import AbstractOperator, op_call_dispatch -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_coordinax/operators/galilean/translation.py b/src/coordinax/_coordinax/operators/galilean/translation.py index 59d55b86..11fbe8d3 100644 --- a/src/coordinax/_coordinax/operators/galilean/translation.py +++ b/src/coordinax/_coordinax/operators/galilean/translation.py @@ -15,7 +15,7 @@ from unxt import Quantity from .base import AbstractGalileanOperator -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition from coordinax._coordinax.d1.cartesian import CartesianPosition1D from coordinax._coordinax.d2.cartesian import CartesianPosition2D from coordinax._coordinax.d3.base import AbstractPosition3D diff --git a/src/coordinax/_coordinax/operators/identity.py b/src/coordinax/_coordinax/operators/identity.py index ab863080..06b3a6e9 100644 --- a/src/coordinax/_coordinax/operators/identity.py +++ b/src/coordinax/_coordinax/operators/identity.py @@ -9,7 +9,7 @@ from unxt import Quantity from .base import AbstractOperator, op_call_dispatch -from coordinax._coordinax.base_pos import AbstractPosition +from coordinax._coordinax.base import AbstractPosition @final diff --git a/src/coordinax/_coordinax/space.py b/src/coordinax/_coordinax/space.py index cc90aa6f..aef21e3a 100644 --- a/src/coordinax/_coordinax/space.py +++ b/src/coordinax/_coordinax/space.py @@ -19,12 +19,14 @@ from unxt import Quantity, dimensions from xmmutablemap import ImmutableMap +from .base import ( + AbstractAcceleration, + AbstractPosition, + AbstractVector, + AbstractVelocity, +) from .typing import Unit from .utils import classproperty -from coordinax._coordinax.base import AbstractVector -from coordinax._coordinax.base_acc import AbstractAcceleration -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity from coordinax._coordinax.funcs import represent_as DimensionLike: TypeAlias = Dimension | str diff --git a/src/coordinax/_coordinax/transform/accelerations.py b/src/coordinax/_coordinax/transform/accelerations.py index 8c41c484..2526fa76 100644 --- a/src/coordinax/_coordinax/transform/accelerations.py +++ b/src/coordinax/_coordinax/transform/accelerations.py @@ -13,9 +13,11 @@ from dataclassish import field_items from unxt import AbstractDistance, Quantity -from coordinax._coordinax.base_acc import AbstractAcceleration -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import ( + AbstractAcceleration, + AbstractPosition, + AbstractVelocity, +) from coordinax._coordinax.d1.base import AbstractAcceleration1D from coordinax._coordinax.d2.base import AbstractAcceleration2D from coordinax._coordinax.d3.base import AbstractAcceleration3D diff --git a/src/coordinax/_coordinax/transform/differentials.py b/src/coordinax/_coordinax/transform/differentials.py index 4d3a97ea..7be23360 100644 --- a/src/coordinax/_coordinax/transform/differentials.py +++ b/src/coordinax/_coordinax/transform/differentials.py @@ -13,8 +13,7 @@ from dataclassish import field_items from unxt import AbstractDistance, Quantity -from coordinax._coordinax.base_pos import AbstractPosition -from coordinax._coordinax.base_vel import AbstractVelocity +from coordinax._coordinax.base import AbstractPosition, AbstractVelocity from coordinax._coordinax.d1.base import AbstractVelocity1D from coordinax._coordinax.d2.base import AbstractVelocity2D from coordinax._coordinax.d3.base import AbstractVelocity3D diff --git a/tests/test_jax_ops.py b/tests/test_jax_ops.py index 44663442..0941a640 100644 --- a/tests/test_jax_ops.py +++ b/tests/test_jax_ops.py @@ -9,7 +9,7 @@ from unxt import AbstractQuantity import coordinax as cx -from coordinax._coordinax.base_pos import VECTOR_CLASSES +from coordinax._coordinax.base import VECTOR_CLASSES VECTOR_CLASSES_3D = [c for c in VECTOR_CLASSES if issubclass(c, cx.AbstractPosition3D)] From 7bf8c53b0b8e09af38c7977b9400a9406f727e40 Mon Sep 17 00:00:00 2001 From: nstarman Date: Sun, 15 Sep 2024 12:45:47 -0400 Subject: [PATCH 2/2] test: position class set import Signed-off-by: nstarman --- src/coordinax/_coordinax/base/base_pos.py | 5 +++-- tests/test_jax_ops.py | 10 ++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/coordinax/_coordinax/base/base_pos.py b/src/coordinax/_coordinax/base/base_pos.py index b8d3ebe1..b141752a 100644 --- a/src/coordinax/_coordinax/base/base_pos.py +++ b/src/coordinax/_coordinax/base/base_pos.py @@ -29,7 +29,8 @@ PosT = TypeVar("PosT", bound="AbstractPosition") -VECTOR_CLASSES: set[type["AbstractPosition"]] = set() +# TODO: figure out public API for this +POSITION_CLASSES: set[type["AbstractPosition"]] = set() class AbstractPosition(AvalMixin, AbstractVector): # pylint: disable=abstract-method @@ -44,7 +45,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if isabstract(cls) or cls.__name__.startswith("Abstract"): return - VECTOR_CLASSES.add(cls) + POSITION_CLASSES.add(cls) @classproperty @classmethod diff --git a/tests/test_jax_ops.py b/tests/test_jax_ops.py index 0941a640..549b882e 100644 --- a/tests/test_jax_ops.py +++ b/tests/test_jax_ops.py @@ -9,13 +9,15 @@ from unxt import AbstractQuantity import coordinax as cx -from coordinax._coordinax.base import VECTOR_CLASSES +from coordinax._coordinax.base.base_pos import POSITION_CLASSES -VECTOR_CLASSES_3D = [c for c in VECTOR_CLASSES if issubclass(c, cx.AbstractPosition3D)] +POSITION_CLASSES_3D = [ + c for c in POSITION_CLASSES if issubclass(c, cx.AbstractPosition3D) +] # TODO: cycle through all representations -@pytest.fixture(params=VECTOR_CLASSES_3D) +@pytest.fixture(params=POSITION_CLASSES_3D) def q(request) -> cx.AbstractPosition: """Fixture for 3D Vectors.""" q = cx.CartesianPosition3D.constructor([1, 2, 3], "kpc") @@ -29,7 +31,7 @@ def func( return q.represent_as(target) -@pytest.mark.parametrize("target", VECTOR_CLASSES_3D) +@pytest.mark.parametrize("target", POSITION_CLASSES_3D) def test_jax_through_representation( q: cx.AbstractPosition, target: type[cx.AbstractPosition] ) -> None: