Skip to content

Commit

Permalink
refactor: consolidate base classes
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Sep 15, 2024
1 parent 876c14a commit cd89369
Show file tree
Hide file tree
Showing 30 changed files with 134 additions and 100 deletions.
12 changes: 0 additions & 12 deletions src/coordinax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from . import operators
from ._coordinax import (
base,
base_acc,
base_pos,
base_vel,
d1,
d2,
d3,
Expand All @@ -26,9 +23,6 @@
utils,
)
from ._coordinax.base import *
from ._coordinax.base_acc import *
from ._coordinax.base_pos import *
from ._coordinax.base_vel import *
from ._coordinax.d1 import *
from ._coordinax.d2 import *
from ._coordinax.d3 import *
Expand All @@ -46,9 +40,6 @@
__all__ = ["__version__", "operators"]
__all__ += funcs.__all__
__all__ += base.__all__
__all__ += base_pos.__all__
__all__ += base_vel.__all__
__all__ += base_acc.__all__
__all__ += d1.__all__
__all__ += d2.__all__
__all__ += d3.__all__
Expand All @@ -70,9 +61,6 @@
# Cleanup
del (
base,
base_vel,
base_pos,
base_acc,
space,
exceptions,
transform,
Expand Down
18 changes: 18 additions & 0 deletions src/coordinax/_coordinax/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Bases."""

__all__ = [
# Base
"AbstractVector",
"ToUnitsOptions",
# Position
"AbstractPosition",
# Velocity
"AbstractVelocity",
# Acceleration
"AbstractAcceleration",
]

from .base import AbstractVector, ToUnitsOptions
from .base_acc import AbstractAcceleration
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
unitsystem,
)

from .typing import Unit
from .utils import classproperty, full_shaped
from coordinax._coordinax.typing import Unit
from coordinax._coordinax.utils import classproperty, full_shaped

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .base import AbstractVector
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity
from .utils import classproperty
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from dataclassish import field_items
from unxt import Quantity

from . import typing as ct
from .base import AbstractVector
from .mixins import AvalMixin
from .utils import classproperty
from coordinax._coordinax import typing as ct
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -227,17 +227,42 @@ def _add_qq(lhs: AbstractPosition, rhs: AbstractPosition, /) -> AbstractPosition
).represent_as(type(lhs))


@register(jax.lax.sub_p) # type: ignore[misc]
def _sub_qq(lhs: AbstractPosition, rhs: AbstractPosition) -> AbstractPosition:
"""Add another object to this vector."""
# The base implementation is to convert to Cartesian and perform the
# operation. Cartesian coordinates do not have any branch cuts or
# singularities or ranges that need to be handled, so this is a safe
# default.
return qlax.sub(
lhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
rhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
).represent_as(type(lhs))
# ------------------------------------------------


@register(jax.lax.convert_element_type_p) # type: ignore[misc]
def _convert_element_type_p(
operand: AbstractPosition, **kwargs: Any
) -> AbstractPosition:
"""Convert the element type of a quantity."""
# TODO: examples
return replace(
operand,
**{k: qlax.convert_element_type(v, **kwargs) for k, v in field_items(operand)},
)


# ------------------------------------------------


@register(jax.lax.div_p) # type: ignore[misc]
def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition:
"""Divide a vector by a scalar.
Examples
--------
>>> import quaxed.array_api as jnp
>>> import coordinax as cx
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> jnp.divide(vec, 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')
>>> (vec / 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')
"""
return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)})


# ------------------------------------------------
Expand Down Expand Up @@ -389,26 +414,6 @@ def _mul_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> Quantity:
return qlax.mul(lq, rq) # re-dispatch to Quantities


@register(jax.lax.div_p) # type: ignore[misc]
def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition:
"""Divide a vector by a scalar.
Examples
--------
>>> import quaxed.array_api as jnp
>>> import coordinax as cx
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> jnp.divide(vec, 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')
>>> (vec / 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')
"""
return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)})


# ------------------------------------------------


Expand Down Expand Up @@ -446,3 +451,19 @@ def _reshape_pos(
for k, v in field_items(operand)
},
)


# ------------------------------------------------


@register(jax.lax.sub_p) # type: ignore[misc]
def _sub_qq(lhs: AbstractPosition, rhs: AbstractPosition) -> AbstractPosition:
"""Add another object to this vector."""
# The base implementation is to convert to Cartesian and perform the
# operation. Cartesian coordinates do not have any branch cuts or
# singularities or ranges that need to be handled, so this is a safe
# default.
return qlax.sub(
lhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
rhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
).represent_as(type(lhs))
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from dataclassish import field_items
from unxt import Quantity

from .base import AbstractVector
from .base_pos import AbstractPosition
from .utils import classproperty
from coordinax._coordinax.base import AbstractPosition, AbstractVector
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from unxt import Quantity

from .funcs import represent_as
from coordinax._coordinax.funcs import represent_as


class AvalMixin:
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclassish import field_values
from unxt import AbstractQuantity, Distance, Quantity, UncheckedQuantity

from .base_pos import AbstractPosition
from .base import AbstractPosition
from coordinax._coordinax.utils import full_shaped

#####################################################################
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/d1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import quaxed.numpy as jnp
from unxt import Quantity

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty

#####################################################################
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty


Expand Down
3 changes: 1 addition & 2 deletions src/coordinax/_coordinax/d1/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D
from .cartesian import CartesianAcceleration1D, CartesianPosition1D, CartesianVelocity1D
from .radial import RadialAcceleration, RadialPosition, RadialVelocity
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import AbstractPosition, AbstractVelocity

###############################################################################
# 1D
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/d2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from abc import abstractmethod

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty


Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration2D, AbstractPosition2D, AbstractVelocity2D
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty


Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d2/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import AbstractPosition2D, AbstractVelocity2D
from .cartesian import CartesianAcceleration2D, CartesianPosition2D, CartesianVelocity2D
from .polar import PolarPosition, PolarVelocity
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base import AbstractPosition


@dispatch
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/d3/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from abc import abstractmethod
from typing_extensions import override

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty


Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d3/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty


Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d3/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base import AbstractAcceleration
from coordinax._coordinax.checks import (
check_azimuth_range,
check_polar_range,
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d3/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SphericalPosition,
SphericalVelocity,
)
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base import AbstractPosition

###############################################################################
# 3D
Expand Down
3 changes: 1 addition & 2 deletions src/coordinax/_coordinax/d4/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from abc import abstractmethod
from typing import TYPE_CHECKING

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base import AbstractPosition, AbstractVector
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/dn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import quaxed.lax as qlax
import quaxed.numpy as jnp

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAccelerationND, AbstractPositionND, AbstractVelocityND
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty

##############################################################################
Expand Down
3 changes: 1 addition & 2 deletions src/coordinax/_coordinax/dn/poincare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from unxt import Quantity

import coordinax._coordinax.typing as ct
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import AbstractPosition, AbstractVelocity
from coordinax._coordinax.utils import classproperty


Expand Down
Loading

0 comments on commit cd89369

Please sign in to comment.