Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: consolidate base classes #183

Merged
merged 2 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/coordinax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from . import operators
from ._coordinax import (
base,
base_acc,
base_pos,
base_vel,
d1,
d2,
d3,
Expand All @@ -26,9 +23,6 @@
utils,
)
from ._coordinax.base import *
from ._coordinax.base_acc import *
from ._coordinax.base_pos import *
from ._coordinax.base_vel import *
from ._coordinax.d1 import *
from ._coordinax.d2 import *
from ._coordinax.d3 import *
Expand All @@ -46,9 +40,6 @@
__all__ = ["__version__", "operators"]
__all__ += funcs.__all__
__all__ += base.__all__
__all__ += base_pos.__all__
__all__ += base_vel.__all__
__all__ += base_acc.__all__
__all__ += d1.__all__
__all__ += d2.__all__
__all__ += d3.__all__
Expand All @@ -70,9 +61,6 @@
# Cleanup
del (
base,
base_vel,
base_pos,
base_acc,
space,
exceptions,
transform,
Expand Down
18 changes: 18 additions & 0 deletions src/coordinax/_coordinax/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Bases."""

__all__ = [
# Base
"AbstractVector",
"ToUnitsOptions",
# Position
"AbstractPosition",
# Velocity
"AbstractVelocity",
# Acceleration
"AbstractAcceleration",
]

from .base import AbstractVector, ToUnitsOptions
from .base_acc import AbstractAcceleration
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
unitsystem,
)

from .typing import Unit
from .utils import classproperty, full_shaped
from coordinax._coordinax.typing import Unit
from coordinax._coordinax.utils import classproperty, full_shaped

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .base import AbstractVector
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity
from .utils import classproperty
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
from dataclassish import field_items
from unxt import Quantity

from . import typing as ct
from .base import AbstractVector
from .mixins import AvalMixin
from .utils import classproperty
from coordinax._coordinax import typing as ct
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self

PosT = TypeVar("PosT", bound="AbstractPosition")

VECTOR_CLASSES: set[type["AbstractPosition"]] = set()
# TODO: figure out public API for this
POSITION_CLASSES: set[type["AbstractPosition"]] = set()


class AbstractPosition(AvalMixin, AbstractVector): # pylint: disable=abstract-method
Expand All @@ -44,7 +45,7 @@
if isabstract(cls) or cls.__name__.startswith("Abstract"):
return

VECTOR_CLASSES.add(cls)
POSITION_CLASSES.add(cls)

@classproperty
@classmethod
Expand Down Expand Up @@ -227,17 +228,42 @@
).represent_as(type(lhs))


@register(jax.lax.sub_p) # type: ignore[misc]
def _sub_qq(lhs: AbstractPosition, rhs: AbstractPosition) -> AbstractPosition:
"""Add another object to this vector."""
# The base implementation is to convert to Cartesian and perform the
# operation. Cartesian coordinates do not have any branch cuts or
# singularities or ranges that need to be handled, so this is a safe
# default.
return qlax.sub(
lhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
rhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
).represent_as(type(lhs))
# ------------------------------------------------


@register(jax.lax.convert_element_type_p) # type: ignore[misc]
def _convert_element_type_p(
operand: AbstractPosition, **kwargs: Any
) -> AbstractPosition:
"""Convert the element type of a quantity."""
# TODO: examples
return replace(

Check warning on line 240 in src/coordinax/_coordinax/base/base_pos.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_coordinax/base/base_pos.py#L240

Added line #L240 was not covered by tests
operand,
**{k: qlax.convert_element_type(v, **kwargs) for k, v in field_items(operand)},
)


# ------------------------------------------------


@register(jax.lax.div_p) # type: ignore[misc]
def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition:
"""Divide a vector by a scalar.

Examples
--------
>>> import quaxed.array_api as jnp
>>> import coordinax as cx

>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> jnp.divide(vec, 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')

>>> (vec / 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')

"""
return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)})


# ------------------------------------------------
Expand Down Expand Up @@ -389,26 +415,6 @@
return qlax.mul(lq, rq) # re-dispatch to Quantities


@register(jax.lax.div_p) # type: ignore[misc]
def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition:
"""Divide a vector by a scalar.

Examples
--------
>>> import quaxed.array_api as jnp
>>> import coordinax as cx

>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> jnp.divide(vec, 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')

>>> (vec / 2).x
Quantity['length'](Array(0.5, dtype=float32), unit='m')

"""
return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)})


# ------------------------------------------------


Expand Down Expand Up @@ -446,3 +452,19 @@
for k, v in field_items(operand)
},
)


# ------------------------------------------------


@register(jax.lax.sub_p) # type: ignore[misc]
def _sub_qq(lhs: AbstractPosition, rhs: AbstractPosition) -> AbstractPosition:
"""Add another object to this vector."""
# The base implementation is to convert to Cartesian and perform the
# operation. Cartesian coordinates do not have any branch cuts or
# singularities or ranges that need to be handled, so this is a safe
# default.
return qlax.sub(

Check warning on line 467 in src/coordinax/_coordinax/base/base_pos.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_coordinax/base/base_pos.py#L467

Added line #L467 was not covered by tests
lhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
rhs.represent_as(lhs._cartesian_cls), # noqa: SLF001
).represent_as(type(lhs))
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .base import AbstractVector
from .base_pos import AbstractPosition
from .utils import classproperty
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from unxt import Quantity

from .funcs import represent_as
from coordinax._coordinax.funcs import represent_as


class AvalMixin:
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dataclassish import field_values
from unxt import AbstractQuantity, Distance, Quantity, UncheckedQuantity

from .base_pos import AbstractPosition
from .base import AbstractPosition
from coordinax._coordinax.utils import full_shaped

#####################################################################
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/d1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import quaxed.numpy as jnp
from unxt import Quantity

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty

#####################################################################
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty


Expand Down
3 changes: 1 addition & 2 deletions src/coordinax/_coordinax/d1/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from .base import AbstractAcceleration1D, AbstractPosition1D, AbstractVelocity1D
from .cartesian import CartesianAcceleration1D, CartesianPosition1D, CartesianVelocity1D
from .radial import RadialAcceleration, RadialPosition, RadialVelocity
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import AbstractPosition, AbstractVelocity

###############################################################################
# 1D
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/d2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from abc import abstractmethod

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty


Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration2D, AbstractPosition2D, AbstractVelocity2D
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty


Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d2/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import AbstractPosition2D, AbstractVelocity2D
from .cartesian import CartesianAcceleration2D, CartesianPosition2D, CartesianVelocity2D
from .polar import PolarPosition, PolarVelocity
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base import AbstractPosition


@dispatch
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/d3/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from abc import abstractmethod
from typing_extensions import override

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty


Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d3/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty


Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d3/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base import AbstractAcceleration
from coordinax._coordinax.checks import (
check_azimuth_range,
check_polar_range,
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d3/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SphericalPosition,
SphericalVelocity,
)
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base import AbstractPosition

###############################################################################
# 3D
Expand Down
3 changes: 1 addition & 2 deletions src/coordinax/_coordinax/d4/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from abc import abstractmethod
from typing import TYPE_CHECKING

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base import AbstractPosition, AbstractVector
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
Expand Down
10 changes: 6 additions & 4 deletions src/coordinax/_coordinax/dn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
import quaxed.lax as qlax
import quaxed.numpy as jnp

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.base_vel import AbstractVelocity
from coordinax._coordinax.base import (
AbstractAcceleration,
AbstractPosition,
AbstractVector,
AbstractVelocity,
)
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import coordinax._coordinax.typing as ct
from .base import AbstractAccelerationND, AbstractPositionND, AbstractVelocityND
from coordinax._coordinax.base_pos import AbstractPosition
from coordinax._coordinax.mixins import AvalMixin
from coordinax._coordinax.base import AbstractPosition
from coordinax._coordinax.base.mixins import AvalMixin
from coordinax._coordinax.utils import classproperty

##############################################################################
Expand Down
Loading