Skip to content

Commit

Permalink
typing: override
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Sep 19, 2024
1 parent 77bbb12 commit 9f83291
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/coordinax/_src/base/base_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import override

import equinox as eqx
import jax
Expand Down Expand Up @@ -129,6 +130,7 @@ def __neg__(self) -> "Self":
# ===============================================================
# Convenience methods

@override
def represent_as(
self,
target: type[VelT],
Expand Down
5 changes: 5 additions & 0 deletions src/coordinax/_src/d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import replace
from functools import partial
from typing import final
from typing_extensions import override

import equinox as eqx
import jax
Expand Down Expand Up @@ -179,16 +180,19 @@ class CartesianVelocity1D(AvalMixin, AbstractVelocity1D):
d_x: ct.BatchableSpeed = eqx.field(converter=Quantity["speed"].constructor)
r"""X differential :math:`dx/dt \in (-\infty,+\infty`)`."""

@override
@classproperty
@classmethod
def integral_cls(cls) -> type[CartesianPosition1D]:
return CartesianPosition1D

@override
@classproperty
@classmethod
def differential_cls(cls) -> type["CartesianAcceleration1D"]:
return CartesianAcceleration1D

@override
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -288,6 +292,7 @@ def integral_cls(cls) -> type[CartesianVelocity1D]:
# -----------------------------------------------------
# Methods

@override
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
2 changes: 2 additions & 0 deletions src/coordinax/_src/d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import fields, replace
from functools import partial
from typing import final
from typing_extensions import override

import equinox as eqx
import jax
Expand Down Expand Up @@ -291,6 +292,7 @@ def integral_cls(cls) -> type[CartesianVelocity2D]:

# -----------------------------------------------------

@override
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
3 changes: 3 additions & 0 deletions src/coordinax/_src/d3/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,13 @@ class CartesianVelocity3D(AvalMixin, AbstractVelocity3D):
)
r"""Z speed :math:`dz/dt \in [-\infty, \infty]."""

@override
@classproperty
@classmethod
def integral_cls(cls) -> type[CartesianPosition3D]:
return CartesianPosition3D

@override
@classproperty
@classmethod
def differential_cls(cls) -> type["CartesianAcceleration3D"]:
Expand Down Expand Up @@ -342,6 +344,7 @@ def integral_cls(cls) -> type[CartesianVelocity3D]:
# -----------------------------------------------------
# Methods

@override
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractVelocity3D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
3 changes: 3 additions & 0 deletions src/coordinax/_src/d3/cylindrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from functools import partial
from typing import final
from typing_extensions import override

import equinox as eqx

Expand Down Expand Up @@ -51,11 +52,13 @@ def __check_init__(self) -> None:
check_r_non_negative(self.rho)
check_azimuth_range(self.phi)

@override
@classproperty
@classmethod
def differential_cls(cls) -> type["CylindricalVelocity"]:
return CylindricalVelocity

@override
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Expand Down
4 changes: 4 additions & 0 deletions src/coordinax/_src/d3/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from abc import abstractmethod
from functools import partial
from typing import final
from typing_extensions import override

import equinox as eqx
import jax
Expand Down Expand Up @@ -242,6 +243,7 @@ def __check_init__(self) -> None:
check_azimuth_range(self.theta)
check_polar_range(self.phi)

@override
@classproperty
@classmethod
def differential_cls(cls) -> type["MathSphericalVelocity"]:
Expand Down Expand Up @@ -439,11 +441,13 @@ def __check_init__(self) -> None:
check_polar_range(self.lat, -Quantity(90, "deg"), Quantity(90, "deg"))
check_r_non_negative(self.distance)

@override
@classproperty
@classmethod
def differential_cls(cls) -> type["LonLatSphericalVelocity"]:
return LonLatSphericalVelocity

@override
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableDistance:
"""Return the norm of the vector.
Expand Down
5 changes: 5 additions & 0 deletions src/coordinax/_src/d4/spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import KW_ONLY, fields, replace
from functools import partial
from typing import TYPE_CHECKING, Any, final
from typing_extensions import override

import equinox as eqx
import jax
Expand Down Expand Up @@ -168,12 +169,14 @@ def __getattr__(self, name: str) -> Any:

# -------------------------------------------

@override
@classproperty
@classmethod
def _cartesian_cls(cls) -> type[AbstractVector]:
msg = "Not yet implemented"
raise NotImplementedError(msg)

@override
@classproperty
@classmethod
def differential_cls(cls) -> "Never": # type: ignore[override]
Expand All @@ -183,6 +186,7 @@ def differential_cls(cls) -> "Never": # type: ignore[override]
# -------------------------------------------
# Unary operations

@override
def __neg__(self) -> "FourVector":
"""Negate the vector.
Expand Down Expand Up @@ -219,6 +223,7 @@ def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]:
"""
return -(self.q.norm() ** 2) + (self.c * self.t) ** 2 # for units

@override
@partial(eqx.filter_jit, inline=True)
def norm(self) -> BatchableLength:
r"""Return the vector norm :math:`\sqrt{(ct)^2 - (x^2 + y^2 + z^2)}`.
Expand Down
9 changes: 6 additions & 3 deletions src/coordinax/_src/dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class CartesianPositionND(AbstractPositionND):
dimensions. Arbitrary batch shapes are supported.
"""

@override
@classproperty
@classmethod
@override
def differential_cls(cls) -> type["CartesianVelocityND"]: # type: ignore[override]
return CartesianVelocityND

Expand Down Expand Up @@ -348,6 +348,7 @@ class CartesianVelocityND(AvalMixin, AbstractVelocityND):
def integral_cls(cls) -> type[CartesianPositionND]:
return CartesianPositionND

@override
@classproperty
@classmethod
def differential_cls(cls) -> type["CartesianAccelerationND"]:
Expand Down Expand Up @@ -490,6 +491,7 @@ class CartesianAccelerationND(AvalMixin, AbstractAccelerationND):
dimensions. Arbitrary batch shapes are supported.
"""

@override
@classproperty
@classmethod
def integral_cls(cls) -> type[CartesianVelocityND]:
Expand Down Expand Up @@ -520,11 +522,12 @@ def differential_cls(cls) -> NoReturn:
msg = "Not yet supported"
raise NotImplementedError(msg) # TODO: Implement this

@override
@partial(eqx.filter_jit, inline=True)
def norm(
self,
velocity: AbstractVelocityND | None = None, # noqa: ARG002
position: AbstractPositionND | None = None, # noqa: ARG002
velocity: AbstractVelocityND | None = None,
position: AbstractPositionND | None = None,
/,
) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down

0 comments on commit 9f83291

Please sign in to comment.