Skip to content

Commit

Permalink
refactor: rearrange interop (#171)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Aug 29, 2024
1 parent 30e3b84 commit bf4acf2
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 94 deletions.
94 changes: 1 addition & 93 deletions src/coordinax/_coordinax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import TYPE_CHECKING, Any, Literal, NoReturn, TypeVar

import astropy.units as u
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -807,7 +806,7 @@ def __str__(self) -> str:


# TODO: move to the class in py3.11+
@AbstractVector.constructor._f.dispatch # noqa: SLF001
@AbstractVector.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVector:
"""Construct a vector from another vector.
Expand Down Expand Up @@ -888,94 +887,3 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe
return obj

return cls(**dict(field_items(obj)))


@AbstractVector.constructor._f.dispatch # noqa: SLF001
def constructor(
cls: type[AbstractVector], obj: Mapping[str, u.Quantity], /
) -> AbstractVector:
"""Construct a vector from a mapping.
Parameters
----------
cls : type[AbstractVector]
The vector class.
obj : Mapping[str, `astropy.units.Quantity`]
The mapping of components.
Examples
--------
>>> import jax.numpy as jnp
>>> from astropy.units import Quantity
>>> import coordinax as cx
>>> xs = {"x": Quantity(1, "m"), "y": Quantity(2, "m"), "z": Quantity(3, "m")}
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
)
>>> xs = {"x": Quantity([1, 2], "m"), "y": Quantity([3, 4], "m"),
... "z": Quantity([5, 6], "m")}
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m"))
)
"""
return cls(**obj)


# TODO: move to the class in py3.11+
@AbstractVector.constructor._f.dispatch # noqa: SLF001
def constructor(cls: type[AbstractVector], obj: u.Quantity, /) -> AbstractVector:
"""Construct a vector from an Astropy Quantity array.
The array is expected to have the components as the last dimension.
Parameters
----------
cls : type[AbstractVector]
The vector class.
obj : Quantity[Any, (*#batch, N), "..."]
The array of components.
Examples
--------
>>> import jax.numpy as jnp
>>> from astropy.units import Quantity
>>> import coordinax as cx
>>> xs = Quantity([1, 2, 3], "meter")
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
)
>>> xs = Quantity(jnp.array([[1, 2, 3], [4, 5, 6]]), "meter")
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m"))
)
>>> vec.x
Quantity['length'](Array([1., 4.], dtype=float32), unit='m')
"""
_ = eqx.error_if(
obj,
obj.shape[-1] != len(fields(cls)),
f"Cannot construct {cls} from array with shape {obj.shape}.",
)
return cls(**{f.name: obj[..., i] for i, f in enumerate(fields(cls))})
96 changes: 95 additions & 1 deletion src/coordinax/_interop/coordinax_interop_astropy/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

__all__: list[str] = []

from collections.abc import Mapping
from dataclasses import fields

import astropy.coordinates as apyc
import astropy.units as u
Expand All @@ -14,6 +16,51 @@
#####################################################################


@cx.AbstractVector.constructor._f.dispatch # noqa: SLF001
def constructor(
cls: type[cx.AbstractVector], obj: Mapping[str, u.Quantity], /
) -> cx.AbstractVector:
"""Construct a vector from a mapping.
Parameters
----------
cls : type[AbstractVector]
The vector class.
obj : Mapping[str, `astropy.units.Quantity`]
The mapping of components.
Examples
--------
>>> import jax.numpy as jnp
>>> from astropy.units import Quantity
>>> import coordinax as cx
>>> xs = {"x": Quantity(1, "m"), "y": Quantity(2, "m"), "z": Quantity(3, "m")}
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
)
>>> xs = {"x": Quantity([1, 2], "m"), "y": Quantity([3, 4], "m"),
... "z": Quantity([5, 6], "m")}
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m"))
)
"""
return cls(**obj)


#####################################################################


@cx.AbstractPosition3D.constructor._f.dispatch(precedence=-1) # noqa: SLF001
def constructor(
cls: type[cx.AbstractPosition3D], obj: apyc.CartesianRepresentation, /
Expand Down Expand Up @@ -438,7 +485,54 @@ def constructor(
#####################################################################


# TODO: move to the class in py3.11+
@cx.AbstractVector.constructor._f.dispatch # noqa: SLF001
def constructor(cls: type[cx.AbstractVector], obj: u.Quantity, /) -> cx.AbstractVector:
"""Construct a vector from an Astropy Quantity array.
The array is expected to have the components as the last dimension.
Parameters
----------
cls : type[AbstractVector]
The vector class.
obj : Quantity[Any, (*#batch, N), "..."]
The array of components.
Examples
--------
>>> import jax.numpy as jnp
>>> from astropy.units import Quantity
>>> import coordinax as cx
>>> xs = Quantity([1, 2, 3], "meter")
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
)
>>> xs = Quantity(jnp.array([[1, 2, 3], [4, 5, 6]]), "meter")
>>> vec = cx.CartesianPosition3D.constructor(xs)
>>> vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[2], unit=Unit("m"))
)
>>> vec.x
Quantity['length'](Array([1., 4.], dtype=float32), unit='m')
"""
_ = eqx.error_if(
obj,
obj.shape[-1] != len(fields(cls)),
f"Cannot construct {cls} from array with shape {obj.shape}.",
)
return cls(**{f.name: obj[..., i] for i, f in enumerate(fields(cls))})


@cx.FourVector.constructor._f.dispatch # noqa: SLF001
def constructor(
cls: type[cx.FourVector], obj: Shaped[u.Quantity, "*batch 4"], /
Expand Down

0 comments on commit bf4acf2

Please sign in to comment.