From d2899f5da6b0f594def4dee3636c0775c09470db Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Thu, 14 Nov 2024 21:23:37 -0500 Subject: [PATCH] build: bump unxt (#222) Signed-off-by: nstarman --- pyproject.toml | 2 +- .../coordinax_interop_astropy/constructors.py | 10 +- .../coordinax_interop_astropy/converters.py | 100 +++++++++--------- src/coordinax/_src/angle/base.py | 7 +- .../_src/angle/register_primitives.py | 34 +++--- src/coordinax/_src/base/base.py | 50 ++++----- src/coordinax/_src/distance/core.py | 20 ++-- .../_src/distance/register_primitives.py | 38 +++---- .../_src/operators/galilean/rotation.py | 18 ++-- src/coordinax/_src/space.py | 27 ++--- src/coordinax/_src/transform/d2.py | 13 ++- src/coordinax/_src/typing.py | 8 +- uv.lock | 10 +- 13 files changed, 162 insertions(+), 175 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a03d43..c048bc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ "plum-dispatch>=2.5.2", "quax>=0.0.5", "quaxed>=0.6.4", - "unxt>=0.21", + "unxt>=0.22", "xmmutablemap>=0.1", ] diff --git a/src/coordinax/_interop/coordinax_interop_astropy/constructors.py b/src/coordinax/_interop/coordinax_interop_astropy/constructors.py index 2e0e8df..deebff0 100644 --- a/src/coordinax/_interop/coordinax_interop_astropy/constructors.py +++ b/src/coordinax/_interop/coordinax_interop_astropy/constructors.py @@ -6,10 +6,10 @@ from collections.abc import Mapping import astropy.coordinates as apyc -import astropy.units as u +import astropy.units as apyu from plum import convert -from unxt import Quantity +import unxt as u import coordinax as cx @@ -18,7 +18,7 @@ @cx.AbstractVector.from_._f.dispatch # noqa: SLF001 def from_( - cls: type[cx.AbstractVector], obj: Mapping[str, u.Quantity], / + cls: type[cx.AbstractVector], obj: Mapping[str, apyu.Quantity], / ) -> cx.AbstractVector: """Construct a vector from a mapping. @@ -486,7 +486,7 @@ def from_( @cx.AbstractVector.from_._f.dispatch # noqa: SLF001 -def from_(cls: type[cx.AbstractVector], obj: u.Quantity, /) -> cx.AbstractVector: +def from_(cls: type[cx.AbstractVector], obj: apyu.Quantity, /) -> cx.AbstractVector: """Construct a vector from an Astropy Quantity array. The array is expected to have the components as the last dimension. @@ -559,4 +559,4 @@ def from_(cls: type[cx.AbstractVector], obj: u.Quantity, /) -> cx.AbstractVector Quantity['length'](Array([1., 4.], dtype=float32), unit='m') """ - return cls.from_(convert(obj, Quantity)) + return cls.from_(convert(obj, u.Quantity)) diff --git a/src/coordinax/_interop/coordinax_interop_astropy/converters.py b/src/coordinax/_interop/coordinax_interop_astropy/converters.py index 24086b2..9795f87 100644 --- a/src/coordinax/_interop/coordinax_interop_astropy/converters.py +++ b/src/coordinax/_interop/coordinax_interop_astropy/converters.py @@ -6,11 +6,11 @@ import astropy.coordinates as apyc -import astropy.units as u +import astropy.units as apyu from jaxtyping import Shaped from plum import conversion_method, convert -import unxt as ux +import unxt as u import coordinax as cx @@ -20,59 +20,59 @@ # Quantity -@conversion_method(cx.AbstractPos3D, u.Quantity) # type: ignore[misc] -def vec_to_q(obj: cx.AbstractPos3D, /) -> Shaped[u.Quantity, "*batch 3"]: +@conversion_method(cx.AbstractPos3D, apyu.Quantity) # type: ignore[misc] +def vec_to_q(obj: cx.AbstractPos3D, /) -> Shaped[apyu.Quantity, "*batch 3"]: """`coordinax.AbstractPos3D` -> `astropy.units.Quantity`. Examples -------- >>> import coordinax as cx >>> from plum import convert - >>> from astropy.units import Quantity + >>> import astropy.units as apyu >>> vec = cx.CartesianPos3D.from_([1, 2, 3], "kpc") - >>> convert(vec, Quantity) + >>> convert(vec, apyu.Quantity) - >>> vec = cx.SphericalPos(r=Quantity(1, unit="kpc"), - ... theta=Quantity(2, unit="deg"), - ... phi=Quantity(3, unit="deg")) - >>> convert(vec, Quantity) + >>> vec = cx.SphericalPos(r=apyu.Quantity(1, unit="kpc"), + ... theta=apyu.Quantity(2, unit="deg"), + ... phi=apyu.Quantity(3, unit="deg")) + >>> convert(vec, apyu.Quantity) - >>> vec = cx.CylindricalPos(rho=Quantity(1, unit="kpc"), - ... phi=Quantity(2, unit="deg"), - ... z=Quantity(3, unit="pc")) - >>> convert(vec, Quantity) + >>> vec = cx.CylindricalPos(rho=apyu.Quantity(1, unit="kpc"), + ... phi=apyu.Quantity(2, unit="deg"), + ... z=apyu.Quantity(3, unit="pc")) + >>> convert(vec, apyu.Quantity) """ - return convert(convert(obj, ux.Quantity), u.Quantity) + return convert(convert(obj, u.Quantity), apyu.Quantity) -@conversion_method(cx.CartesianAcc3D, u.Quantity) # type: ignore[misc] -@conversion_method(cx.CartesianVel3D, u.Quantity) # type: ignore[misc] +@conversion_method(cx.CartesianAcc3D, apyu.Quantity) # type: ignore[misc] +@conversion_method(cx.CartesianVel3D, apyu.Quantity) # type: ignore[misc] def vec_diff_to_q( obj: cx.CartesianVel3D | cx.CartesianAcc3D, / -) -> Shaped[u.Quantity, "*batch 3"]: +) -> Shaped[apyu.Quantity, "*batch 3"]: """`coordinax.CartesianVel3D` -> `astropy.units.Quantity`. Examples -------- >>> import coordinax as cx >>> from plum import convert - >>> from astropy.units import Quantity + >>> from astropy.units import Quantity as AstropyQuantity >>> dif = cx.CartesianVel3D.from_([1, 2, 3], "km/s") - >>> convert(dif, Quantity) + >>> convert(dif, AstropyQuantity) >>> dif2 = cx.CartesianAcc3D.from_([1, 2, 3], "km/s2") - >>> convert(dif2, Quantity) + >>> convert(dif2, AstropyQuantity) """ - return convert(convert(obj, ux.Quantity), u.Quantity) + return convert(convert(obj, u.Quantity), apyu.Quantity) # ===================================== @@ -100,9 +100,9 @@ def cart3_to_apycart3(obj: cx.CartesianPos3D, /) -> apyc.CartesianRepresentation """ return apyc.CartesianRepresentation( - x=convert(obj.x, u.Quantity), - y=convert(obj.y, u.Quantity), - z=convert(obj.z, u.Quantity), + x=convert(obj.x, apyu.Quantity), + y=convert(obj.y, apyu.Quantity), + z=convert(obj.z, apyu.Quantity), ) @@ -112,7 +112,6 @@ def apycart3_to_cart3(obj: apyc.CartesianRepresentation, /) -> cx.CartesianPos3D Examples -------- - >>> import astropy.units as u >>> import coordinax as cx >>> from astropy.coordinates import CartesianRepresentation @@ -155,9 +154,9 @@ def cyl_to_apycyl(obj: cx.CylindricalPos, /) -> apyc.CylindricalRepresentation: """ return apyc.CylindricalRepresentation( - rho=convert(obj.rho, u.Quantity), - phi=convert(obj.phi, u.Quantity), - z=convert(obj.z, u.Quantity), + rho=convert(obj.rho, apyu.Quantity), + phi=convert(obj.phi, apyu.Quantity), + z=convert(obj.z, apyu.Quantity), ) @@ -206,9 +205,9 @@ def sph_to_apysph(obj: cx.SphericalPos, /) -> apyc.PhysicsSphericalRepresentatio """ return apyc.PhysicsSphericalRepresentation( - r=convert(obj.r, u.Quantity), - phi=convert(obj.phi, u.Quantity), - theta=convert(obj.theta, u.Quantity), + r=convert(obj.r, apyu.Quantity), + phi=convert(obj.phi, apyu.Quantity), + theta=convert(obj.theta, apyu.Quantity), ) @@ -258,9 +257,9 @@ def lonlatsph_to_apysph(obj: cx.LonLatSphericalPos, /) -> apyc.SphericalRepresen """ return apyc.SphericalRepresentation( - lon=convert(obj.lon, u.Quantity), - lat=convert(obj.lat, u.Quantity), - distance=convert(obj.distance, u.Quantity), + lon=convert(obj.lon, apyu.Quantity), + lat=convert(obj.lat, apyu.Quantity), + distance=convert(obj.distance, apyu.Quantity), ) @@ -308,9 +307,9 @@ def diffcart3_to_apycart3(obj: cx.CartesianVel3D, /) -> apyc.CartesianDifferenti """ return apyc.CartesianDifferential( - d_x=convert(obj.d_x, u.Quantity), - d_y=convert(obj.d_y, u.Quantity), - d_z=convert(obj.d_z, u.Quantity), + d_x=convert(obj.d_x, apyu.Quantity), + d_y=convert(obj.d_y, apyu.Quantity), + d_z=convert(obj.d_z, apyu.Quantity), ) @@ -322,7 +321,6 @@ def apycart3_to_diffcart3(obj: apyc.CartesianDifferential, /) -> cx.CartesianVel Examples -------- - >>> import astropy.units as u >>> import coordinax as cx >>> from astropy.coordinates import CartesianDifferential @@ -366,9 +364,9 @@ def diffcyl_to_apycyl(obj: cx.CylindricalVel, /) -> apyc.CylindricalDifferential """ return apyc.CylindricalDifferential( - d_rho=convert(obj.d_rho, u.Quantity), - d_phi=convert(obj.d_phi, u.Quantity), - d_z=convert(obj.d_z, u.Quantity), + d_rho=convert(obj.d_rho, apyu.Quantity), + d_phi=convert(obj.d_phi, apyu.Quantity), + d_z=convert(obj.d_z, apyu.Quantity), ) @@ -424,9 +422,9 @@ def diffsph_to_apysph(obj: cx.SphericalVel, /) -> apyc.PhysicsSphericalDifferent """ return apyc.PhysicsSphericalDifferential( - d_r=convert(obj.d_r, u.Quantity), - d_theta=convert(obj.d_theta, u.Quantity), - d_phi=convert(obj.d_phi, u.Quantity), + d_r=convert(obj.d_r, apyu.Quantity), + d_theta=convert(obj.d_theta, apyu.Quantity), + d_phi=convert(obj.d_phi, apyu.Quantity), ) @@ -484,9 +482,9 @@ def difflonlatsph_to_apysph( """ return apyc.SphericalDifferential( - d_distance=convert(obj.d_distance, u.Quantity), - d_lon=convert(obj.d_lon, u.Quantity), - d_lat=convert(obj.d_lat, u.Quantity), + d_distance=convert(obj.d_distance, apyu.Quantity), + d_lon=convert(obj.d_lon, apyu.Quantity), + d_lat=convert(obj.d_lat, apyu.Quantity), ) @@ -546,9 +544,9 @@ def diffloncoslatsph_to_apysph( """ # noqa: E501 return apyc.SphericalCosLatDifferential( - d_distance=convert(obj.d_distance, u.Quantity), - d_lon_coslat=convert(obj.d_lon_coslat, u.Quantity), - d_lat=convert(obj.d_lat, u.Quantity), + d_distance=convert(obj.d_distance, apyu.Quantity), + d_lon_coslat=convert(obj.d_lon_coslat, apyu.Quantity), + d_lat=convert(obj.d_lat, apyu.Quantity), ) diff --git a/src/coordinax/_src/angle/base.py b/src/coordinax/_src/angle/base.py index 2817f41..6c7fb0f 100644 --- a/src/coordinax/_src/angle/base.py +++ b/src/coordinax/_src/angle/base.py @@ -2,13 +2,12 @@ __all__: list[str] = [] -import astropy.units as u from plum import add_promotion_rule, conversion_method -from unxt import Quantity, dimensions_of +from unxt import Quantity, dimension, dimension_of from unxt.quantity import AbstractQuantity -angle_dimension = u.get_physical_type("angle") +angle_dimension = dimension("angle") class AbstractAngle(AbstractQuantity): # type: ignore[misc] @@ -39,7 +38,7 @@ class AbstractAngle(AbstractQuantity): # type: ignore[misc] def __check_init__(self) -> None: """Check the initialization.""" - if dimensions_of(self) != angle_dimension: + if dimension_of(self) != angle_dimension: msg = "Angle must have dimensions angle." raise ValueError(msg) diff --git a/src/coordinax/_src/angle/register_primitives.py b/src/coordinax/_src/angle/register_primitives.py index aee8dbd..648803d 100644 --- a/src/coordinax/_src/angle/register_primitives.py +++ b/src/coordinax/_src/angle/register_primitives.py @@ -6,19 +6,21 @@ from collections.abc import Callable from typing import Any, TypeVar -from astropy.units import dimensionless_unscaled as one, radian # pylint: disable=E0611 from jax import lax from jax.core import Primitive from jaxtyping import ArrayLike from quax import register as register_ +import unxt as u from quaxed import lax as qlax -from unxt import Quantity, ustrip from .base import AbstractAngle T = TypeVar("T") +one = u.unit("") +radian = u.unit("radian") + def register(primitive: Primitive, **kwargs: Any) -> Callable[[T], T]: """`quax.register`, but makes mypy happy.""" @@ -31,7 +33,7 @@ def register(primitive: Primitive, **kwargs: Any) -> Callable[[T], T]: # TODO: can this be done with promotion/conversion instead? @register(lax.cbrt_p) -def _cbrt_p_a(x: AbstractAngle) -> Quantity: +def _cbrt_p_a(x: AbstractAngle) -> u.Quantity: """Cube root of an angle. Examples @@ -44,14 +46,14 @@ def _cbrt_p_a(x: AbstractAngle) -> Quantity: Quantity['rad1/3'](Array(2., dtype=float32, weak_type=True), unit='rad(1/3)') """ - return Quantity(lax.cbrt(x.value), unit=x.unit ** (1 / 3)) + return u.Quantity(lax.cbrt(x.value), unit=x.unit ** (1 / 3)) # ============================================================================== @register(lax.cos_p) -def _cos_p(x: AbstractAngle) -> Quantity: +def _cos_p(x: AbstractAngle) -> u.Quantity: """Cosine of an Angle. Examples @@ -64,7 +66,7 @@ def _cos_p(x: AbstractAngle) -> Quantity: Quantity['dimensionless'](Array(1., dtype=float32, ...), unit='') """ - return Quantity(qlax.cos(ustrip(radian, x)), unit=one) + return u.Quantity(qlax.cos(u.ustrip(radian, x)), unit=one) # ============================================================================== @@ -73,7 +75,7 @@ def _cos_p(x: AbstractAngle) -> Quantity: @register(lax.dot_general_p) def _dot_general_aa( lhs: AbstractAngle, rhs: AbstractAngle, /, **kwargs: Any -) -> Quantity: +) -> u.Quantity: """Dot product of two Angles. Examples @@ -90,7 +92,7 @@ def _dot_general_aa( Quantity['solid angle'](Array(32, dtype=int32), unit='deg2') """ - return Quantity( + return u.Quantity( lax.dot_general_p.bind(lhs.value, rhs.value, **kwargs), unit=lhs.unit * rhs.unit, ) @@ -100,7 +102,7 @@ def _dot_general_aa( @register(lax.integer_pow_p) -def _integer_pow_p_a(x: AbstractAngle, *, y: Any) -> Quantity: +def _integer_pow_p_a(x: AbstractAngle, *, y: Any) -> u.Quantity: """Integer power of an Angle. Examples @@ -112,14 +114,14 @@ def _integer_pow_p_a(x: AbstractAngle, *, y: Any) -> Quantity: Quantity['rad3'](Array(8, dtype=int32, weak_type=True), unit='deg3') """ - return Quantity(value=lax.integer_pow(x.value, y), unit=x.unit**y) + return u.Quantity(value=lax.integer_pow(x.value, y), unit=x.unit**y) # ============================================================================== @register(lax.pow_p) -def _pow_p_a(x: AbstractAngle, y: ArrayLike) -> Quantity: +def _pow_p_a(x: AbstractAngle, y: ArrayLike) -> u.Quantity: """Power of an Angle by redispatching to Quantity. Examples @@ -133,14 +135,14 @@ def _pow_p_a(x: AbstractAngle, y: ArrayLike) -> Quantity: Quantity['rad3'](Array(1000., dtype=float32, ...), unit='deg3') """ - return Quantity(x.value, x.unit) ** y # TODO: better call to power + return u.Quantity(x.value, x.unit) ** y # TODO: better call to power # ============================================================================== @register(lax.sin_p) -def _sin_p(x: AbstractAngle) -> Quantity: +def _sin_p(x: AbstractAngle) -> u.Quantity: """Sine of an Angle. Examples @@ -153,14 +155,14 @@ def _sin_p(x: AbstractAngle) -> Quantity: Quantity['dimensionless'](Array(1., dtype=float32, ...), unit='') """ - return Quantity(qlax.sin(ustrip(radian, x)), unit=one) + return u.Quantity(qlax.sin(u.ustrip(radian, x)), unit=one) # ============================================================================== @register(lax.sqrt_p) -def _sqrt_p_a(x: AbstractAngle) -> Quantity: +def _sqrt_p_a(x: AbstractAngle) -> u.Quantity: """Square root of an Angle. Examples @@ -174,4 +176,4 @@ def _sqrt_p_a(x: AbstractAngle) -> Quantity: """ # Promote to something that supports sqrt units. - return Quantity(lax.sqrt(x.value), unit=x.unit ** (1 / 2)) + return u.Quantity(lax.sqrt(x.value), unit=x.unit ** (1 / 2)) diff --git a/src/coordinax/_src/base/base.py b/src/coordinax/_src/base/base.py index e4e693b..1bd329b 100644 --- a/src/coordinax/_src/base/base.py +++ b/src/coordinax/_src/base/base.py @@ -16,7 +16,7 @@ import jax import numpy as np -from astropy.units import PhysicalType as Dimensions +from astropy.units import PhysicalType as Dimension from jax import Device, tree from jaxtyping import Array, ArrayLike, Bool from plum import dispatch @@ -24,15 +24,8 @@ import quaxed.lax as qlax import quaxed.numpy as jnp +import unxt as u from dataclassish import field_items, field_values, replace -from unxt import ( - Quantity, - dimensions, - dimensions_of, - uconvert, - units_of, - unitsystem, -) from unxt.quantity import AbstractQuantity from .mixins import IPythonReprMixin @@ -81,10 +74,11 @@ def from_( Examples -------- >>> import jax.numpy as jnp - >>> from unxt import Quantity + >>> import unxt as u >>> import coordinax as cx - >>> xs = {"x": Quantity(1, "m"), "y": Quantity(2, "m"), "z": Quantity(3, "m")} + >>> xs = {"x": u.Quantity(1, "m"), "y": u.Quantity(2, "m"), + ... "z": u.Quantity(3, "m")} >>> vec = cx.CartesianPos3D.from_(xs) >>> vec CartesianPos3D( @@ -93,8 +87,8 @@ def from_( z=Quantity[...](value=f32[], unit=Unit("m")) ) - >>> xs = {"x": Quantity([1, 2], "m"), "y": Quantity([3, 4], "m"), - ... "z": Quantity([5, 6], "m")} + >>> xs = {"x": u.Quantity([1, 2], "m"), "y": u.Quantity([3, 4], "m"), + ... "z": u.Quantity([5, 6], "m")} >>> vec = cx.CartesianPos3D.from_(xs) >>> vec CartesianPos3D( @@ -148,7 +142,7 @@ def from_( Quantity['length'](Array([1., 4.], dtype=float32), unit='m') """ - obj = Quantity.from_(jnp.asarray(obj), unit) + obj = u.Quantity.from_(jnp.asarray(obj), unit) return cls.from_(obj) # re-dispatch # =============================================================== @@ -434,7 +428,7 @@ def __len__(self) -> int: """ return self.shape[0] if self.ndim > 0 else 0 - def __abs__(self) -> Quantity: + def __abs__(self) -> u.Quantity: """Return the norm of the vector. Examples @@ -696,7 +690,7 @@ def components(cls) -> tuple[str, ...]: @property def units(self) -> MappingProxyType[str, Unit]: """Get the units of the vector's components.""" - return MappingProxyType({k: units_of(v) for k, v in field_items(self)}) + return MappingProxyType({k: u.unit_of(v) for k, v in field_items(self)}) @property def dtypes(self) -> MappingProxyType[str, jnp.dtype]: @@ -750,11 +744,10 @@ def to_units(self, usys: Any, /) -> "AbstractVector": Examples -------- - >>> import astropy.units as u >>> from unxt import Quantity, unitsystem >>> import coordinax as cx - >>> usys = unitsystem(u.m, u.s, u.kg, u.rad) + >>> usys = unitsystem("m", "s", "kg", "rad") >>> vec = cx.CartesianPos3D.from_([1, 2, 3], "km") >>> vec.to_units(usys) @@ -765,21 +758,21 @@ def to_units(self, usys: Any, /) -> "AbstractVector": ) """ - usys = unitsystem(usys) + usys = u.unitsystem(usys) return replace( self, - **{k: uconvert(usys[dimensions_of(v)], v) for k, v in field_items(self)}, + **{k: u.uconvert(usys[u.dimension_of(v)], v) for k, v in field_items(self)}, ) @dispatch def to_units( - self: "AbstractVector", usys: Mapping[Dimensions | str, Unit | str], / + self: "AbstractVector", usys: Mapping[Dimension | str, Unit | str], / ) -> "AbstractVector": """Convert the vector to the given units. Parameters ---------- - usys : Mapping[Dimensions | str, Unit | str] + usys : Mapping[Dimension | str, Unit | str] The units to convert to according to the physical type of the components. @@ -809,11 +802,14 @@ def to_units( """ # Ensure `units_` is PT -> Unit - units_ = {dimensions(k): v for k, v in usys.items()} + units_ = {u.dimension(k): v for k, v in usys.items()} # Convert to the given units return replace( self, - **{k: uconvert(units_[dimensions_of(v)], v) for k, v in field_items(self)}, + **{ + k: u.uconvert(units_[u.dimension_of(v)], v) + for k, v in field_items(self) + }, ) @dispatch @@ -859,14 +855,14 @@ def to_units( dim2unit = {} units_ = {} for k, v in field_items(self): - pt = dimensions_of(v) + pt = u.dimension_of(v) if pt not in dim2unit: - dim2unit[pt] = units_of(v) + dim2unit[pt] = u.unit_of(v) units_[k] = dim2unit[pt] return replace( self, - **{k: uconvert(units_[k], v) for k, v in field_items(self)}, + **{k: u.uconvert(units_[k], v) for k, v in field_items(self)}, ) # =============================================================== diff --git a/src/coordinax/_src/distance/core.py b/src/coordinax/_src/distance/core.py index 7d73784..3e266bc 100644 --- a/src/coordinax/_src/distance/core.py +++ b/src/coordinax/_src/distance/core.py @@ -5,19 +5,19 @@ from dataclasses import KW_ONLY from typing import Any, final -import astropy.units as u import equinox as eqx import jax.numpy as jnp import quaxed.numpy as jnp -from unxt import Quantity, dimensions_of, ustrip +import unxt as u +from unxt import Quantity, dimension, dimension_of, ustrip from .base import AbstractDistance -parallax_base_length = Quantity(1, "AU") -distance_modulus_base_distance = Quantity(10, "pc") -angle_dimension = u.get_physical_type("angle") -length_dimension = u.get_physical_type("length") +parallax_base_length = u.Quantity(1, "AU") +distance_modulus_base_distance = u.Quantity(10, "pc") +angle_dimension = dimension("angle") +length_dimension = dimension("length") ############################################################################## @@ -45,7 +45,7 @@ class Distance(AbstractDistance): def __check_init__(self) -> None: """Check the initialization.""" - if dimensions_of(self) != length_dimension: + if dimension_of(self) != length_dimension: msg = "Distance must have dimensions length." raise ValueError(msg) @@ -65,7 +65,7 @@ def distance(self) -> "Distance": @property def parallax( # noqa: PLR0206 (needed for quax boundary) - self, base_length: Quantity["length"] = parallax_base_length + self, base_length: u.Quantity["length"] = parallax_base_length ) -> "Parallax": r"""The parallax of the distance. @@ -142,7 +142,7 @@ class Parallax(AbstractDistance): def __check_init__(self) -> None: """Check the initialization.""" - if dimensions_of(self) != angle_dimension: + if dimension_of(self) != angle_dimension: msg = "Parallax must have angular dimensions." raise ValueError(msg) @@ -223,7 +223,7 @@ class DistanceModulus(AbstractDistance): def __check_init__(self) -> None: """Check the initialization.""" - if self.unit != u.mag: + if self.unit != u.unit("mag"): msg = "Distance modulus must have units of magnitude." raise ValueError(msg) diff --git a/src/coordinax/_src/distance/register_primitives.py b/src/coordinax/_src/distance/register_primitives.py index 067f75e..8e42741 100644 --- a/src/coordinax/_src/distance/register_primitives.py +++ b/src/coordinax/_src/distance/register_primitives.py @@ -6,22 +6,21 @@ from collections.abc import Callable from typing import Any, TypeVar -from astropy.units import ( # pylint: disable=no-name-in-module - dimensionless_unscaled as one, - radian, -) from jax import lax from jax.core import Primitive from jaxtyping import ArrayLike from quax import register as register_ -from unxt import Quantity, is_unit_convertible, ustrip +import unxt as u from unxt.quantity import AbstractQuantity from .base import AbstractDistance T = TypeVar("T") +one = u.unit("") +radian = u.unit("radian") + def register(primitive: Primitive, **kwargs: Any) -> Callable[[T], T]: """`quax.register`, but makes mypy happy.""" @@ -34,7 +33,7 @@ def register(primitive: Primitive, **kwargs: Any) -> Callable[[T], T]: # TODO: can this be done with promotion/conversion instead? @register(lax.cbrt_p) -def _cbrt_p_d(x: AbstractDistance) -> Quantity: +def _cbrt_p_d(x: AbstractDistance) -> u.Quantity: """Cube root of a distance. Examples @@ -46,7 +45,7 @@ def _cbrt_p_d(x: AbstractDistance) -> Quantity: Quantity['m1/3'](Array(2., dtype=float32, ...), unit='m(1/3)') """ - return Quantity(lax.cbrt(x.value), unit=x.unit ** (1 / 3)) + return u.Quantity(lax.cbrt(x.value), unit=x.unit ** (1 / 3)) # ============================================================================== @@ -55,7 +54,7 @@ def _cbrt_p_d(x: AbstractDistance) -> Quantity: @register(lax.dot_general_p) def _dot_general_dd( lhs: AbstractDistance, rhs: AbstractDistance, /, **kwargs: Any -) -> Quantity: +) -> u.Quantity: """Dot product of two Distances. Examples @@ -63,6 +62,7 @@ def _dot_general_dd( This is a dot product of two Distances. >>> import quaxed.numpy as jnp + >>> import unxt as u >>> from coordinax.distance import Distance >>> q1 = Distance([1, 2, 3], "m") @@ -77,7 +77,7 @@ def _dot_general_dd( >>> Rz = jnp.asarray([[0, -1, 0], ... [1, 0, 0], ... [0, 0, 1]]) - >>> q = Quantity([1, 0, 0], "m") + >>> q = u.Quantity([1, 0, 0], "m") >>> Rz @ q Quantity['length'](Array([0, 1, 0], dtype=int32), unit='m') @@ -87,7 +87,7 @@ def _dot_general_dd( Quantity['length'](Array([0, 1, 0], dtype=int32), unit='m') """ - return Quantity( + return u.Quantity( lax.dot_general_p.bind(lhs.value, rhs.value, **kwargs), unit=lhs.unit * rhs.unit, ) @@ -97,7 +97,7 @@ def _dot_general_dd( @register(lax.integer_pow_p) -def _integer_pow_p_d(x: AbstractDistance, *, y: Any) -> Quantity: +def _integer_pow_p_d(x: AbstractDistance, *, y: Any) -> u.Quantity: """Integer power of a Distance. Examples @@ -108,14 +108,14 @@ def _integer_pow_p_d(x: AbstractDistance, *, y: Any) -> Quantity: Quantity['volume'](Array(8, dtype=int32, ...), unit='m3') """ - return Quantity(value=lax.integer_pow(x.value, y), unit=x.unit**y) + return u.Quantity(value=lax.integer_pow(x.value, y), unit=x.unit**y) # ============================================================================== @register(lax.pow_p) -def _pow_p_d(x: AbstractDistance, y: ArrayLike) -> Quantity: +def _pow_p_d(x: AbstractDistance, y: ArrayLike) -> u.Quantity: """Power of a Distance by redispatching to Quantity. Examples @@ -129,14 +129,14 @@ def _pow_p_d(x: AbstractDistance, y: ArrayLike) -> Quantity: Quantity['volume'](Array(1000., dtype=float32, ...), unit='m3') """ - return Quantity(x.value, x.unit) ** y # TODO: better call to power + return u.Quantity(x.value, x.unit) ** y # TODO: better call to power # ============================================================================== @register(lax.sqrt_p) -def _sqrt_p_d(x: AbstractDistance) -> Quantity: +def _sqrt_p_d(x: AbstractDistance) -> u.Quantity: """Square root of a quantity. Examples @@ -155,17 +155,17 @@ def _sqrt_p_d(x: AbstractDistance) -> Quantity: """ # Promote to something that supports sqrt units. - return Quantity(lax.sqrt(x.value), unit=x.unit ** (1 / 2)) + return u.Quantity(lax.sqrt(x.value), unit=x.unit ** (1 / 2)) # ============================================================================== def _to_value_rad_or_one(q: AbstractQuantity) -> ArrayLike: - return ustrip(radian if is_unit_convertible(q.unit, radian) else one, q) + return u.ustrip(radian if u.is_unit_convertible(q.unit, radian) else one, q) # TODO: figure out a promotion alternative that works in general @register(lax.tan_p) -def _tan_p_d(x: AbstractDistance) -> Quantity["dimensionless"]: - return Quantity(lax.tan(_to_value_rad_or_one(x)), unit=one) +def _tan_p_d(x: AbstractDistance) -> u.Quantity["dimensionless"]: + return u.Quantity(lax.tan(_to_value_rad_or_one(x)), unit=one) diff --git a/src/coordinax/_src/operators/galilean/rotation.py b/src/coordinax/_src/operators/galilean/rotation.py index 57eaef4..51c9ae0 100644 --- a/src/coordinax/_src/operators/galilean/rotation.py +++ b/src/coordinax/_src/operators/galilean/rotation.py @@ -72,7 +72,7 @@ class GalileanRotationOperator(AbstractGalileanOperator): We start with the required imports: >>> import jax.numpy as jnp - >>> from unxt import Quantity + >>> import unxt as u >>> import coordinax as cx >>> import coordinax.operators as co @@ -88,8 +88,8 @@ class GalileanRotationOperator(AbstractGalileanOperator): Translation operators can be applied to a Quantity[float, (N, 3), "...]: - >>> q = Quantity([1, 0, 0], "m") - >>> t = Quantity(1, "s") + >>> q = u.Quantity([1, 0, 0], "m") + >>> t = u.Quantity(1, "s") >>> newq, newt = op(q, t) >>> newq Quantity['length'](Array([0.70710677, 0.70710677, 0. ], dtype=float32), unit='m') @@ -101,8 +101,8 @@ class GalileanRotationOperator(AbstractGalileanOperator): This also works for a batch of vectors: - >>> q = Quantity([[1, 0, 0], [0, 1, 0]], "m") - >>> t = Quantity(0, "s") + >>> q = u.Quantity([[1, 0, 0], [0, 1, 0]], "m") + >>> t = u.Quantity(0, "s") >>> newq, newt = op(q, t) >>> newq @@ -197,17 +197,17 @@ def __call__( Examples -------- >>> import quaxed.numpy as jnp - >>> from unxt import Quantity + >>> import unxt as u >>> from coordinax.operators import GalileanRotationOperator - >>> theta = Quantity(45, "deg") + >>> theta = u.Quantity(45, "deg") >>> Rz = jnp.asarray([[jnp.cos(theta), -jnp.sin(theta), 0], ... [jnp.sin(theta), jnp.cos(theta), 0], ... [0, 0, 1]]) >>> op = GalileanRotationOperator(Rz) - >>> q = Quantity([1, 0, 0], "m") - >>> t = Quantity(1, "s") + >>> q = u.Quantity([1, 0, 0], "m") + >>> t = u.Quantity(1, "s") >>> newq, newt = op(q, t) >>> newq Quantity[...](Array([0.70710677, 0.70710677, 0. ], dtype=float32), unit='m') diff --git a/src/coordinax/_src/space.py b/src/coordinax/_src/space.py index 70c6c36..d256698 100644 --- a/src/coordinax/_src/space.py +++ b/src/coordinax/_src/space.py @@ -8,7 +8,6 @@ from typing import Any, TypeAlias, final from typing_extensions import override -import astropy.units as u import equinox as eqx import jax from astropy.units import PhysicalType as Dimension @@ -16,7 +15,7 @@ from plum import dispatch import quaxed.numpy as jnp -from unxt import Quantity, dimensions +import unxt as u from xmmutablemap import ImmutableMap from .base import AbstractAcc, AbstractPos, AbstractVector, AbstractVel @@ -28,7 +27,7 @@ def _get_dimension_name(dim: DimensionLike, /) -> str: - return dimensions(dim)._physical_type_list[0] # noqa: SLF001 + return u.dimension(dim)._physical_type_list[0] # noqa: SLF001 def _can_broadcast_shapes(*shapes: tuple[int, ...]) -> bool: @@ -62,7 +61,6 @@ class Space(AbstractVector, ImmutableMap[Dimension, AbstractVector]): # type: i Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> x = cx.CartesianPos3D.from_([1, 2, 3], "km") >>> v = cx.CartesianVel3D.from_([4, 5, 6], "km/s") @@ -220,8 +218,8 @@ def __getitem__(self, key: str | Dimension) -> Any: By dimension: - >>> import astropy.units as u - >>> w[u.get_physical_type("length")] + >>> import unxt as u + >>> w[u.dimension("length")] CartesianPos3D( x=Quantity[...](value=f32[1,2], unit=Unit("m")), y=Quantity[...](value=f32[1,2], unit=Unit("m")), @@ -276,7 +274,7 @@ def mT(self) -> "Self": # noqa: N802 @property def ndim(self) -> int: - """Number of array dimensions (axes). + """Number of array dimension (axes). Examples -------- @@ -299,7 +297,6 @@ def shape(self) -> tuple[int, ...]: Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -319,7 +316,6 @@ def size(self) -> int: Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -338,7 +334,6 @@ def T(self) -> "Self": # noqa: N802 Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -370,7 +365,6 @@ def __neg__(self) -> "Self": Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -389,7 +383,6 @@ def __repr__(self) -> str: Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> q = cx.CartesianPos3D.from_([1, 2, 3], "m") >>> p = cx.CartesianVel3D.from_([1, 2, 3], "m/s") @@ -422,8 +415,8 @@ def __str__(self) -> str: def asdict( self, *, - dict_factory: Callable[[Any], Mapping[str, Quantity]] = dict, - ) -> Mapping[str, Quantity]: + dict_factory: Callable[[Any], Mapping[str, u.Quantity]] = dict, + ) -> Mapping[str, u.Quantity]: """Return the vector as a Mapping. Parameters @@ -464,7 +457,6 @@ def dtypes(self) -> MappingProxyType[str, MappingProxyType[str, jnp.dtype]]: Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -486,7 +478,6 @@ def devices(self) -> MappingProxyType[str, MappingProxyType[str, Device]]: Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -508,7 +499,6 @@ def shapes(self) -> MappingProxyType[str, MappingProxyType[str, tuple[int, ...]] Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -528,7 +518,6 @@ def sizes(self) -> MappingProxyType[str, int]: Examples -------- >>> import coordinax as cx - >>> from unxt import Quantity >>> w = cx.Space( ... length=cx.CartesianPos3D.from_([[[1, 2, 3], [4, 5, 6]]], "m"), @@ -552,7 +541,7 @@ def represent_as(self, target: type[AbstractVector], /) -> "Space": # pylint: d @dispatch # type: ignore[misc] @override def to_units( - self: "Space", units: Mapping[u.PhysicalType | str, Unit | str], / + self: "Space", units: Mapping[Dimension | str, Unit | str], / ) -> "Space": """Convert the vector to the given units.""" raise NotImplementedError diff --git a/src/coordinax/_src/transform/d2.py b/src/coordinax/_src/transform/d2.py index 3bfe984..d33a1ec 100644 --- a/src/coordinax/_src/transform/d2.py +++ b/src/coordinax/_src/transform/d2.py @@ -5,7 +5,6 @@ from typing import Any from warnings import warn -import astropy.units as u from plum import dispatch import quaxed.numpy as jnp @@ -33,7 +32,7 @@ def represent_as( current: AbstractPos2D, target: type[AbstractPos3D], /, - z: AbstractQuantity = Quantity(0.0, u.m), + z: AbstractQuantity = Quantity(0.0, "m"), **kwargs: Any, ) -> AbstractPos3D: """AbstractPos2D -> Cartesian2D -> Cartesian3D -> AbstractPos3D. @@ -85,7 +84,7 @@ def represent_as( current: AbstractPos2D, target: type[AbstractPos3D], /, - z: AbstractQuantity = Quantity(0.0, u.m), + z: AbstractQuantity = Quantity(0.0, "m"), **kwargs: Any, ) -> AbstractPos3D: """AbstractPos2D -> PolarPos -> Cylindrical -> AbstractPos3D. @@ -189,7 +188,7 @@ def represent_as( target: type[CartesianPos3D], /, *, - z: AbstractQuantity = Quantity(0.0, u.m), + z: AbstractQuantity = Quantity(0.0, "m"), **kwargs: Any, ) -> CartesianPos3D: """CartesianPos2D -> CartesianPos3D. @@ -286,7 +285,7 @@ def represent_as( current: PolarPos, target: type[SphericalPos], /, - theta: Quantity["angle"] = Quantity(0.0, u.radian), # type: ignore[name-defined] + theta: Quantity["angle"] = Quantity(0.0, "radian"), **kwargs: Any, ) -> SphericalPos: """PolarPos -> SphericalPos. @@ -315,7 +314,7 @@ def represent_as( current: PolarPos, target: type[MathSphericalPos], /, - phi: Quantity["angle"] = Quantity(0.0, u.radian), # type: ignore[name-defined] + phi: Quantity["angle"] = Quantity(0.0, "radian"), **kwargs: Any, ) -> MathSphericalPos: """PolarPos -> MathSphericalPos. @@ -345,7 +344,7 @@ def represent_as( target: type[CylindricalPos], /, *, - z: Quantity["length"] = Quantity(0.0, u.m), # type: ignore[name-defined] + z: Quantity["length"] = Quantity(0.0, "m"), **kwargs: Any, ) -> CylindricalPos: """PolarPos -> CylindricalPos. diff --git a/src/coordinax/_src/typing.py b/src/coordinax/_src/typing.py index 89d893c..b163484 100644 --- a/src/coordinax/_src/typing.py +++ b/src/coordinax/_src/typing.py @@ -4,7 +4,11 @@ from typing import TypeAlias -import astropy.units as u +from astropy.units import ( + CompositeUnit as AstropyCompositeUnit, + Unit as AstropyUnit, + UnitBase as AstropyUnitBase, +) from jaxtyping import Float, Int, Shaped from unxt.quantity import AbstractQuantity, Quantity @@ -12,7 +16,7 @@ from .angle.base import AbstractAngle from .distance.base import AbstractDistance -Unit: TypeAlias = u.Unit | u.UnitBase | u.CompositeUnit +Unit: TypeAlias = AstropyUnit | AstropyUnitBase | AstropyCompositeUnit FloatScalarQ = Float[AbstractQuantity, ""] BatchFloatScalarQ = Shaped[FloatScalarQ, "*batch"] diff --git a/uv.lock b/uv.lock index ce6b1e5..4f53039 100644 --- a/uv.lock +++ b/uv.lock @@ -292,7 +292,7 @@ wheels = [ [[package]] name = "coordinax" -version = "0.13.3.dev1+g0d8da79.d20241114" +version = "0.13.4.dev0+g0cb55cf.d20241115" source = { editable = "." } dependencies = [ { name = "astropy" }, @@ -402,7 +402,7 @@ requires-dist = [ { name = "sybil", marker = "extra == 'all'", specifier = "!=7.1.0" }, { name = "sybil", marker = "extra == 'dev'", specifier = "!=7.1.0" }, { name = "sybil", marker = "extra == 'test'", specifier = "!=7.1.0" }, - { name = "unxt", specifier = ">=0.21" }, + { name = "unxt", specifier = ">=0.22" }, { name = "xmmutablemap", specifier = ">=0.1" }, ] @@ -1734,7 +1734,7 @@ wheels = [ [[package]] name = "unxt" -version = "0.21.0" +version = "0.22.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astropy" }, @@ -1751,9 +1751,9 @@ dependencies = [ { name = "xmmutablemap" }, { name = "zeroth" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7b/1f/3af871c97032f9eb87de19d6765db70bffad73e5ea9fbc07aed767ce61c3/unxt-0.21.0.tar.gz", hash = "sha256:0e5a1d189e692238af9a5ca227cca8ed7e42c01efedf87b448db4b3edccdc899", size = 659410 } +sdist = { url = "https://files.pythonhosted.org/packages/75/11/a66511eb9b6e8510378d59c9c3a5f75c68da33d1ef5fdaff2b22078d7388/unxt-0.22.0.tar.gz", hash = "sha256:6f49cd0621d6a512a17bc89321f2b315de5f839c2d202498376b46f6a3605666", size = 659642 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d8/e7/0c5355f26c6f0cf34965d8d284872598f9b082e163f3e1d2f5611d1c7243/unxt-0.21.0-py3-none-any.whl", hash = "sha256:25c35c9e08bd7eb68cd5eeee6067d09e3e7b513cc990ad10ab55e1c05f38cc74", size = 56453 }, + { url = "https://files.pythonhosted.org/packages/8e/af/82046f24cd6c66a8167af7b8d7cd078e725bde4e37047dfc564326e66b71/unxt-0.22.0-py3-none-any.whl", hash = "sha256:21ea3e0288b2d9a3286a889895bc80e181ae1f637034cfca3a87fba40cbe9ad7", size = 56961 }, ] [[package]]