Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support equality comparison #196

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,4 @@
py-version = "3.10"
reports.output-format = "colorized"
similarities.ignore-imports = "yes"
max-module-lines = 1500
93 changes: 89 additions & 4 deletions src/coordinax/_src/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import jax
import numpy as np
from astropy.units import PhysicalType as Dimensions
from jax import Device
from jaxtyping import ArrayLike
from jax import Device, tree
from jaxtyping import Array, ArrayLike, Bool
from plum import dispatch
from quax import ArrayValue
from quax import ArrayValue, register

import quaxed.lax as qlax
import quaxed.numpy as jnp
Expand Down Expand Up @@ -331,6 +331,81 @@
# ---------------------------------------------------------------
# Methods

def __eq__(self: "AbstractVector", other: object) -> Any:
"""Check if the vector is equal to another object.

Examples
--------
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx

Positions are covered by a separate dispatch. So here we show velocities
and accelerations:

>>> vel1 = cx.CartesianVelocity1D(Quantity([1, 2, 3], "km/s"))
>>> vel2 = cx.CartesianVelocity1D(Quantity([1, 0, 3], "km/s"))
>>> jnp.equal(vel1, vel2)
Array([ True, False, True], dtype=bool)
>>> vel1 == vel2
Array([ True, False, True], dtype=bool)

>>> acc1 = cx.CartesianAcceleration1D(Quantity([1, 2, 3], "km/s2"))
>>> acc2 = cx.CartesianAcceleration1D(Quantity([1, 0, 3], "km/s2"))
>>> jnp.equal(acc1, acc2)
Array([ True, False, True], dtype=bool)
>>> acc1 == acc2
Array([ True, False, True], dtype=bool)

>>> vel1 = cx.RadialVelocity(Quantity([1, 2, 3], "km/s"))
>>> vel2 = cx.RadialVelocity(Quantity([1, 0, 3], "km/s"))
>>> jnp.equal(vel1, vel2)
Array([ True, False, True], dtype=bool)
>>> vel1 == vel2
Array([ True, False, True], dtype=bool)

>>> acc1 = cx.RadialAcceleration(Quantity([1, 2, 3], "km/s2"))
>>> acc2 = cx.RadialAcceleration(Quantity([1, 0, 3], "km/s2"))
>>> jnp.equal(acc1, acc2)
Array([ True, False, True], dtype=bool)
>>> acc1 == acc2
Array([ True, False, True], dtype=bool)

>>> vel1 = cx.CartesianVelocity2D.constructor([[1, 3], [2, 4]], "km/s")
>>> vel2 = cx.CartesianVelocity2D.constructor([[1, 3], [0, 4]], "km/s")
>>> vel1.d_x
Quantity['speed'](Array([1., 2.], dtype=float32), unit='km / s')
>>> jnp.equal(vel1, vel2)
Array([ True, False], dtype=bool)
>>> vel1 == vel2
Array([ True, False], dtype=bool)

>>> acc1 = cx.CartesianAcceleration2D.constructor([[1, 3], [2, 4]], "km/s2")
>>> acc2 = cx.CartesianAcceleration2D.constructor([[1, 3], [0, 4]], "km/s2")
>>> acc1.d2_x
Quantity['acceleration'](Array([1., 2.], dtype=float32), unit='km / s2')
>>> jnp.equal(acc1, acc2)
Array([ True, False], dtype=bool)
>>> acc1 == acc2
Array([ True, False], dtype=bool)

>>> vel1 = cx.CartesianVelocity3D.constructor([[1, 4], [2, 5], [3, 6]], "km/s")
>>> vel2 = cx.CartesianVelocity3D.constructor([[1, 4], [0, 5], [3, 0]], "km/s")
>>> vel1.d_x
Quantity['speed'](Array([1., 2., 3.], dtype=float32), unit='km / s')
>>> jnp.equal(vel1, vel2)
Array([ True, False, False], dtype=bool)
>>> vel1 == vel2
Array([ True, False, False], dtype=bool)

"""
if type(other) is not type(self):
return NotImplemented

