Skip to content

Commit

Permalink
refactor: rename qnp to jnp
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Aug 27, 2024
1 parent 9d577db commit f07317b
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 206 deletions.
8 changes: 4 additions & 4 deletions src/coordinax/_coordinax/dn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import equinox as eqx

import quaxed.lax as qlax
import quaxed.numpy as qnp
import quaxed.numpy as jnp

from coordinax._coordinax.base import AbstractVector
from coordinax._coordinax.base_acc import AbstractAcceleration
Expand Down Expand Up @@ -127,7 +127,7 @@ def flatten(self) -> "Self":
(2,)
"""
return replace(self, q=qnp.reshape(self.q, (self.size, self.q.shape[-1]), "C"))
return replace(self, q=jnp.reshape(self.q, (self.size, self.q.shape[-1]), "C"))

def reshape(self, *shape: Any, order: str = "C") -> "Self":
"""Reshape the N-dimensional position.
Expand Down Expand Up @@ -263,7 +263,7 @@ def flatten(self) -> "Self":
"""
return replace(
self, d_q=qnp.reshape(self.d_q, (self.size, self.d_q.shape[-1]), "C")
self, d_q=jnp.reshape(self.d_q, (self.size, self.d_q.shape[-1]), "C")
)

def reshape(self, *shape: Any, order: str = "C") -> "Self":
Expand Down Expand Up @@ -406,7 +406,7 @@ def flatten(self) -> "Self":
"""
return replace(
self, d2_q=qnp.reshape(self.d2_q, (self.size, self.d2_q.shape[-1]), "C")
self, d2_q=jnp.reshape(self.d2_q, (self.size, self.d2_q.shape[-1]), "C")
)

def reshape(self, *shape: Any, order: str = "C") -> "Self":
Expand Down
12 changes: 6 additions & 6 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

import quaxed.array_api as xp
import quaxed.numpy as qnp
import quaxed.numpy as jnp
from dataclassish import field_items
from unxt import AbstractQuantity

Expand Down Expand Up @@ -108,7 +108,7 @@ def test_flatten(self, vector):
flat = vector.flatten()
assert isinstance(flat, type(vector))
assert all(
qnp.array_equal(getattr(flat, c), getattr(vector, c).flatten())
jnp.array_equal(getattr(flat, c), getattr(vector, c).flatten())
for c in vector.components
)

Expand All @@ -120,7 +120,7 @@ def test_flatten(self, vector):
flat = vec.flatten()
assert isinstance(flat, type(vec))
assert all(
qnp.array_equal(getattr(flat, c).value, xp.ones(8)) for c in vec.components
jnp.array_equal(getattr(flat, c).value, xp.ones(8)) for c in vec.components
)

def test_reshape(self, vector):
Expand All @@ -129,7 +129,7 @@ def test_reshape(self, vector):
reshaped = vector.reshape(2, -1)
assert isinstance(reshaped, type(vector))
assert all(
qnp.array_equal(getattr(reshaped, c), getattr(vector, c).reshape(2, -1))
jnp.array_equal(getattr(reshaped, c), getattr(vector, c).reshape(2, -1))
for c in vector.components
)

Expand All @@ -141,7 +141,7 @@ def test_reshape(self, vector):
reshaped = vec.reshape(1, 8)
assert isinstance(reshaped, type(vec))
assert all(
qnp.array_equal(getattr(reshaped, c).value, xp.ones((1, 8)))
jnp.array_equal(getattr(reshaped, c).value, xp.ones((1, 8)))
for c in vec.components
)

Expand All @@ -156,7 +156,7 @@ def test_asdict(self, vector):
for k, v in adict.items():
assert isinstance(k, str)
assert isinstance(v, AbstractQuantity)
assert qnp.array_equal(v, getattr(vector, k))
assert jnp.array_equal(v, getattr(vector, k))

# Test with a different dict_factory
adict = vector.asdict(dict_factory=UserDict)
Expand Down
110 changes: 55 additions & 55 deletions tests/test_d1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

import quaxed.numpy as qnp
import quaxed.numpy as jnp
from unxt import Quantity

import coordinax as cx
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_cartesian1d_to_radial(self, vector):
radial = vector.represent_as(cx.RadialPosition)

