From 30173aacd582385d32f68641939f100e96f9b379 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Wed, 17 Jul 2024 01:53:11 -0400 Subject: [PATCH] build: add dataclasstools (#138) * build: add dataclasstools Signed-off-by: nstarman --- pyproject.toml | 3 +- src/coordinax/_base.py | 3 +- src/coordinax/_base_acc.py | 3 +- src/coordinax/_base_pos.py | 3 +- src/coordinax/_base_vel.py | 3 +- src/coordinax/_d1/compat.py | 3 +- src/coordinax/_d2/compat.py | 3 +- src/coordinax/_d3/compat.py | 3 +- src/coordinax/_transform/accelerations.py | 2 +- src/coordinax/_transform/differentials.py | 2 +- src/coordinax/_utils.py | 52 ++--------------------- src/coordinax/operators/_base.py | 2 +- src/coordinax/operators/_composite.py | 4 +- tests/test_base.py | 2 +- tests/test_jax_ops.py | 2 +- 15 files changed, 27 insertions(+), 63 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4661f63..b405344 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ ] dependencies = [ "astropy", + "dataclasstools @ git+https://github.com/GalacticDynamics/dataclasstools.git", "equinox", "immutable_map_jax @ git+https://github.com/GalacticDynamics/immutable_map_jax.git", "jax", @@ -193,7 +194,7 @@ [tool.ruff.lint.isort] combine-as-imports = true extra-standard-library = ["typing_extensions"] - known-first-party = ["quaxed", "unxt"] + known-first-party = ["dataclasstools", "quaxed", "unxt"] known-local-folder = ["coordinax"] diff --git a/src/coordinax/_base.py b/src/coordinax/_base.py index c7ef524..e80fee4 100644 --- a/src/coordinax/_base.py +++ b/src/coordinax/_base.py @@ -23,9 +23,10 @@ from plum import dispatch import quaxed.array_api as xp +from dataclasstools import field_items, field_values, replace from unxt import Quantity, unitsystem -from ._utils import classproperty, field_items, field_values, full_shaped, replace +from ._utils import classproperty, full_shaped from coordinax._typing import Unit if TYPE_CHECKING: diff --git a/src/coordinax/_base_acc.py b/src/coordinax/_base_acc.py index 486a778..ec25611 100644 --- a/src/coordinax/_base_acc.py +++ b/src/coordinax/_base_acc.py @@ -11,12 +11,13 @@ import jax from plum import dispatch +from dataclasstools import field_items from unxt import Quantity from ._base import AbstractVector from ._base_pos import AbstractPosition from ._base_vel import AbstractVelocity -from ._utils import classproperty, field_items +from ._utils import classproperty if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_base_pos.py b/src/coordinax/_base_pos.py index 5ab8a55..96ffeb5 100644 --- a/src/coordinax/_base_pos.py +++ b/src/coordinax/_base_pos.py @@ -13,10 +13,11 @@ import jax from jaxtyping import ArrayLike +from dataclasstools import field_items from unxt import Quantity from ._base import AbstractVector -from ._utils import classproperty, field_items +from ._utils import classproperty if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_base_vel.py b/src/coordinax/_base_vel.py index 6561f73..ef53069 100644 --- a/src/coordinax/_base_vel.py +++ b/src/coordinax/_base_vel.py @@ -10,11 +10,12 @@ import jax +from dataclasstools import field_items from unxt import Quantity from ._base import AbstractVector from ._base_pos import AbstractPosition -from ._utils import classproperty, field_items +from ._utils import classproperty if TYPE_CHECKING: from typing_extensions import Self diff --git a/src/coordinax/_d1/compat.py b/src/coordinax/_d1/compat.py index 58cee46..5b211e5 100644 --- a/src/coordinax/_d1/compat.py +++ b/src/coordinax/_d1/compat.py @@ -7,11 +7,12 @@ from plum import conversion_method import quaxed.array_api as xp +from dataclasstools import field_values from unxt import Quantity from .base import AbstractPosition1D from .cartesian import CartesianAcceleration1D, CartesianPosition1D, CartesianVelocity1D -from coordinax._utils import field_values, full_shaped +from coordinax._utils import full_shaped @conversion_method(type_from=AbstractPosition1D, type_to=Quantity) # type: ignore[misc] diff --git a/src/coordinax/_d2/compat.py b/src/coordinax/_d2/compat.py index 9f6229e..7460b0d 100644 --- a/src/coordinax/_d2/compat.py +++ b/src/coordinax/_d2/compat.py @@ -7,11 +7,12 @@ from plum import conversion_method import quaxed.array_api as xp +from dataclasstools import field_values from unxt import Quantity from .base import AbstractPosition2D from .cartesian import CartesianPosition2D, CartesianVelocity2D -from coordinax._utils import field_values, full_shaped +from coordinax._utils import full_shaped ##################################################################### # Quantity diff --git a/src/coordinax/_d3/compat.py b/src/coordinax/_d3/compat.py index 8b82dc6..a70fcd6 100644 --- a/src/coordinax/_d3/compat.py +++ b/src/coordinax/_d3/compat.py @@ -7,11 +7,12 @@ from plum import conversion_method import quaxed.array_api as xp +from dataclasstools import field_values from unxt import Quantity from .base import AbstractPosition3D from .cartesian import CartesianAcceleration3D, CartesianPosition3D, CartesianVelocity3D -from coordinax._utils import field_values, full_shaped +from coordinax._utils import full_shaped ##################################################################### # Quantity diff --git a/src/coordinax/_transform/accelerations.py b/src/coordinax/_transform/accelerations.py index f57f9c2..1e4bd5b 100644 --- a/src/coordinax/_transform/accelerations.py +++ b/src/coordinax/_transform/accelerations.py @@ -10,6 +10,7 @@ from plum import dispatch import quaxed.array_api as xp +from dataclasstools import field_items from unxt import AbstractDistance, Quantity from coordinax._base_acc import AbstractAcceleration @@ -18,7 +19,6 @@ from coordinax._d1.base import AbstractAcceleration1D from coordinax._d2.base import AbstractAcceleration2D from coordinax._d3.base import AbstractAcceleration3D -from coordinax._utils import field_items # TODO: implement for cross-representations diff --git a/src/coordinax/_transform/differentials.py b/src/coordinax/_transform/differentials.py index 7c68c53..50a2654 100644 --- a/src/coordinax/_transform/differentials.py +++ b/src/coordinax/_transform/differentials.py @@ -10,6 +10,7 @@ from plum import dispatch import quaxed.array_api as xp +from dataclasstools import field_items from unxt import AbstractDistance, Quantity from coordinax._base_pos import AbstractPosition @@ -17,7 +18,6 @@ from coordinax._d1.base import AbstractVelocity1D from coordinax._d2.base import AbstractVelocity2D from coordinax._d3.base import AbstractVelocity3D -from coordinax._utils import field_items # TODO: implement for cross-representations diff --git a/src/coordinax/_utils.py b/src/coordinax/_utils.py index 05f63ae..1e436d9 100644 --- a/src/coordinax/_utils.py +++ b/src/coordinax/_utils.py @@ -2,61 +2,17 @@ __all__: list[str] = [] -from collections.abc import Callable, Iterator -from dataclasses import dataclass, fields, replace as _dataclass_replace -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Generic, - Protocol, - TypeVar, - runtime_checkable, -) - -from plum import dispatch +from collections.abc import Callable +from dataclasses import dataclass, replace as _dataclass_replace +from typing import TYPE_CHECKING, Generic, TypeVar import quaxed.array_api as xp +from dataclasstools import field_values if TYPE_CHECKING: from coordinax._base import AbstractVector -################################################################################ - - -@runtime_checkable -class DataclassInstance(Protocol): - """Protocol for dataclass instances.""" - - __dataclass_fields__: ClassVar[dict[str, Any]] - - # B/c of https://github.com/python/mypy/issues/3939 just having - # `__dataclass_fields__` is insufficient for `issubclass` checks. - @classmethod - def __subclasshook__(cls: type, c: type) -> bool: - """Customize the subclass check.""" - return hasattr(c, "__dataclass_fields__") - - -@dispatch # type: ignore[misc] -def replace(obj: DataclassInstance, /, **kwargs: Any) -> DataclassInstance: - """Replace the fields of a dataclass instance.""" - return _dataclass_replace(obj, **kwargs) - - -@dispatch # type: ignore[misc] -def field_values(obj: DataclassInstance) -> Iterator[Any]: - """Return the values of a dataclass instance.""" - yield from (getattr(obj, f.name) for f in fields(obj)) - - -@dispatch # type: ignore[misc] -def field_items(obj: DataclassInstance) -> Iterator[tuple[str, Any]]: - """Return the field names and values of a dataclass instance.""" - yield from ((f.name, getattr(obj, f.name)) for f in fields(obj)) - - def full_shaped(obj: "AbstractVector", /) -> "AbstractVector": """Return the vector, fully broadcasting all components.""" arrays = xp.broadcast_arrays(*field_values(obj)) diff --git a/src/coordinax/operators/_base.py b/src/coordinax/operators/_base.py index 55b326d..e57914d 100644 --- a/src/coordinax/operators/_base.py +++ b/src/coordinax/operators/_base.py @@ -9,10 +9,10 @@ import equinox as eqx from plum import dispatch +from dataclasstools import field_items from unxt import Quantity from coordinax._base_pos import AbstractPosition -from coordinax._utils import field_items if TYPE_CHECKING: from ._sequential import OperatorSequence diff --git a/src/coordinax/operators/_composite.py b/src/coordinax/operators/_composite.py index de32ca5..58fb1fb 100644 --- a/src/coordinax/operators/_composite.py +++ b/src/coordinax/operators/_composite.py @@ -6,18 +6,18 @@ from dataclasses import replace from typing import TYPE_CHECKING, Protocol, overload, runtime_checkable +from dataclasstools import DataclassInstance from unxt import Quantity from ._base import AbstractOperator, op_call_dispatch from coordinax._base_pos import AbstractPosition -from coordinax._utils import DataclassInstance if TYPE_CHECKING: from typing_extensions import Self @runtime_checkable -class HasOperatorsAttr(DataclassInstance, Protocol): +class HasOperatorsAttr(DataclassInstance, Protocol): # type: ignore[misc] """Protocol for classes with an `operators` attribute.""" operators: tuple[AbstractOperator, ...] diff --git a/tests/test_base.py b/tests/test_base.py index d3841c7..66628a7 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -11,6 +11,7 @@ import quaxed.array_api as xp import quaxed.numpy as qnp +from dataclasstools import field_items from unxt import AbstractQuantity from coordinax import ( @@ -38,7 +39,6 @@ SphericalPosition, SphericalVelocity, ) -from coordinax._utils import field_items BUILTIN_VECTORS = [ # 1D diff --git a/tests/test_jax_ops.py b/tests/test_jax_ops.py index 2d2f5a1..ab1fb08 100644 --- a/tests/test_jax_ops.py +++ b/tests/test_jax_ops.py @@ -6,11 +6,11 @@ import jax import pytest +from dataclasstools import field_items from unxt import AbstractQuantity, Quantity import coordinax as cx from coordinax._base_pos import VECTOR_CLASSES -from coordinax._utils import field_items VECTOR_CLASSES_3D = [c for c in VECTOR_CLASSES if issubclass(c, cx.AbstractPosition3D)]