Check warning on line 403 in src/coordinax/_src/base/base.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_src/base/base.py#L403

Added line #L403 was not covered by tests

comp_tree = tree.map(jnp.equal, self, other)
comp_leaves = jnp.array(tree.leaves(comp_tree))
return jax.numpy.logical_and.reduce(comp_leaves)

def __len__(self) -> int:
"""Return the length of the vector.

Expand Down Expand Up @@ -831,7 +906,7 @@
return f"<{cls_name} ({comps})\n {vs}>"


# -----------------------------------------------
# ===============================================================
# Register additional constructors


Expand Down Expand Up @@ -916,3 +991,13 @@
return obj

return cls(**dict(field_items(obj)))


# ===============================================================
# Register primitives


@register(jax.lax.eq_p) # type: ignore[misc]
def _eq_vec_vec(lhs: AbstractVector, rhs: AbstractVector, /) -> Bool[Array, "..."]:
"""Element-wise equality of two vectors."""
return lhs == rhs
54 changes: 54 additions & 0 deletions src/coordinax/_src/base/base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,51 @@

__neg__ = jnp.negative

# ===============================================================
# Binary operations

def __eq__(self: "AbstractPosition", other: object) -> Any:
"""Element-wise equality of two positions.

Examples
--------
>>> import quaxed.numpy as jnp
>>> import coordinax as cx

Showing the broadcasting, then element-wise comparison of two vectors:

>>> vec1 = cx.CartesianPosition3D.constructor([[1, 2, 3], [1, 2, 4]], "m")
>>> vec2 = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> jnp.equal(vec1, vec2)
Array([ True, False], dtype=bool)

Showing the change of representation:

>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> vec1 = vec.represent_as(cx.SphericalPosition)
>>> vec2 = vec.represent_as(cx.MathSphericalPosition)
>>> jnp.equal(vec1, vec2)
Array(True, dtype=bool)

Quick run-through of each dimensionality:

>>> vec1 = cx.CartesianPosition1D.constructor([1], "m")
>>> vec2 = cx.RadialPosition.constructor([1], "m")
>>> jnp.equal(vec1, vec2)
Array(True, dtype=bool)

>>> vec1 = cx.CartesianPosition2D.constructor([2, 0], "m")
>>> vec2 = cx.PolarPosition(r=Quantity(2, "m"), phi=Quantity(0, "rad"))
>>> jnp.equal(vec1, vec2)
Array(True, dtype=bool)

"""
if not isinstance(other, AbstractPosition):
return NotImplemented

Check warning on line 133 in src/coordinax/_src/base/base_pos.py

View check run for this annotation

Codecov / codecov/patch

src/coordinax/_src/base/base_pos.py#L133

Added line #L133 was not covered by tests

rhs = other.represent_as(type(self))
return super().__eq__(rhs)

# ===============================================================
# Convenience methods

Expand Down Expand Up @@ -209,6 +254,15 @@
# ------------------------------------------------


@register(jax.lax.eq_p) # type: ignore[misc]
def _eq_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> ArrayLike:
"""Element-wise equality of two positions."""
return lhs == rhs


# ------------------------------------------------


@register(jax.lax.mul_p) # type: ignore[misc]
def _mul_v_pos(lhs: ArrayLike, rhs: AbstractPosition, /) -> AbstractPosition:
"""Scale a position by a scalar.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_d1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_cartesian1d_to_cartesian1d(self, vector):
"""Test ``coordinax.represent_as(CartesianPosition1D)``."""
# Jit can copy
newvec = vector.represent_as(cx.CartesianPosition1D)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.CartesianPosition1D)
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_radial_to_radial(self, vector):
"""Test ``coordinax.represent_as(RadialPosition)``."""
# Jit can copy
newvec = vector.represent_as(cx.RadialPosition)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.RadialPosition)
Expand Down Expand Up @@ -212,7 +212,7 @@ def test_cartesian1d_to_cartesian1d(self, difntl, vector):
"""Test ``difntl.represent_as(CartesianVelocity1D)``."""
# Jit can copy
newvec = difntl.represent_as(cx.CartesianVelocity1D, vector)
assert newvec == difntl
assert jnp.array_equal(newvec, difntl)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(difntl, cx.CartesianVelocity1D, vector)
Expand Down Expand Up @@ -328,7 +328,7 @@ def test_radial_to_radial(self, difntl, vector):
"""Test ``difntl.represent_as(RadialVelocity)``."""
# Jit can copy
newvec = difntl.represent_as(cx.RadialVelocity, vector)
assert newvec == difntl
assert jnp.array_equal(newvec, difntl)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(difntl, cx.RadialVelocity, vector)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_cartesian2d_to_cartesian2d(self, vector):
"""Test ``coordinax.represent_as(CartesianPosition2D)``."""
# Jit can copy
newvec = vector.represent_as(cx.CartesianPosition2D)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.CartesianPosition2D)
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_polar_to_polar(self, vector):
"""Test ``coordinax.represent_as(PolarPosition)``."""
# Jit can copy
newvec = vector.represent_as(cx.PolarPosition)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.PolarPosition)
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_cartesian2d_to_cartesian2d(self, difntl, vector):
"""Test ``difntl.represent_as(CartesianVelocity2D, vector)``."""
# Jit can copy
newvec = difntl.represent_as(cx.CartesianVelocity2D, vector)
assert newvec == difntl
assert jnp.array_equal(newvec, difntl)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(difntl, cx.CartesianVelocity2D, vector)
Expand Down Expand Up @@ -378,7 +378,7 @@ def test_polar_to_polar(self, difntl, vector):
"""Test ``difntl.represent_as(PolarVelocity, vector)``."""
# Jit can copy
newvec = difntl.represent_as(cx.PolarVelocity, vector)
assert newvec == difntl
assert all(newvec == difntl)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(difntl, cx.PolarVelocity, vector)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_d3.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_cartesian3d_to_cartesian3d(self, vector):
"""Test ``coordinax.represent_as(CartesianPosition3D)``."""
# Jit can copy
newvec = vector.represent_as(cx.CartesianPosition3D)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.CartesianPosition3D)
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_cylindrical_to_cylindrical(self, vector):
"""Test ``coordinax.represent_as(CylindricalPosition)``."""
# Jit can copy
newvec = vector.represent_as(cx.CylindricalPosition)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.CylindricalPosition)
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_spherical_to_spherical(self, vector):
"""Test ``coordinax.represent_as(SphericalPosition)``."""
# Jit can copy
newvec = vector.represent_as(cx.SphericalPosition)
assert newvec == vector
assert jnp.array_equal(newvec, vector)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(vector, cx.SphericalPosition)
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_cartesian3d_to_cartesian3d(self, difntl, vector):
"""Test ``coordinax.represent_as(CartesianPosition3D)``."""
# Jit can copy
newvec = difntl.represent_as(cx.CartesianVelocity3D, vector)
assert newvec == difntl
assert jnp.array_equal(newvec, difntl)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(difntl, cx.CartesianVelocity3D, vector)
Expand Down Expand Up @@ -939,7 +939,7 @@ def test_spherical_to_spherical(self, difntl, vector):
"""Test ``coordinax.represent_as(SphericalVelocity)``."""
# Jit can copy
newvec = difntl.represent_as(cx.SphericalVelocity, vector)
assert newvec == difntl
assert all(newvec == difntl)

# The normal `represent_as` method should return the same object
newvec = cx.represent_as(difntl, cx.SphericalVelocity, vector)
Expand Down