assert isinstance(radial, cx.RadialPosition)
assert qnp.array_equal(radial.r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(radial.r, Quantity([1, 2, 3, 4], "kpc"))

def test_cartesian1d_to_cartesian2d(self, vector):
"""Test ``coordinax.represent_as(CartesianPosition2D)``."""
Expand All @@ -50,16 +50,16 @@ def test_cartesian1d_to_cartesian2d(self, vector):
)

assert isinstance(cart2d, cx.CartesianPosition2D)
assert qnp.array_equal(cart2d.x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart2d.y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart2d.x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart2d.y, Quantity([5, 6, 7, 8], "km"))

def test_cartesian1d_to_polar(self, vector):
"""Test ``coordinax.represent_as(PolarPosition)``."""
polar = vector.represent_as(cx.PolarPosition, phi=Quantity([0, 1, 2, 3], "rad"))

assert isinstance(polar, cx.PolarPosition)
assert qnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad"))

def test_cartesian1d_to_cartesian3d(self, vector):
"""Test ``coordinax.represent_as(CartesianPosition3D)``."""
Expand All @@ -70,9 +70,9 @@ def test_cartesian1d_to_cartesian3d(self, vector):
)

assert isinstance(cart3d, cx.CartesianPosition3D)
assert qnp.array_equal(cart3d.x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart3d.y, Quantity([5, 6, 7, 8], "km"))
assert qnp.array_equal(cart3d.z, Quantity([9, 10, 11, 12], "m"))
assert jnp.array_equal(cart3d.x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart3d.y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart3d.z, Quantity([9, 10, 11, 12], "m"))

def test_cartesian1d_to_spherical(self, vector):
"""Test ``coordinax.represent_as(SphericalPosition)``."""
Expand All @@ -83,9 +83,9 @@ def test_cartesian1d_to_spherical(self, vector):
)

assert isinstance(spherical, cx.SphericalPosition)
assert qnp.array_equal(spherical.r, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(spherical.theta, Quantity([4, 15, 60, 170], "deg"))
assert qnp.array_equal(spherical.phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(spherical.r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(spherical.theta, Quantity([4, 15, 60, 170], "deg"))
assert jnp.array_equal(spherical.phi, Quantity([0, 1, 2, 3], "rad"))

def test_cartesian1d_to_cylindrical(self, vector):
"""Test ``coordinax.represent_as(CylindricalPosition)``."""
Expand All @@ -96,9 +96,9 @@ def test_cartesian1d_to_cylindrical(self, vector):
)

assert isinstance(cylindrical, cx.CylindricalPosition)
assert qnp.array_equal(cylindrical.rho, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cylindrical.phi, Quantity([0, 1, 2, 3], "rad"))
assert qnp.array_equal(cylindrical.z, Quantity([4, 5, 6, 7], "m"))
assert jnp.array_equal(cylindrical.rho, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cylindrical.phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(cylindrical.z, Quantity([4, 5, 6, 7], "m"))


class TestRadialPosition(AbstractPosition1DTest):
Expand All @@ -117,7 +117,7 @@ def test_radial_to_cartesian1d(self, vector):
cart1d = vector.represent_as(cx.CartesianPosition1D)

assert isinstance(cart1d, cx.CartesianPosition1D)
assert qnp.array_equal(cart1d.x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart1d.x, Quantity([1, 2, 3, 4], "kpc"))

def test_radial_to_radial(self, vector):
"""Test ``coordinax.represent_as(RadialPosition)``."""
Expand All @@ -136,16 +136,16 @@ def test_radial_to_cartesian2d(self, vector):
)

assert isinstance(cart2d, cx.CartesianPosition2D)
assert qnp.array_equal(cart2d.x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart2d.y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart2d.x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart2d.y, Quantity([5, 6, 7, 8], "km"))

def test_radial_to_polar(self, vector):
"""Test ``coordinax.represent_as(PolarPosition)``."""
polar = vector.represent_as(cx.PolarPosition, phi=Quantity([0, 1, 2, 3], "rad"))

assert isinstance(polar, cx.PolarPosition)
assert qnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad"))

def test_radial_to_cartesian3d(self, vector):
"""Test ``coordinax.represent_as(CartesianPosition3D)``."""
Expand All @@ -156,9 +156,9 @@ def test_radial_to_cartesian3d(self, vector):
)

assert isinstance(cart3d, cx.CartesianPosition3D)
assert qnp.array_equal(cart3d.x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart3d.y, Quantity([5, 6, 7, 8], "km"))
assert qnp.array_equal(cart3d.z, Quantity([9, 10, 11, 12], "m"))
assert jnp.array_equal(cart3d.x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart3d.y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart3d.z, Quantity([9, 10, 11, 12], "m"))

def test_radial_to_spherical(self, vector):
"""Test ``coordinax.represent_as(SphericalPosition)``."""
Expand All @@ -169,9 +169,9 @@ def test_radial_to_spherical(self, vector):
)

