diff --git a/pyproject.toml b/pyproject.toml index fabd2a9..ea0f77a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -238,3 +238,4 @@ py-version = "3.10" reports.output-format = "colorized" similarities.ignore-imports = "yes" + max-module-lines = 1500 diff --git a/src/coordinax/_src/base/base.py b/src/coordinax/_src/base/base.py index f7bff5f..71b816d 100644 --- a/src/coordinax/_src/base/base.py +++ b/src/coordinax/_src/base/base.py @@ -17,10 +17,10 @@ import jax import numpy as np from astropy.units import PhysicalType as Dimensions -from jax import Device -from jaxtyping import ArrayLike +from jax import Device, tree +from jaxtyping import Array, ArrayLike, Bool from plum import dispatch -from quax import ArrayValue +from quax import ArrayValue, register import quaxed.lax as qlax import quaxed.numpy as jnp @@ -331,6 +331,81 @@ def T(self) -> "Self": # noqa: N802 # --------------------------------------------------------------- # Methods + def __eq__(self: "AbstractVector", other: object) -> Any: + """Check if the vector is equal to another object. + + Examples + -------- + >>> import quaxed.numpy as jnp + >>> from unxt import Quantity + >>> import coordinax as cx + + Positions are covered by a separate dispatch. So here we show velocities + and accelerations: + + >>> vel1 = cx.CartesianVelocity1D(Quantity([1, 2, 3], "km/s")) + >>> vel2 = cx.CartesianVelocity1D(Quantity([1, 0, 3], "km/s")) + >>> jnp.equal(vel1, vel2) + Array([ True, False, True], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, True], dtype=bool) + + >>> acc1 = cx.CartesianAcceleration1D(Quantity([1, 2, 3], "km/s2")) + >>> acc2 = cx.CartesianAcceleration1D(Quantity([1, 0, 3], "km/s2")) + >>> jnp.equal(acc1, acc2) + Array([ True, False, True], dtype=bool) + >>> acc1 == acc2 + Array([ True, False, True], dtype=bool) + + >>> vel1 = cx.RadialVelocity(Quantity([1, 2, 3], "km/s")) + >>> vel2 = cx.RadialVelocity(Quantity([1, 0, 3], "km/s")) + >>> jnp.equal(vel1, vel2) + Array([ True, False, True], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, True], dtype=bool) + + >>> acc1 = cx.RadialAcceleration(Quantity([1, 2, 3], "km/s2")) + >>> acc2 = cx.RadialAcceleration(Quantity([1, 0, 3], "km/s2")) + >>> jnp.equal(acc1, acc2) + Array([ True, False, True], dtype=bool) + >>> acc1 == acc2 + Array([ True, False, True], dtype=bool) + + >>> vel1 = cx.CartesianVelocity2D.constructor([[1, 3], [2, 4]], "km/s") + >>> vel2 = cx.CartesianVelocity2D.constructor([[1, 3], [0, 4]], "km/s") + >>> vel1.d_x + Quantity['speed'](Array([1., 2.], dtype=float32), unit='km / s') + >>> jnp.equal(vel1, vel2) + Array([ True, False], dtype=bool) + >>> vel1 == vel2 + Array([ True, False], dtype=bool) + + >>> acc1 = cx.CartesianAcceleration2D.constructor([[1, 3], [2, 4]], "km/s2") + >>> acc2 = cx.CartesianAcceleration2D.constructor([[1, 3], [0, 4]], "km/s2") + >>> acc1.d2_x + Quantity['acceleration'](Array([1., 2.], dtype=float32), unit='km / s2') + >>> jnp.equal(acc1, acc2) + Array([ True, False], dtype=bool) + >>> acc1 == acc2 + Array([ True, False], dtype=bool) + + >>> vel1 = cx.CartesianVelocity3D.constructor([[1, 4], [2, 5], [3, 6]], "km/s") + >>> vel2 = cx.CartesianVelocity3D.constructor([[1, 4], [0, 5], [3, 0]], "km/s") + >>> vel1.d_x + Quantity['speed'](Array([1., 2., 3.], dtype=float32), unit='km / s') + >>> jnp.equal(vel1, vel2) + Array([ True, False, False], dtype=bool) + >>> vel1 == vel2 + Array([ True, False, False], dtype=bool) + + """ + if type(other) is not type(self): + return NotImplemented + + comp_tree = tree.map(jnp.equal, self, other) + comp_leaves = jnp.array(tree.leaves(comp_tree)) + return jax.numpy.logical_and.reduce(comp_leaves) + def __len__(self) -> int: """Return the length of the vector. @@ -831,7 +906,7 @@ def __str__(self) -> str: return f"<{cls_name} ({comps})\n {vs}>" -# ----------------------------------------------- +# =============================================================== # Register additional constructors @@ -916,3 +991,13 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe return obj return cls(**dict(field_items(obj))) + + +# =============================================================== +# Register primitives + + +@register(jax.lax.eq_p) # type: ignore[misc] +def _eq_vec_vec(lhs: AbstractVector, rhs: AbstractVector, /) -> Bool[Array, "..."]: + """Element-wise equality of two vectors.""" + return lhs == rhs diff --git a/src/coordinax/_src/base/base_pos.py b/src/coordinax/_src/base/base_pos.py index 863c303..1cd11c8 100644 --- a/src/coordinax/_src/base/base_pos.py +++ b/src/coordinax/_src/base/base_pos.py @@ -90,6 +90,51 @@ def differential_cls(cls) -> type["AbstractVelocity"]: __neg__ = jnp.negative + # =============================================================== + # Binary operations + + def __eq__(self: "AbstractPosition", other: object) -> Any: + """Element-wise equality of two positions. + + Examples + -------- + >>> import quaxed.numpy as jnp + >>> import coordinax as cx + + Showing the broadcasting, then element-wise comparison of two vectors: + + >>> vec1 = cx.CartesianPosition3D.constructor([[1, 2, 3], [1, 2, 4]], "m") + >>> vec2 = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> jnp.equal(vec1, vec2) + Array([ True, False], dtype=bool) + + Showing the change of representation: + + >>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> vec1 = vec.represent_as(cx.SphericalPosition) + >>> vec2 = vec.represent_as(cx.MathSphericalPosition) + >>> jnp.equal(vec1, vec2) + Array(True, dtype=bool) + + Quick run-through of each dimensionality: + + >>> vec1 = cx.CartesianPosition1D.constructor([1], "m") + >>> vec2 = cx.RadialPosition.constructor([1], "m") + >>> jnp.equal(vec1, vec2) + Array(True, dtype=bool) + + >>> vec1 = cx.CartesianPosition2D.constructor([2, 0], "m") + >>> vec2 = cx.PolarPosition(r=Quantity(2, "m"), phi=Quantity(0, "rad")) + >>> jnp.equal(vec1, vec2) + Array(True, dtype=bool) + + """ + if not isinstance(other, AbstractPosition): + return NotImplemented + + rhs = other.represent_as(type(self)) + return super().__eq__(rhs) + # =============================================================== # Convenience methods @@ -209,6 +254,15 @@ def _div_pos_v(lhs: AbstractPosition, rhs: ArrayLike) -> AbstractPosition: # ------------------------------------------------ +@register(jax.lax.eq_p) # type: ignore[misc] +def _eq_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> ArrayLike: + """Element-wise equality of two positions.""" + return lhs == rhs + + +# ------------------------------------------------ + + @register(jax.lax.mul_p) # type: ignore[misc] def _mul_v_pos(lhs: ArrayLike, rhs: AbstractPosition, /) -> AbstractPosition: """Scale a position by a scalar. diff --git a/tests/test_d1.py b/tests/test_d1.py index cebbbe6..7a0d56e 100644 --- a/tests/test_d1.py +++ b/tests/test_d1.py @@ -30,7 +30,7 @@ def test_cartesian1d_to_cartesian1d(self, vector): """Test ``coordinax.represent_as(CartesianPosition1D)``.""" # Jit can copy newvec = vector.represent_as(cx.CartesianPosition1D) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CartesianPosition1D) @@ -123,7 +123,7 @@ def test_radial_to_radial(self, vector): """Test ``coordinax.represent_as(RadialPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.RadialPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.RadialPosition) @@ -212,7 +212,7 @@ def test_cartesian1d_to_cartesian1d(self, difntl, vector): """Test ``difntl.represent_as(CartesianVelocity1D)``.""" # Jit can copy newvec = difntl.represent_as(cx.CartesianVelocity1D, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.CartesianVelocity1D, vector) @@ -328,7 +328,7 @@ def test_radial_to_radial(self, difntl, vector): """Test ``difntl.represent_as(RadialVelocity)``.""" # Jit can copy newvec = difntl.represent_as(cx.RadialVelocity, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.RadialVelocity, vector) diff --git a/tests/test_d2.py b/tests/test_d2.py index f0c115e..51341cf 100644 --- a/tests/test_d2.py +++ b/tests/test_d2.py @@ -51,7 +51,7 @@ def test_cartesian2d_to_cartesian2d(self, vector): """Test ``coordinax.represent_as(CartesianPosition2D)``.""" # Jit can copy newvec = vector.represent_as(cx.CartesianPosition2D) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CartesianPosition2D) @@ -170,7 +170,7 @@ def test_polar_to_polar(self, vector): """Test ``coordinax.represent_as(PolarPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.PolarPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.PolarPosition) @@ -262,7 +262,7 @@ def test_cartesian2d_to_cartesian2d(self, difntl, vector): """Test ``difntl.represent_as(CartesianVelocity2D, vector)``.""" # Jit can copy newvec = difntl.represent_as(cx.CartesianVelocity2D, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.CartesianVelocity2D, vector) @@ -378,7 +378,7 @@ def test_polar_to_polar(self, difntl, vector): """Test ``difntl.represent_as(PolarVelocity, vector)``.""" # Jit can copy newvec = difntl.represent_as(cx.PolarVelocity, vector) - assert newvec == difntl + assert all(newvec == difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.PolarVelocity, vector) diff --git a/tests/test_d3.py b/tests/test_d3.py index e8b53b3..2b05245 100644 --- a/tests/test_d3.py +++ b/tests/test_d3.py @@ -118,7 +118,7 @@ def test_cartesian3d_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(CartesianPosition3D)``.""" # Jit can copy newvec = vector.represent_as(cx.CartesianPosition3D) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CartesianPosition3D) @@ -287,7 +287,7 @@ def test_cylindrical_to_cylindrical(self, vector): """Test ``coordinax.represent_as(CylindricalPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.CylindricalPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.CylindricalPosition) @@ -440,7 +440,7 @@ def test_spherical_to_spherical(self, vector): """Test ``coordinax.represent_as(SphericalPosition)``.""" # Jit can copy newvec = vector.represent_as(cx.SphericalPosition) - assert newvec == vector + assert jnp.array_equal(newvec, vector) # The normal `represent_as` method should return the same object newvec = cx.represent_as(vector, cx.SphericalPosition) @@ -571,7 +571,7 @@ def test_cartesian3d_to_cartesian3d(self, difntl, vector): """Test ``coordinax.represent_as(CartesianPosition3D)``.""" # Jit can copy newvec = difntl.represent_as(cx.CartesianVelocity3D, vector) - assert newvec == difntl + assert jnp.array_equal(newvec, difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.CartesianVelocity3D, vector) @@ -939,7 +939,7 @@ def test_spherical_to_spherical(self, difntl, vector): """Test ``coordinax.represent_as(SphericalVelocity)``.""" # Jit can copy newvec = difntl.represent_as(cx.SphericalVelocity, vector) - assert newvec == difntl + assert all(newvec == difntl) # The normal `represent_as` method should return the same object newvec = cx.represent_as(difntl, cx.SphericalVelocity, vector)