Skip to content

Commit

Permalink
refactor: prefer filter_jit (#190)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Sep 18, 2024
1 parent e5759f8 commit 77bbb12
Show file tree
Hide file tree
Showing 14 changed files with 28 additions and 31 deletions.
3 changes: 2 additions & 1 deletion src/coordinax/_src/base/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, TypeVar
from typing_extensions import override

import equinox as eqx
import jax
from quax import register

Expand Down Expand Up @@ -157,7 +158,7 @@ def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT
"""
return represent_as(self, target, *args, **kwargs)

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(
self, velocity: AbstractVelocity, position: AbstractPosition, /
) -> Quantity["speed"]:
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_src/base/base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def represent_as(self, target: type[PosT], /, *args: Any, **kwargs: Any) -> PosT
"""
return represent_as(self, target, *args, **kwargs)

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_src/base/base_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar

import equinox as eqx
import jax
from quax import register

Expand Down Expand Up @@ -177,7 +178,7 @@ def represent_as(

return represent_as(self, target, *args, **kwargs)

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, position: AbstractPosition, /) -> Quantity["speed"]:
"""Return the norm of the vector."""
return self.represent_as(self._cartesian_cls, position).norm()
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_src/d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def integral_cls(cls) -> type[CartesianPosition1D]:
def differential_cls(cls) -> type["CartesianAcceleration1D"]:
return CartesianAcceleration1D

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -288,7 +288,7 @@ def integral_cls(cls) -> type[CartesianVelocity1D]:
# -----------------------------------------------------
# Methods

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_src/d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def integral_cls(cls) -> type[CartesianVelocity2D]:

# -----------------------------------------------------

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_src/d3/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _sub_cart3d_pos(

# from coordinax.funcs
@dispatch # type: ignore[misc]
@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def normalize_vector(obj: CartesianPosition3D, /) -> CartesianGeneric3D:
"""Return the norm of the vector.
Expand Down Expand Up @@ -225,7 +225,7 @@ def integral_cls(cls) -> type[CartesianPosition3D]:
def differential_cls(cls) -> type["CartesianAcceleration3D"]:
return CartesianAcceleration3D

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractPosition3D | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -342,7 +342,7 @@ def integral_cls(cls) -> type[CartesianVelocity3D]:
# -----------------------------------------------------
# Methods

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractVelocity3D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
3 changes: 1 addition & 2 deletions src/coordinax/_src/d3/cylindrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import final

import equinox as eqx
import jax

import quaxed.numpy as xp
from unxt import Quantity
Expand Down Expand Up @@ -57,7 +56,7 @@ def __check_init__(self) -> None:
def differential_cls(cls) -> type["CylindricalVelocity"]:
return CylindricalVelocity

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_src/d3/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def __check_init__(self) -> None:
def differential_cls(cls) -> type["MathSphericalVelocity"]:
return MathSphericalVelocity

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableDistance:
"""Return the norm of the vector.
Expand Down Expand Up @@ -444,7 +444,7 @@ def __check_init__(self) -> None:
def differential_cls(cls) -> type["LonLatSphericalVelocity"]:
return LonLatSphericalVelocity

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableDistance:
"""Return the norm of the vector.
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_src/d4/spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __neg__(self) -> "FourVector":

# -------------------------------------------

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]:
r"""Return the squared vector norm :math:`(ct)^2 - (x^2 + y^2 + z^2)`.
Expand All @@ -219,7 +219,7 @@ def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]:
"""
return -(self.q.norm() ** 2) + (self.c * self.t) ** 2 # for units

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self) -> BatchableLength:
r"""Return the vector norm :math:`\sqrt{(ct)^2 - (x^2 + y^2 + z^2)}`.
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_src/dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def differential_cls(cls) -> type["CartesianVelocityND"]: # type: ignore[overri

# -----------------------------------------------------

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Expand Down Expand Up @@ -353,7 +353,7 @@ def integral_cls(cls) -> type[CartesianPositionND]:
def differential_cls(cls) -> type["CartesianAccelerationND"]:
return CartesianAccelerationND

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(self, _: AbstractPositionND | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -520,7 +520,7 @@ def differential_cls(cls) -> NoReturn:
msg = "Not yet supported"
raise NotImplementedError(msg) # TODO: Implement this

@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def norm(
self,
velocity: AbstractVelocityND | None = None, # noqa: ARG002
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_src/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import partial
from typing import Any

import jax
import equinox as eqx
from jaxtyping import Array, Shaped
from plum import dispatch

Expand All @@ -30,7 +30,7 @@ def represent_as(current: Any, target: type[Any], /, **kwargs: Any) -> Any:


@dispatch
@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def normalize_vector(x: Shaped[Array, "*batch N"], /) -> Shaped[Array, "*batch N"]:
"""Return the unit vector.
Expand All @@ -52,7 +52,7 @@ def normalize_vector(x: Shaped[Array, "*batch N"], /) -> Shaped[Array, "*batch N


@dispatch
@partial(jax.jit, inline=True)
@partial(eqx.filter_jit, inline=True)
def normalize_vector(
x: Shaped[AbstractQuantity, "*batch N"], /
) -> Shaped[AbstractQuantity, "*batch N"]:
Expand Down
5 changes: 2 additions & 3 deletions src/coordinax/_src/transform/accelerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from math import prod
from typing import Any

import equinox as eqx
import jax
from plum import dispatch

Expand Down Expand Up @@ -202,6 +203,4 @@ def represent_as(


# TODO: situate this better to show how represent_as is used
jac_rep_as = jax.jit(
jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,)
)
jac_rep_as = eqx.filter_jit(jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)))
5 changes: 2 additions & 3 deletions src/coordinax/_src/transform/differentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from math import prod
from typing import Any

import equinox as eqx
import jax
from plum import dispatch

Expand Down Expand Up @@ -174,6 +175,4 @@ def represent_as(


# TODO: situate this better to show how represent_as is used
jac_rep_as = jax.jit(
jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,)
)
jac_rep_as = eqx.filter_jit(jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)))
6 changes: 2 additions & 4 deletions tests/test_jax_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Test using Jax operations."""

from functools import partial

import jax
import equinox as eqx
import pytest

from dataclassish import field_items
Expand All @@ -24,7 +22,7 @@ def q(request) -> cx.AbstractPosition:
return q.represent_as(request.param)


@partial(jax.jit, static_argnums=(1,))
@eqx.filter_jit
def func(
q: cx.AbstractPosition, target: type[cx.AbstractPosition]
) -> cx.AbstractPosition:
Expand Down

0 comments on commit 77bbb12

Please sign in to comment.