assert isinstance(spherical, cx.SphericalPosition)
assert qnp.array_equal(spherical.r, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(spherical.theta, Quantity([4, 15, 60, 170], "deg"))
assert qnp.array_equal(spherical.phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(spherical.r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(spherical.theta, Quantity([4, 15, 60, 170], "deg"))
assert jnp.array_equal(spherical.phi, Quantity([0, 1, 2, 3], "rad"))

def test_radial_to_cylindrical(self, vector):
"""Test ``coordinax.represent_as(CylindricalPosition)``."""
Expand All @@ -182,9 +182,9 @@ def test_radial_to_cylindrical(self, vector):
)

assert isinstance(cylindrical, cx.CylindricalPosition)
assert qnp.array_equal(cylindrical.rho, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cylindrical.phi, Quantity([0, 1, 2, 3], "rad"))
assert qnp.array_equal(cylindrical.z, Quantity([4, 5, 6, 7], "m"))
assert jnp.array_equal(cylindrical.rho, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cylindrical.phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(cylindrical.z, Quantity([4, 5, 6, 7], "m"))


class AbstractVelocity1DTest(AbstractVelocityTest):
Expand Down Expand Up @@ -224,7 +224,7 @@ def test_cartesian1d_to_radial(self, difntl, vector):
radial = difntl.represent_as(cx.RadialVelocity, vector)

assert isinstance(radial, cx.RadialVelocity)
assert qnp.array_equal(radial.d_r, Quantity([1, 2, 3, 4], "km/s"))
assert jnp.array_equal(radial.d_r, Quantity([1, 2, 3, 4], "km/s"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -235,8 +235,8 @@ def test_cartesian1d_to_cartesian2d(self, difntl, vector):
)

assert isinstance(cart2d, cx.CartesianVelocity2D)
assert qnp.array_equal(cart2d.d_x, Quantity([1, 2, 3, 4], "km/s"))
assert qnp.array_equal(cart2d.d_y, Quantity([5, 6, 7, 8], "km/s"))
assert jnp.array_equal(cart2d.d_x, Quantity([1, 2, 3, 4], "km/s"))
assert jnp.array_equal(cart2d.d_y, Quantity([5, 6, 7, 8], "km/s"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -247,8 +247,8 @@ def test_cartesian1d_to_polar(self, difntl, vector):
)

assert isinstance(polar, cx.PolarVelocity)
assert qnp.array_equal(polar.d_r, Quantity([1, 2, 3, 4], "km/s"))
assert qnp.array_equal(polar.d_phi, Quantity([0, 1, 2, 3], "rad/s"))
assert jnp.array_equal(polar.d_r, Quantity([1, 2, 3, 4], "km/s"))
assert jnp.array_equal(polar.d_phi, Quantity([0, 1, 2, 3], "rad/s"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -262,9 +262,9 @@ def test_cartesian1d_to_cartesian3d(self, difntl, vector):
)

assert isinstance(cart3d, cx.CartesianVelocity3D)
assert qnp.array_equal(cart3d.d_x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart3d.d_y, Quantity([5, 6, 7, 8], "km"))
assert qnp.array_equal(cart3d.d_z, Quantity([9, 10, 11, 12], "m"))
assert jnp.array_equal(cart3d.d_x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart3d.d_y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart3d.d_z, Quantity([9, 10, 11, 12], "m"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -278,9 +278,9 @@ def test_cartesian1d_to_spherical(self, difntl, vector):
)

assert isinstance(spherical, cx.SphericalVelocity)
assert qnp.array_equal(spherical.d_r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(spherical.d_r, Quantity([1, 2, 3, 4], "kpc"))
assert spherical.d_theta == Quantity([4, 5, 6, 7], "rad/s")
assert qnp.array_equal(spherical.d_phi, Quantity([0, 1, 2, 3], "rad/s"))
assert jnp.array_equal(spherical.d_phi, Quantity([0, 1, 2, 3], "rad/s"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -294,9 +294,9 @@ def test_cartesian1d_to_cylindrical(self, difntl, vector):
)

assert isinstance(cylindrical, cx.CylindricalVelocity)
assert qnp.array_equal(cylindrical.d_rho, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cylindrical.d_phi, Quantity([0, 1, 2, 3], "rad/s"))
assert qnp.array_equal(cylindrical.d_z, Quantity([4, 5, 6, 7], "m/s"))
assert jnp.array_equal(cylindrical.d_rho, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cylindrical.d_phi, Quantity([0, 1, 2, 3], "rad/s"))
assert jnp.array_equal(cylindrical.d_z, Quantity([4, 5, 6, 7], "m/s"))


class TestRadialVelocity(AbstractVelocity1DTest):
Expand All @@ -321,7 +321,7 @@ def test_radial_to_cartesian1d(self, difntl, vector):
cart1d = difntl.represent_as(cx.CartesianVelocity1D, vector)

assert isinstance(cart1d, cx.CartesianVelocity1D)
assert qnp.array_equal(cart1d.d_x, Quantity([1, 2, 3, 4], "km/s"))
assert jnp.array_equal(cart1d.d_x, Quantity([1, 2, 3, 4], "km/s"))

@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
def test_radial_to_radial(self, difntl, vector):
Expand All @@ -343,8 +343,8 @@ def test_radial_to_cartesian2d(self, difntl, vector):
)

assert isinstance(cart2d, cx.CartesianVelocity2D)
assert qnp.array_equal(cart2d.d_x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart2d.d_y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart2d.d_x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart2d.d_y, Quantity([5, 6, 7, 8], "km"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -355,8 +355,8 @@ def test_radial_to_polar(self, difntl, vector):
)

assert isinstance(polar, cx.PolarVelocity)
assert qnp.array_equal(polar.d_r, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(polar.d_phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(polar.d_r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(polar.d_phi, Quantity([0, 1, 2, 3], "rad"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -370,9 +370,9 @@ def test_radial_to_cartesian3d(self, difntl, vector):
)

assert isinstance(cart3d, cx.CartesianVelocity3D)
assert qnp.array_equal(cart3d.d_x, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cart3d.d_y, Quantity([5, 6, 7, 8], "km"))
assert qnp.array_equal(cart3d.d_z, Quantity([9, 10, 11, 12], "m"))
assert jnp.array_equal(cart3d.d_x, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cart3d.d_y, Quantity([5, 6, 7, 8], "km"))
assert jnp.array_equal(cart3d.d_z, Quantity([9, 10, 11, 12], "m"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -386,9 +386,9 @@ def test_radial_to_spherical(self, difntl, vector):
)

assert isinstance(spherical, cx.SphericalVelocity)
assert qnp.array_equal(spherical.d_r, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(spherical.d_r, Quantity([1, 2, 3, 4], "kpc"))
assert spherical.d_theta == Quantity([4, 5, 6, 7], "rad")
assert qnp.array_equal(spherical.d_phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(spherical.d_phi, Quantity([0, 1, 2, 3], "rad"))

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.filterwarnings("ignore:Explicitly requested dtype")
Expand All @@ -402,6 +402,6 @@ def test_radial_to_cylindrical(self, difntl, vector):
)

assert isinstance(cylindrical, cx.CylindricalVelocity)
assert qnp.array_equal(cylindrical.d_rho, Quantity([1, 2, 3, 4], "kpc"))
assert qnp.array_equal(cylindrical.d_phi, Quantity([0, 1, 2, 3], "rad"))
assert qnp.array_equal(cylindrical.d_z, Quantity([4, 5, 6, 7], "m"))
assert jnp.array_equal(cylindrical.d_rho, Quantity([1, 2, 3, 4], "kpc"))
assert jnp.array_equal(cylindrical.d_phi, Quantity([0, 1, 2, 3], "rad"))
assert jnp.array_equal(cylindrical.d_z, Quantity([4, 5, 6, 7], "m"))
Loading

0 comments on commit f07317b

Please sign in to comment.