From f1211b3baebd8112b42981de708609f2c3944a1c Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 25 Sep 2024 12:53:36 -0400 Subject: [PATCH 1/4] feat: equality Signed-off-by: nstarman --- src/coordinax/_src/base/base.py | 25 +++++++++++++++++++++---- src/coordinax/_src/base/base_pos.py | 25 +++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/src/coordinax/_src/base/base.py b/src/coordinax/_src/base/base.py index f7bff5f..790d268 100644 --- a/src/coordinax/_src/base/base.py +++ b/src/coordinax/_src/base/base.py @@ -17,10 +17,10 @@ import jax import numpy as np from astropy.units import PhysicalType as Dimensions -from jax import Device -from jaxtyping import ArrayLike +from jax import Device, tree +from jaxtyping import Array, ArrayLike, Bool from plum import dispatch -from quax import ArrayValue +from quax import ArrayValue, register import quaxed.lax as qlax import quaxed.numpy as jnp @@ -831,7 +831,7 @@ def __str__(self) -> str: return f"<{cls_name} ({comps})\n {vs}>" -# ----------------------------------------------- +# =============================================================== # Register additional constructors @@ -916,3 +916,20 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe return obj return cls(**dict(field_items(obj))) + + +# =============================================================== +# Register primitives + + +@register(jax.lax.eq_p) # type: ignore[misc] +def _eq_pos_pos(lhs: AbstractVector, rhs: AbstractVector, /) -> Bool[Array, "..."]: + """Element-wise equality of two vectors.""" + # TODO: match the behaviour of `numpy.equal` + if type(lhs) is not type(rhs): + msg = f"Cannot compare {type(lhs)} with {type(rhs)}." + raise TypeError(msg) + + comp_tree = tree.map(jnp.equal, lhs, rhs) + comp_leaves = jnp.array(tree.leaves(comp_tree)) + return jax.numpy.logical_and.reduce(comp_leaves) diff --git a/src/coordinax/_src/base/base_pos.py b/src/coordinax/_src/base/base_pos.py index 863c303..2abd872 100644 --- a/src/coordinax/_src/base/base_pos.py +++ b/src/coordinax/_src/base/base_pos.py @@ -11,6 +11,7 @@ import equinox as eqx import jax +from jax import tree from jaxtyping import ArrayLike from plum import convert from quax import quaxify, register @@ -209,6 +210,30 @@ def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition: # ------------------------------------------------ +@register(jax.lax.eq_p) # type: ignore[misc] +def _eq_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> ArrayLike: + """Element-wise equality of two positions. + + Examples + -------- + >>> import quaxed.numpy as jnp + >>> import coordinax as cx + + >>> vec1 = cx.CartesianPosition3D.constructor([[1, 2, 3], [1, 2, 4]], "m") + >>> vec2 = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> jnp.equal(vec1, vec2) + Array([ True, False], dtype=bool) + + """ + rhs_ = rhs.represent_as(rhs._cartesian_cls) # noqa: SLF001 + comp_tree = tree.map(jnp.equal, lhs, rhs_) + comp_leaves = jnp.array(tree.leaves(comp_tree)) + return jax.numpy.logical_and.reduce(comp_leaves) + + +# ------------------------------------------------ + + @register(jax.lax.mul_p) # type: ignore[misc] def _mul_v_pos(lhs: ArrayLike, rhs: AbstractPosition, /) -> AbstractPosition: """Scale a position by a scalar. From aa2b248edc863fd7ec9ead54885cc20308cf715c Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 25 Sep 2024 22:07:02 -0400 Subject: [PATCH 2/4] feat: set __eq__ Signed-off-by: nstarman --- src/coordinax/_src/base/base.py | 72 ++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/src/coordinax/_src/base/base.py b/src/coordinax/_src/base/base.py index 790d268..a534acd 100644 --- a/src/coordinax/_src/base/base.py +++ b/src/coordinax/_src/base/base.py @@ -177,6 +177,8 @@ def aval(self) -> jax.core.ShapedArray: # =============================================================== # Array API + __eq__ = jnp.equal + # --------------------------------------------------------------- # Attributes @@ -923,8 +925,74 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe @register(jax.lax.eq_p) # type: ignore[misc] -def _eq_pos_pos(lhs: AbstractVector, rhs: AbstractVector, /) -> Bool[Array, "..."]: - """Element-wise equality of two vectors.""" +def _eq_vec_vec(lhs: AbstractVector, rhs: AbstractVector, /) -> Bool[Array, "..."]: + """Element-wise equality of two vectors. + + Examples + -------- + >>> import quaxed.numpy as jnp + >>> from unxt import Quantity + >>> import coordinax as cx + + Positions are covered by a separate dispatch. So here we show velocities and + accelerations: + + >>> vel1 = cx.CartesianVelocity1D(Quantity([1, 2, 3], "km/s")) + >>> vel2 = cx.CartesianVelocity1D(Quantity([1, 0, 3], "km/s")) + >>> jnp.equal(vel1, vel2) + Array([ True, False, True], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, True], dtype=bool) + + >>> acc1 = cx.CartesianAcceleration1D(Quantity([1, 2, 3], "km/s2")) + >>> acc2 = cx.CartesianAcceleration1D(Quantity([1, 0, 3], "km/s2")) + >>> jnp.equal(acc1, acc2) + Array([ True, False, True], dtype=bool) + >>> acc1 == acc2 + Array([ True, False, True], dtype=bool) + + >>> vel1 = cx.RadialVelocity(Quantity([1, 2, 3], "km/s")) + >>> vel2 = cx.RadialVelocity(Quantity([1, 0, 3], "km/s")) + >>> jnp.equal(vel1, vel2) + Array([ True, False, True], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, True], dtype=bool) + + >>> acc1 = cx.RadialAcceleration(Quantity([1, 2, 3], "km/s2")) + >>> acc2 = cx.RadialAcceleration(Quantity([1, 0, 3], "km/s2")) + >>> jnp.equal(acc1, acc2) + Array([ True, False, True], dtype=bool) + >>> acc1 == acc2 + Array([ True, False, True], dtype=bool) + + >>> vel1 = cx.CartesianVelocity2D.constructor([[1, 3], [2, 4]], "km/s") + >>> vel2 = cx.CartesianVelocity2D.constructor([[1, 3], [0, 4]], "km/s") + >>> vel1.d_x + Quantity['speed'](Array([1., 2.], dtype=float32), unit='km / s') + >>> jnp.equal(vel1, vel2) + Array([ True, False], dtype=bool) + >>> vel1 == vel2 + Array([ True, False], dtype=bool) + + >>> acc1 = cx.CartesianAcceleration2D.constructor([[1, 3], [2, 4]], "km/s2") + >>> acc2 = cx.CartesianAcceleration2D.constructor([[1, 3], [0, 4]], "km/s2") + >>> acc1.d2_x + Quantity['acceleration'](Array([1., 2.], dtype=float32), unit='km / s2') + >>> jnp.equal(acc1, acc2) + Array([ True, False], dtype=bool) + >>> acc1 == acc2 + Array([ True, False], dtype=bool) + + >>> vel1 = cx.CartesianVelocity3D.constructor([[1, 4], [2, 5], [3, 6]], "km/s") + >>> vel2 = cx.CartesianVelocity3D.constructor([[1, 4], [0, 5], [3, 0]], "km/s") + >>> vel1.d_x + Quantity['speed'](Array([1., 2., 3.], dtype=float32), unit='km / s') + >>> jnp.equal(vel1, vel2) + Array([ True, False, False], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, False], dtype=bool) + + """ # TODO: match the behaviour of `numpy.equal` if type(lhs) is not type(rhs): msg = f"Cannot compare {type(lhs)} with {type(rhs)}." From 931b35968f0f0e59e51c98339d0ae8aaf9fd77a1 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 25 Sep 2024 22:11:37 -0400 Subject: [PATCH 3/4] ci: fix pylint Signed-off-by: nstarman --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fabd2a9..ea0f77a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,3 +238,4 @@ py-version = "3.10" reports.output-format = "colorized" similarities.ignore-imports = "yes" + max-module-lines = 1500 From f810282a82ad03f329404a6cfb0a9a78b9f41724 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 25 Sep 2024 23:44:37 -0400 Subject: [PATCH 4/4] refactor: move eq to objects Signed-off-by: nstarman --- src/coordinax/_src/base/base.py | 154 ++++++++++++++-------------- src/coordinax/_src/base/base_pos.py | 65 ++++++++---- tests/test_d1.py | 8 +- tests/test_d2.py | 8 +- tests/test_d3.py | 10 +- 5 files changed, 137 insertions(+), 108 deletions(-) diff --git a/src/coordinax/_src/base/base.py b/src/coordinax/_src/base/base.py index a534acd..71b816d 100644 --- a/src/coordinax/_src/base/base.py +++ b/src/coordinax/_src/base/base.py @@ -177,8 +177,6 @@ def aval(self) -> jax.core.ShapedArray: # =============================================================== # Array API - __eq__ = jnp.equal - # --------------------------------------------------------------- # Attributes @@ -333,6 +331,81 @@ def T(self) -> "Self": # noqa: N802 # --------------------------------------------------------------- # Methods + def __eq__(self: "AbstractVector", other: object) -> Any: + """Check if the vector is equal to another object. + + Examples + -------- + >>> import quaxed.numpy as jnp + >>> from unxt import Quantity + >>> import coordinax as cx + + Positions are covered by a separate dispatch. So here we show velocities + and accelerations: + + >>> vel1 = cx.CartesianVelocity1D(Quantity([1, 2, 3], "km/s")) + >>> vel2 = cx.CartesianVelocity1D(Quantity([1, 0, 3], "km/s")) + >>> jnp.equal(vel1, vel2) + Array([ True, False, True], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, True], dtype=bool) + + >>> acc1 = cx.CartesianAcceleration1D(Quantity([1, 2, 3], "km/s2")) + >>> acc2 = cx.CartesianAcceleration1D(Quantity([1, 0, 3], "km/s2")) + >>> jnp.equal(acc1, acc2) + Array([ True, False, True], dtype=bool) + >>> acc1 == acc2 + Array([ True, False, True], dtype=bool) + + >>> vel1 = cx.RadialVelocity(Quantity([1, 2, 3], "km/s")) + >>> vel2 = cx.RadialVelocity(Quantity([1, 0, 3], "km/s")) + >>> jnp.equal(vel1, vel2) + Array([ True, False, True], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, True], dtype=bool) + + >>> acc1 = cx.RadialAcceleration(Quantity([1, 2, 3], "km/s2")) + >>> acc2 = cx.RadialAcceleration(Quantity([1, 0, 3], "km/s2")) + >>> jnp.equal(acc1, acc2) + Array([ True, False, True], dtype=bool) + >>> acc1 == acc2 + Array([ True, False, True], dtype=bool) + + >>> vel1 = cx.CartesianVelocity2D.constructor([[1, 3], [2, 4]], "km/s") + >>> vel2 = cx.CartesianVelocity2D.constructor([[1, 3], [0, 4]], "km/s") + >>> vel1.d_x + Quantity['speed'](Array([1., 2.], dtype=float32), unit='km / s') + >>> jnp.equal(vel1, vel2) + Array([ True, False], dtype=bool) + >>> vel1 == vel2 + Array([ True, False], dtype=bool) + + >>> acc1 = cx.CartesianAcceleration2D.constructor([[1, 3], [2, 4]], "km/s2") + >>> acc2 = cx.CartesianAcceleration2D.constructor([[1, 3], [0, 4]], "km/s2") + >>> acc1.d2_x + Quantity['acceleration'](Array([1., 2.], dtype=float32), unit='km / s2') + >>> jnp.equal(acc1, acc2) + Array([ True, False], dtype=bool) + >>> acc1 == acc2 + Array([ True, False], dtype=bool) + + >>> vel1 = cx.CartesianVelocity3D.constructor([[1, 4], [2, 5], [3, 6]], "km/s") + >>> vel2 = cx.CartesianVelocity3D.constructor([[1, 4], [0, 5], [3, 0]], "km/s") + >>> vel1.d_x + Quantity['speed'](Array([1., 2., 3.], dtype=float32), unit='km / s') + >>> jnp.equal(vel1, vel2) + Array([ True, False, False], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, False], dtype=bool) + + """ + if type(other) is not type(self): + return NotImplemented + + comp_tree = tree.map(jnp.equal, self, other) + comp_leaves = jnp.array(tree.leaves(comp_tree)) + return jax.numpy.logical_and.reduce(comp_leaves) + def __len__(self) -> int: """Return the length of the vector. @@ -926,78 +999,5 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe @register(jax.lax.eq_p) # type: ignore[misc] def _eq_vec_vec(lhs: AbstractVector, rhs: AbstractVector, /) -> Bool[Array, "..."]: - """Element-wise equality of two vectors. - - Examples - -------- - >>> import quaxed.numpy as jnp - >>> from unxt import Quantity - >>> import coordinax as cx - - Positions are covered by a separate dispatch. So here we show velocities and - accelerations: - - >>> vel1 = cx.CartesianVelocity1D(Quantity([1, 2, 3], "km/s")) - >>> vel2 = cx.CartesianVelocity1D(Quantity([1, 0, 3], "km/s")) - >>> jnp.equal(vel1, vel2) - Array([ True, False, True], dtype=bool) - >>> vel1 == vel2 - Array([ True, False, True], dtype=bool) - - >>> acc1 = cx.CartesianAcceleration1D(Quantity([1, 2, 3], "km/s2")) - >>> acc2 = cx.CartesianAcceleration1D(Quantity([1, 0, 3], "km/s2")) - >>> jnp.equal(acc1, acc2) - Array([ True, False, True], dtype=bool) - >>> acc1 == acc2 - Array([ True, False, True], dtype=bool) - - >>> vel1 = cx.RadialVelocity(Quantity([1, 2, 3], "km/s")) - >>> vel2 = cx.RadialVelocity(Quantity([1, 0, 3], "km/s")) - >>> jnp.equal(vel1, vel2) - Array([ True, False, True], dtype=bool) - >>> vel1 == vel2 - Array([ True, False, True], dtype=bool) - - >>> acc1 = cx.RadialAcceleration(Quantity([1, 2, 3], "km/s2")) - >>> acc2 = cx.RadialAcceleration(Quantity([1, 0, 3], "km/s2")) - >>> jnp.equal(acc1, acc2) - Array([ True, False, True], dtype=bool) - >>> acc1 == acc2 - Array([ True, False, True], dtype=bool) - - >>> vel1 = cx.CartesianVelocity2D.constructor([[1, 3], [2, 4]], "km/s") - >>> vel2 = cx.CartesianVelocity2D.constructor([[1, 3], [0, 4]], "km/s") - >>> vel1.d_x - Quantity['speed'](Array([1., 2.], dtype=float32), unit='km / s') - >>> jnp.equal(vel1, vel2) - Array([ True, False], dtype=bool) - >>> vel1 == vel2 - Array([ True, False], dtype=bool) - - >>> acc1 = cx.CartesianAcceleration2D.constructor([[1, 3], [2, 4]], "km/s2") - >>> acc2 = cx.CartesianAcceleration2D.constructor([[1, 3], [0, 4]], "km/s2") - >>> acc1.d2_x - Quantity['acceleration'](Array([1., 2.], dtype=float32), unit='km / s2') - >>> jnp.equal(acc1, acc2) - Array([ True, False], dtype=bool) - >>> acc1 == acc2 - Array([ True, False], dtype=bool) - - >>> vel1 = cx.CartesianVelocity3D.constructor([[1, 4], [2, 5], [3, 6]], "km/s") - >>> vel2 = cx.CartesianVelocity3D.constructor([[1, 4], [0, 5], [3, 0]], "km/s") - >>> vel1.d_x - Quantity['speed'](Array([1., 2., 3.], dtype=float32), unit='km / s') - >>> jnp.equal(vel1, vel2) - Array([ True, False, False], dtype=bool) - >>> vel1 == vel2 - Array([ True, False, False], dtype=bool) - - """ - # TODO: match the behaviour of `numpy.equal` - if type(lhs) is not type(rhs): - msg = f"Cannot compare {type(lhs)} with {type(rhs)}." - raise TypeError(msg) - - comp_tree = tree.map(jnp.equal, lhs, rhs) - comp_leaves = jnp.array(tree.leaves(comp_tree)) - return jax.numpy.logical_and.reduce(comp_leaves) + """Element-wise equality of two vectors.""" + return lhs == rhs diff --git a/src/coordinax/_src/base/base_pos.py b/src/coordinax/_src/base/base_pos.py index 2abd872..1cd11c8 100644 --- a/src/coordinax/_src/base/base_pos.py +++ b/src/coordinax/_src/base/base_pos.py @@ -11,7 +11,6 @@ import equinox as eqx import jax -from jax import tree from jaxtyping import ArrayLike from plum import convert from quax import quaxify, register @@ -91,6 +90,51 @@ def differential_cls(cls) -> type["AbstractVelocity"]: __neg__ = jnp.negative + # =============================================================== + # Binary operations + + def __eq__(self: "AbstractPosition", other: object) -> Any: + """Element-wise equality of two positions. + + Examples + -------- + >>> import quaxed.numpy as jnp + >>> import coordinax as cx + + Showing the broadcasting, then element-wise comparison of two vectors: + + >>> vec1 = cx.CartesianPosition3D.constructor([[1, 2, 3], [1, 2, 4]], "m") + >>> vec2 = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> jnp.equal(vec1, vec2) + Array([ True, False], dtype=bool) + + Showing the change of representation: + + >>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> vec1 = vec.represent_as(cx.SphericalPosition) + >>> vec2 = vec.represent_as(cx.MathSphericalPosition) + >>> jnp.equal(vec1, vec2) + Array(True, dtype=bool) + + Quick run-through of each dimensionality: + + >>> vec1 = cx.CartesianPosition1D.constructor([1], "m") + >>> vec2 = cx.RadialPosition.constructor([1], "m") + >>> jnp.equal(vec1, vec2) + Array(True, dtype=bool) + + >>> vec1 = cx.CartesianPosition2D.constructor([2, 0], "m") + >>> vec2 = cx.PolarPosition(r=Quantity(2, "m"), phi=Quantity(0, "rad")) + >>> jnp.equal(vec1, vec2) + Array(True, dtype=bool) + + """ + if not isinstance(other, AbstractPosition): + return NotImplemented + + rhs = other.represent_as(type(self)) + return super().__eq__(rhs) + # =============================================================== # Convenience methods @@ -212,23 +256,8 @@ def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition: @register(jax.lax.eq_p) # type: ignore[misc] def _eq_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> ArrayLike: - """Element-wise equality of two positions. - - Examples - -------- - >>> import quaxed.numpy as jnp - >>> import coordinax as cx - - >>> vec1 = cx.CartesianPosition3D.constructor([[1, 2, 3], [1, 2, 4]], "m") - >>> vec2 = cx.CartesianPosition3D.constructor([1, 2, 3], "m") - >>> jnp.equal(vec1, vec2) - Array([ True, False], dtype=bool) - - """ - rhs_ = rhs.represent_as(rhs._cartesian_cls) # noqa: SLF001 - comp_tree = tree.map(jnp.equal, lhs, rhs_) - comp_leaves = jnp.array(tree.leaves(comp_tree)) - return jax.numpy.logical_and.reduce(comp_leaves) + """Element-wise equality of two positions.""" + return lhs == rhs # ------------------------------------------------ diff --git a/tests/test_d1.py b/tests/test_d1.py index cebbbe6..7a0d56e 100644 --- a/tests/test_d1.py +++ b/tests/test_d1.py @@ -30,7 +30,7 @@ def test_cartesian1d_to_cartesian1d(self, vector): """Test ``coordinax.represent_as(CartesianPosition1D)``.""" # Jit can copy newvec = vector.represent_as(cx.CartesianPosition1D) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CartesianPosition1D) @@ -123,7 +123,7 @@ def test_radial_to_radial(self, vector): """Test ``coordinax.represent_as(RadialPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.RadialPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.RadialPosition) @@ -212,7 +212,7 @@ def test_cartesian1d_to_cartesian1d(self, difntl, vector): """Test ``difntl.represent_as(CartesianVelocity1D)``.""" # Jit can copy newvec = difntl.represent_as(cx.CartesianVelocity1D, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.CartesianVelocity1D, vector) @@ -328,7 +328,7 @@ def test_radial_to_radial(self, difntl, vector): """Test ``difntl.represent_as(RadialVelocity)``.""" # Jit can copy newvec = difntl.represent_as(cx.RadialVelocity, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.RadialVelocity, vector) diff --git a/tests/test_d2.py b/tests/test_d2.py index f0c115e..51341cf 100644 --- a/tests/test_d2.py +++ b/tests/test_d2.py @@ -51,7 +51,7 @@ def test_cartesian2d_to_cartesian2d(self, vector): """Test ``coordinax.represent_as(CartesianPosition2D)``.""" # Jit can copy newvec = vector.represent_as(cx.CartesianPosition2D) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CartesianPosition2D) @@ -170,7 +170,7 @@ def test_polar_to_polar(self, vector): """Test ``coordinax.represent_as(PolarPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.PolarPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.PolarPosition) @@ -262,7 +262,7 @@ def test_cartesian2d_to_cartesian2d(self, difntl, vector): """Test ``difntl.represent_as(CartesianVelocity2D, vector)``.""" # Jit can copy newvec = difntl.represent_as(cx.CartesianVelocity2D, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.CartesianVelocity2D, vector) @@ -378,7 +378,7 @@ def test_polar_to_polar(self, difntl, vector): """Test ``difntl.represent_as(PolarVelocity, vector)``.""" # Jit can copy newvec = difntl.represent_as(cx.PolarVelocity, vector) - assert newvec == difntl + assert all(newvec == difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.PolarVelocity, vector) diff --git a/tests/test_d3.py b/tests/test_d3.py index e8b53b3..2b05245 100644 --- a/tests/test_d3.py +++ b/tests/test_d3.py @@ -118,7 +118,7 @@ def test_cartesian3d_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(CartesianPosition3D)``.""" # Jit can copy newvec = vector.represent_as(cx.CartesianPosition3D) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CartesianPosition3D) @@ -287,7 +287,7 @@ def test_cylindrical_to_cylindrical(self, vector): """Test ``coordinax.represent_as(CylindricalPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.CylindricalPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CylindricalPosition) @@ -440,7 +440,7 @@ def test_spherical_to_spherical(self, vector): """Test ``coordinax.represent_as(SphericalPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.SphericalPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.SphericalPosition) @@ -571,7 +571,7 @@ def test_cartesian3d_to_cartesian3d(self, difntl, vector): """Test ``coordinax.represent_as(CartesianPosition3D)``.""" # Jit can copy newvec = difntl.represent_as(cx.CartesianVelocity3D, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.CartesianVelocity3D, vector) @@ -939,7 +939,7 @@ def test_spherical_to_spherical(self, difntl, vector): """Test ``coordinax.represent_as(SphericalVelocity)``.""" # Jit can copy newvec = difntl.represent_as(cx.SphericalVelocity, vector) - assert newvec == difntl + assert all(newvec == difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.SphericalVelocity, vector)