Skip to content

Commit

Permalink
refactor: inline jit (#168)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Aug 27, 2024
1 parent c909a1c commit fc5d5c2
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def represent_as(self, target: type[AccT], /, *args: Any, **kwargs: Any) -> AccT

return represent_as(self, target, *args, **kwargs)

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(
self, velocity: AbstractVelocity, position: AbstractPosition, /
) -> Quantity["speed"]:
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/base_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def represent_as(

return represent_as(self, target, *args, **kwargs)

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, position: AbstractPosition, /) -> Quantity["speed"]:
"""Return the norm of the vector."""
return self.represent_as(self._cartesian_cls, position).norm()
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def integral_cls(cls) -> type[CartesianPosition1D]:
def differential_cls(cls) -> type["CartesianAcceleration1D"]:
return CartesianAcceleration1D

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -291,7 +291,7 @@ def integral_cls(cls) -> type[CartesianVelocity1D]:
# -----------------------------------------------------
# Methods

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d2/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def integral_cls(cls) -> type[CartesianVelocity2D]:

# -----------------------------------------------------

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d3/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def integral_cls(cls) -> type[CartesianPosition3D]:
def differential_cls(cls) -> type["CartesianAcceleration3D"]:
return CartesianAcceleration3D

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, _: AbstractPosition3D | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -235,7 +235,7 @@ def integral_cls(cls) -> type[CartesianVelocity3D]:
# -----------------------------------------------------
# Methods

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, _: AbstractVelocity3D | None = None, /) -> ct.BatchableAcc:
"""Return the norm of the vector.
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_coordinax/d3/cylindrical.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __check_init__(self) -> None:
def differential_cls(cls) -> type["CylindricalVelocity"]:
return CylindricalVelocity

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d3/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __check_init__(self) -> None:
def differential_cls(cls) -> type["MathSphericalVelocity"]:
return MathSphericalVelocity

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self) -> ct.BatchableDistance:
"""Return the norm of the vector.
Expand Down Expand Up @@ -450,7 +450,7 @@ def __check_init__(self) -> None:
def differential_cls(cls) -> type["LonLatSphericalVelocity"]:
return LonLatSphericalVelocity

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self) -> ct.BatchableDistance:
"""Return the norm of the vector.
Expand Down
16 changes: 8 additions & 8 deletions src/coordinax/_coordinax/d3/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,57 +110,57 @@ def represent_as(
/,
**kwargs: Any,
) -> AbstractVelocity3D:
"""Self transforms for 3D differentials.
"""Self transforms for 3D velocity.
Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx
For these transformations the position does not matter since the
self-transform returns the differential unchanged.
self-transform returns the velocity unchanged.
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "kpc")
Cartesian to Cartesian differential:
Cartesian to Cartesian velocity:
>>> dif = cx.CartesianVelocity3D.constructor([1, 2, 3], "km/s")
>>> cx.represent_as(dif, cx.CartesianVelocity3D, vec) is dif
True
Cylindrical to Cylindrical differential:
Cylindrical to Cylindrical velocity:
>>> dif = cx.CylindricalVelocity(d_rho=Quantity(1, "km/s"),
... d_phi=Quantity(2, "mas/yr"),
... d_z=Quantity(3, "km/s"))
>>> cx.represent_as(dif, cx.CylindricalVelocity, vec) is dif
True
Spherical to Spherical differential:
Spherical to Spherical velocity:
>>> dif = cx.SphericalVelocity(d_r=Quantity(1, "km/s"),
... d_theta=Quantity(2, "mas/yr"),
... d_phi=Quantity(3, "mas/yr"))
>>> cx.represent_as(dif, cx.SphericalVelocity, vec) is dif
True
LonLatSpherical to LonLatSpherical differential:
LonLatSpherical to LonLatSpherical velocity:
>>> dif = cx.LonLatSphericalVelocity(d_lon=Quantity(1, "mas/yr"),
... d_lat=Quantity(2, "mas/yr"),
... d_distance=Quantity(3, "km/s"))
>>> cx.represent_as(dif, cx.LonLatSphericalVelocity, vec) is dif
True
LonCosLatSpherical to LonCosLatSpherical differential:
LonCosLatSpherical to LonCosLatSpherical velocity:
>>> dif = cx.LonCosLatSphericalVelocity(d_lon_coslat=Quantity(1, "mas/yr"),
... d_lat=Quantity(2, "mas/yr"),
... d_distance=Quantity(3, "km/s"))
>>> cx.represent_as(dif, cx.LonCosLatSphericalVelocity, vec) is dif
True
MathSpherical to MathSpherical differential:
MathSpherical to MathSpherical velocity:
>>> dif = cx.MathSphericalVelocity(d_r=Quantity(1, "km/s"),
... d_theta=Quantity(2, "mas/yr"),
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/d4/spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __neg__(self) -> "FourVector":

# -------------------------------------------

@partial(jax.jit)
@partial(jax.jit, inline=True)
def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]:
r"""Return the squared vector norm :math:`(ct)^2 - (x^2 + y^2 + z^2)`.
Expand All @@ -223,7 +223,7 @@ def _norm2(self) -> Shaped[Quantity["area"], "*#batch"]:
"""
return -(self.q.norm() ** 2) + (self.c * self.t) ** 2 # for units

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self) -> BatchableLength:
r"""Return the vector norm :math:`\sqrt{(ct)^2 - (x^2 + y^2 + z^2)}`.
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_coordinax/dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __neg__(self) -> "Self":

# -----------------------------------------------------

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Expand Down Expand Up @@ -350,7 +350,7 @@ def integral_cls(cls) -> type[CartesianPositionND]:
def differential_cls(cls) -> type["CartesianAccelerationND"]:
return CartesianAccelerationND

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(self, _: AbstractPositionND | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Expand Down Expand Up @@ -517,7 +517,7 @@ def differential_cls(cls) -> NoReturn:
msg = "Not yet supported"
raise NotImplementedError(msg) # TODO: Implement this

@partial(jax.jit)
@partial(jax.jit, inline=True)
def norm(
self,
velocity: AbstractVelocityND | None = None, # noqa: ARG002
Expand Down

0 comments on commit fc5d5c2

Please sign in to comment.