diff --git a/src/coordinax/_src/base/base_acc.py b/src/coordinax/_src/base/base_acc.py index c9f706c..52bb485 100644 --- a/src/coordinax/_src/base/base_acc.py +++ b/src/coordinax/_src/base/base_acc.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, TypeVar from typing_extensions import override +import equinox as eqx import jax from quax import register @@ -157,7 +158,7 @@ def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT """ return represent_as(self, target, *args, **kwargs) - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm( self, velocity: AbstractVelocity, position: AbstractPosition, / ) -> Quantity["speed"]: diff --git a/src/coordinax/_src/base/base_pos.py b/src/coordinax/_src/base/base_pos.py index ae2f777..863c303 100644 --- a/src/coordinax/_src/base/base_pos.py +++ b/src/coordinax/_src/base/base_pos.py @@ -128,7 +128,7 @@ def represent_as(self, target: type[PosT], /, *args: Any, **kwargs: Any) -> PosT """ return represent_as(self, target, *args, **kwargs) - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self) -> ct.BatchableLength: """Return the norm of the vector. diff --git a/src/coordinax/_src/base/base_vel.py b/src/coordinax/_src/base/base_vel.py index a674830..4cb1f2b 100644 --- a/src/coordinax/_src/base/base_vel.py +++ b/src/coordinax/_src/base/base_vel.py @@ -6,6 +6,7 @@ from functools import partial from typing import TYPE_CHECKING, Any, TypeVar +import equinox as eqx import jax from quax import register @@ -177,7 +178,7 @@ def represent_as( return represent_as(self, target, *args, **kwargs) - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, position: AbstractPosition, /) -> Quantity["speed"]: """Return the norm of the vector.""" return self.represent_as(self._cartesian_cls, position).norm() diff --git a/src/coordinax/_src/d1/cartesian.py b/src/coordinax/_src/d1/cartesian.py index abdc684..8686a74 100644 --- a/src/coordinax/_src/d1/cartesian.py +++ b/src/coordinax/_src/d1/cartesian.py @@ -189,7 +189,7 @@ def integral_cls(cls) -> type[CartesianPosition1D]: def differential_cls(cls) -> type["CartesianAcceleration1D"]: return CartesianAcceleration1D - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed: """Return the norm of the vector. @@ -288,7 +288,7 @@ def integral_cls(cls) -> type[CartesianVelocity1D]: # ----------------------------------------------------- # Methods - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc: """Return the norm of the vector. diff --git a/src/coordinax/_src/d2/cartesian.py b/src/coordinax/_src/d2/cartesian.py index f3192a3..3b7e5a0 100644 --- a/src/coordinax/_src/d2/cartesian.py +++ b/src/coordinax/_src/d2/cartesian.py @@ -291,7 +291,7 @@ def integral_cls(cls) -> type[CartesianVelocity2D]: # ----------------------------------------------------- - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc: """Return the norm of the vector. diff --git a/src/coordinax/_src/d3/cartesian.py b/src/coordinax/_src/d3/cartesian.py index 2423f43..84ae7ce 100644 --- a/src/coordinax/_src/d3/cartesian.py +++ b/src/coordinax/_src/d3/cartesian.py @@ -159,7 +159,7 @@ def _sub_cart3d_pos( # from coordinax.funcs @dispatch # type: ignore[misc] -@partial(jax.jit, inline=True) +@partial(eqx.filter_jit, inline=True) def normalize_vector(obj: CartesianPosition3D, /) -> CartesianGeneric3D: """Return the norm of the vector. @@ -225,7 +225,7 @@ def integral_cls(cls) -> type[CartesianPosition3D]: def differential_cls(cls) -> type["CartesianAcceleration3D"]: return CartesianAcceleration3D - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, _: AbstractPosition3D | None = None, /) -> ct.BatchableSpeed: """Return the norm of the vector. @@ -342,7 +342,7 @@ def integral_cls(cls) -> type[CartesianVelocity3D]: # ----------------------------------------------------- # Methods - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, _: AbstractVelocity3D | None = None, /) -> ct.BatchableAcc: """Return the norm of the vector. diff --git a/src/coordinax/_src/d3/cylindrical.py b/src/coordinax/_src/d3/cylindrical.py index c68e4d4..4fd44f0 100644 --- a/src/coordinax/_src/d3/cylindrical.py +++ b/src/coordinax/_src/d3/cylindrical.py @@ -10,7 +10,6 @@ from typing import final import equinox as eqx -import jax import quaxed.numpy as xp from unxt import Quantity @@ -57,7 +56,7 @@ def __check_init__(self) -> None: def differential_cls(cls) -> type["CylindricalVelocity"]: return CylindricalVelocity - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self) -> ct.BatchableLength: """Return the norm of the vector. diff --git a/src/coordinax/_src/d3/spherical.py b/src/coordinax/_src/d3/spherical.py index e90f762..0fc85e5 100644 --- a/src/coordinax/_src/d3/spherical.py +++ b/src/coordinax/_src/d3/spherical.py @@ -247,7 +247,7 @@ def __check_init__(self) -> None: def differential_cls(cls) -> type["MathSphericalVelocity"]: return MathSphericalVelocity - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self) -> ct.BatchableDistance: """Return the norm of the vector. @@ -444,7 +444,7 @@ def __check_init__(self) -> None: def differential_cls(cls) -> type["LonLatSphericalVelocity"]: return LonLatSphericalVelocity - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self) -> ct.BatchableDistance: """Return the norm of the vector. diff --git a/src/coordinax/_src/d4/spacetime.py b/src/coordinax/_src/d4/spacetime.py index 020173c..46b1e9f 100644 --- a/src/coordinax/_src/d4/spacetime.py +++ b/src/coordinax/_src/d4/spacetime.py @@ -203,7 +203,7 @@ def __neg__(self) -> "FourVector": # ------------------------------------------- - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]: r"""Return the squared vector norm :math:`(ct)^2 - (x^2 + y^2 + z^2)`. @@ -219,7 +219,7 @@ def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]: """ return -(self.q.norm() ** 2) + (self.c * self.t) ** 2 # for units - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self) -> BatchableLength: r"""Return the vector norm :math:`\sqrt{(ct)^2 - (x^2 + y^2 + z^2)}`. diff --git a/src/coordinax/_src/dn/cartesian.py b/src/coordinax/_src/dn/cartesian.py index 399d31e..0588ef3 100644 --- a/src/coordinax/_src/dn/cartesian.py +++ b/src/coordinax/_src/dn/cartesian.py @@ -100,7 +100,7 @@ def differential_cls(cls) -> type["CartesianVelocityND"]: # type: ignore[overri # ----------------------------------------------------- - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self) -> ct.BatchableLength: """Return the norm of the vector. @@ -353,7 +353,7 @@ def integral_cls(cls) -> type[CartesianPositionND]: def differential_cls(cls) -> type["CartesianAccelerationND"]: return CartesianAccelerationND - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm(self, _: AbstractPositionND | None = None, /) -> ct.BatchableSpeed: """Return the norm of the vector. @@ -520,7 +520,7 @@ def differential_cls(cls) -> NoReturn: msg = "Not yet supported" raise NotImplementedError(msg) # TODO: Implement this - @partial(jax.jit, inline=True) + @partial(eqx.filter_jit, inline=True) def norm( self, velocity: AbstractVelocityND | None = None, # noqa: ARG002 diff --git a/src/coordinax/_src/funcs.py b/src/coordinax/_src/funcs.py index 756bfc2..156f146 100644 --- a/src/coordinax/_src/funcs.py +++ b/src/coordinax/_src/funcs.py @@ -8,7 +8,7 @@ from functools import partial from typing import Any -import jax +import equinox as eqx from jaxtyping import Array, Shaped from plum import dispatch @@ -30,7 +30,7 @@ def represent_as(current: Any, target: type[Any], /, **kwargs: Any) -> Any: @dispatch -@partial(jax.jit, inline=True) +@partial(eqx.filter_jit, inline=True) def normalize_vector(x: Shaped[Array, "*batch N"], /) -> Shaped[Array, "*batch N"]: """Return the unit vector. @@ -52,7 +52,7 @@ def normalize_vector(x: Shaped[Array, "*batch N"], /) -> Shaped[Array, "*batch N @dispatch -@partial(jax.jit, inline=True) +@partial(eqx.filter_jit, inline=True) def normalize_vector( x: Shaped[AbstractQuantity, "*batch N"], / ) -> Shaped[AbstractQuantity, "*batch N"]: diff --git a/src/coordinax/_src/transform/accelerations.py b/src/coordinax/_src/transform/accelerations.py index 4f85a33..7a194ae 100644 --- a/src/coordinax/_src/transform/accelerations.py +++ b/src/coordinax/_src/transform/accelerations.py @@ -6,6 +6,7 @@ from math import prod from typing import Any +import equinox as eqx import jax from plum import dispatch @@ -202,6 +203,4 @@ def represent_as( # TODO: situate this better to show how represent_as is used -jac_rep_as = jax.jit( - jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,) -) +jac_rep_as = eqx.filter_jit(jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None))) diff --git a/src/coordinax/_src/transform/differentials.py b/src/coordinax/_src/transform/differentials.py index 01e9774..abe9d0a 100644 --- a/src/coordinax/_src/transform/differentials.py +++ b/src/coordinax/_src/transform/differentials.py @@ -6,6 +6,7 @@ from math import prod from typing import Any +import equinox as eqx import jax from plum import dispatch @@ -174,6 +175,4 @@ def represent_as( # TODO: situate this better to show how represent_as is used -jac_rep_as = jax.jit( - jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,) -) +jac_rep_as = eqx.filter_jit(jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None))) diff --git a/tests/test_jax_ops.py b/tests/test_jax_ops.py index 140aac2..886c1ee 100644 --- a/tests/test_jax_ops.py +++ b/tests/test_jax_ops.py @@ -1,8 +1,6 @@ """Test using Jax operations.""" -from functools import partial - -import jax +import equinox as eqx import pytest from dataclassish import field_items @@ -24,7 +22,7 @@ def q(request) -> cx.AbstractPosition: return q.represent_as(request.param) -@partial(jax.jit, static_argnums=(1,)) +@eqx.filter_jit def func( q: cx.AbstractPosition, target: type[cx.AbstractPosition] ) -> cx.AbstractPosition: