Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sphere constructor #155

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 102 additions & 9 deletions src/coordinax/_d3/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@

import quaxed.array_api as xp
import quaxed.lax as qlax
import quaxed.numpy as jnp
from dataclassish import replace
from unxt import AbstractDistance, Distance, Quantity
from unxt import AbstractDistance, AbstractQuantity, Distance, Quantity

import coordinax._typing as ct
from .base import AbstractAcceleration3D, AbstractPosition3D, AbstractVelocity3D
Expand All @@ -46,6 +47,7 @@

_90d = Quantity(90, "deg")
_180d = Quantity(180, "deg")
_360d = Quantity(360, "deg")

##############################################################################
# Position
Expand All @@ -60,6 +62,9 @@ class AbstractSphericalPosition(AbstractPosition3D):
def differential_cls(cls) -> type["AbstractSphericalVelocity"]: ...


# ============================================================================


@final
class SphericalPosition(AbstractSphericalPosition):
"""Spherical vector representation.
Expand Down Expand Up @@ -110,6 +115,93 @@ def differential_cls(cls) -> type["SphericalVelocity"]:
return SphericalVelocity


@SphericalPosition.constructor._f.register # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(
cls: type[SphericalPosition],
*,
r: AbstractQuantity,
theta: AbstractQuantity,
phi: AbstractQuantity,
) -> SphericalPosition:
"""Construct SphericalPosition, allowing for out-of-range values.

Examples
--------
>>> from unxt import Quantity
>>> import coordinax as cx

Let's start with a valid input:

>>> cx.SphericalPosition.constructor(r=Quantity(3, "kpc"),
... theta=Quantity(90, "deg"),
... phi=Quantity(0, "deg"))
SphericalPosition(
r=Distance(value=f32[], unit=Unit("kpc")),
theta=Quantity[...](value=f32[], unit=Unit("deg")),
phi=Quantity[...](value=f32[], unit=Unit("deg"))
)

The radial distance can be negative, which wraps the azimuthal angle by 180
degrees and flips the polar angle:

>>> vec = cx.SphericalPosition.constructor(r=Quantity(-3, "kpc"),
... theta=Quantity(45, "deg"),
... phi=Quantity(0, "deg"))
>>> vec.r
Distance(Array(3., dtype=float32), unit='kpc')
>>> vec.theta
Quantity['angle'](Array(135., dtype=float32), unit='deg')
>>> vec.phi
Quantity[...](Array(180., dtype=float32), unit='deg')

The polar angle can be outside the [0, 180] deg range, causing the azimuthal
angle to be shifted by 180 degrees:

>>> vec = cx.SphericalPosition.constructor(r=Quantity(3, "kpc"),
... theta=Quantity(190, "deg"),
... phi=Quantity(0, "deg"))
>>> vec.r
Distance(Array(3., dtype=float32), unit='kpc')
>>> vec.theta
Quantity['angle'](Array(170., dtype=float32), unit='deg')
>>> vec.phi
Quantity['angle'](Array(180., dtype=float32), unit='deg')

The azimuth can be outside the [0, 360) deg range. This is wrapped to the
[0, 360) deg range (actually the base constructor does this):

>>> vec = cx.SphericalPosition.constructor(r=Quantity(3, "kpc"),
... theta=Quantity(90, "deg"),
... phi=Quantity(365, "deg"))
>>> vec.phi
Quantity['angle'](Array(5., dtype=float32), unit='deg')

"""
# 1) Convert the inputs
fields = SphericalPosition.__dataclass_fields__
r = fields["r"].metadata["converter"](r)
theta = fields["theta"].metadata["converter"](theta)
phi = fields["phi"].metadata["converter"](phi)

# 2) handle negative distances
r_pred = r < xp.zeros_like(r)
r = qlax.select(r_pred, -r, r)
phi = qlax.select(r_pred, phi + _180d, phi)
theta = qlax.select(r_pred, _180d - theta, theta)

# 3) Handle polar angle outside of [0, 180] degrees
theta = jnp.mod(theta, _360d) # wrap to [0, 360) deg
theta_pred = theta < _180d
theta = qlax.select(theta_pred, theta, _360d - theta)
phi = qlax.select(theta_pred, phi, phi + _180d)

# 4) Construct. This also handles the azimuthal angle wrapping
return cls(r=r, theta=theta, phi=phi)


# ============================================================================


@final
class MathSphericalPosition(AbstractSphericalPosition):
"""Spherical vector representation.
Expand Down Expand Up @@ -321,8 +413,8 @@ def constructor(
flips the latitude:

>>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"),
... lat=Quantity(45, "deg"),
... distance=Quantity(-3, "kpc"))
... lat=Quantity(45, "deg"),
... distance=Quantity(-3, "kpc"))
>>> vec.lon
Quantity['angle'](Array(180., dtype=float32), unit='deg')
>>> vec.lat
Expand All @@ -334,8 +426,8 @@ def constructor(
to be shifted by 180 degrees:

>>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"),
... lat=Quantity(-100, "deg"),
... distance=Quantity(3, "kpc"))
... lat=Quantity(-100, "deg"),
... distance=Quantity(3, "kpc"))
>>> vec.lon
Quantity['angle'](Array(180., dtype=float32), unit='deg')
>>> vec.lat
Expand All @@ -344,8 +436,8 @@ def constructor(
Distance(Array(3., dtype=float32), unit='kpc')

>>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(0, "deg"),
... lat=Quantity(100, "deg"),
... distance=Quantity(3, "kpc"))
... lat=Quantity(100, "deg"),
... distance=Quantity(3, "kpc"))
>>> vec.lon
Quantity['angle'](Array(180., dtype=float32), unit='deg')
>>> vec.lat
Expand All @@ -357,8 +449,8 @@ def constructor(
[0, 360) deg range (actually the base constructor does this):

>>> vec = cx.LonLatSphericalPosition.constructor(lon=Quantity(365, "deg"),
... lat=Quantity(0, "deg"),
... distance=Quantity(3, "kpc"))
... lat=Quantity(0, "deg"),
... distance=Quantity(3, "kpc"))
>>> vec.lon
Quantity['angle'](Array(5., dtype=float32), unit='deg')

Expand All @@ -376,6 +468,7 @@ def constructor(
lat = qlax.select(distance_pred, -lat, lat)

# 3) Handle latitude outside of [-90, 90] degrees
# TODO: fix when lat < -180, lat > 180
lat_pred = lat < -_90d
lat = qlax.select(lat_pred, -_180d - lat, lat)
lon = qlax.select(lat_pred, lon + _180d, lon)
Expand Down