Skip to content

Commit

Permalink
build: bump quaxed (#182)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Sep 15, 2024
1 parent 56cf691 commit 876c14a
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"optional_dependencies >= 0.3",
"plum-dispatch>=2.5.1",
"quax>=0.0.3",
"quaxed >= 0.4",
"quaxed >= 0.5.3",
"unxt >= 0.16",
]
description = "Coordinates in JAX"
Expand Down
10 changes: 5 additions & 5 deletions src/coordinax/_coordinax/base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from plum import convert, dispatch
from quax import quaxify, register

import quaxed.array_api as xp
import quaxed.lax as qlax
import quaxed.numpy as jnp
from dataclassish import field_items
from unxt import Quantity

Expand Down Expand Up @@ -183,7 +183,7 @@ def norm(self) -> ct.BatchableLength:
Quantity['length'](Array(3.7416575, dtype=float32), unit='m')
"""
return xp.linalg.vector_norm(self, axis=-1)
return jnp.linalg.vector_norm(self, axis=-1)


# ===================================================================
Expand Down Expand Up @@ -251,7 +251,7 @@ def _mul_v_pos(lhs: ArrayLike, rhs: AbstractPosition, /) -> AbstractPosition:
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import quaxed.array_api as jnp
>>> import quaxed.numpy as jnp
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> jnp.multiply(2, vec)
Expand All @@ -277,7 +277,7 @@ def _mul_v_pos(lhs: ArrayLike, rhs: AbstractPosition, /) -> AbstractPosition:
>>> from plum import conversion_method
>>> @conversion_method(MyCartesian, Quantity)
... def _to_quantity(x: MyCartesian, /) -> Quantity:
... return xp.stack((x.x, x.y, x.z), axis=-1)
... return jnp.stack((x.x, x.y, x.z), axis=-1)
Add representation transformation
Expand Down Expand Up @@ -406,7 +406,7 @@ def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition:
Quantity['length'](Array(0.5, dtype=float32), unit='m')
"""
return replace(lhs, **{k: xp.divide(v, rhs) for k, v in field_items(lhs)})
return replace(lhs, **{k: jnp.divide(v, rhs) for k, v in field_items(lhs)})


# ------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_coordinax/d2/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def _mul_p_vpolar(lhs: ArrayLike, rhs: PolarPosition, /) -> PolarPosition:
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import quaxed.array_api as xp
>>> import quaxed.array_api as jnp
>>> v = cx.PolarPosition(r=Quantity(1, "m"), phi=Quantity(90, "deg"))
>>> xp.linalg.vector_norm(v, axis=-1)
>>> jnp.linalg.vector_norm(v, axis=-1)
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> nv = xp.multiply(2, v)
>>> nv = jnp.multiply(2, v)
>>> nv
PolarPosition(
r=Distance(value=f32[], unit=Unit("m")),
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_coordinax/d3/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,16 +797,16 @@ def _mul_p_vmsph(
--------
>>> from unxt import Quantity
>>> import coordinax as cx
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> v = cx.MathSphericalPosition(r=Quantity(3, "kpc"),
... theta=Quantity(90, "deg"),
... phi=Quantity(0, "deg"))
>>> xp.linalg.vector_norm(v, axis=-1)
>>> jnp.linalg.vector_norm(v, axis=-1)
Quantity['length'](Array(3., dtype=float32), unit='kpc')
>>> nv = xp.multiply(2, v)
>>> nv = jnp.multiply(2, v)
>>> nv
MathSphericalPosition(
r=Distance(value=f32[], unit=Unit("kpc")),
Expand Down
7 changes: 3 additions & 4 deletions src/coordinax/_coordinax/dn/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from plum import conversion_method
from quax import register

import quaxed.array_api as xp
import quaxed.numpy as jnp
from unxt import Quantity

Expand Down Expand Up @@ -131,7 +130,7 @@ def norm(self) -> ct.BatchableLength:
Quantity['length'](Array(3.7416575, dtype=float32), unit='kpc')
"""
return xp.linalg.vector_norm(self.q, axis=-1)
return jnp.linalg.vector_norm(self.q, axis=-1)


# -------------------------------------------------------------------
Expand Down Expand Up @@ -365,7 +364,7 @@ def norm(self, _: AbstractPositionND | None = None, /) -> ct.BatchableSpeed:
Quantity['speed'](Array(3.7416575, dtype=float32), unit='km / s')
"""
return xp.linalg.vector_norm(self.d_q, axis=-1)
return jnp.linalg.vector_norm(self.d_q, axis=-1)


# -------------------------------------------------------------------
Expand Down Expand Up @@ -537,7 +536,7 @@ def norm(
Quantity['acceleration'](Array(3.7416575, dtype=float32), unit='km / s2')
"""
return xp.linalg.vector_norm(self.d2_q, axis=-1)
return jnp.linalg.vector_norm(self.d2_q, axis=-1)


# -------------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_coordinax/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jaxtyping import Array, Shaped
from plum import dispatch

import quaxed.array_api as xp
import quaxed.numpy as jnp
from unxt import AbstractQuantity


Expand Down Expand Up @@ -48,7 +48,7 @@ def normalize_vector(x: Shaped[Array, "*batch N"], /) -> Shaped[Array, "*batch N
Array([0., 1.], dtype=float32)
"""
return x / xp.linalg.vector_norm(x, axis=-1, keepdims=True)
return x / jnp.linalg.vector_norm(x, axis=-1, keepdims=True)


@dispatch
Expand All @@ -72,4 +72,4 @@ def normalize_vector(
Quantity['dimensionless'](Array([0., 1.], dtype=float32), unit='')
"""
return x / xp.linalg.vector_norm(x, axis=-1, keepdims=True)
return x / jnp.linalg.vector_norm(x, axis=-1, keepdims=True)

0 comments on commit 876c14a

Please sign in to comment.