Skip to content

Commit

Permalink
build: add dataclasstools (#138)
Browse files Browse the repository at this point in the history
* build: add dataclasstools

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Jul 17, 2024
1 parent db5d997 commit 30173aa
Show file tree
Hide file tree
Showing 15 changed files with 27 additions and 63 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
]
dependencies = [
"astropy",
"dataclasstools @ git+https://github.com/GalacticDynamics/dataclasstools.git",
"equinox",
"immutable_map_jax @ git+https://github.com/GalacticDynamics/immutable_map_jax.git",
"jax",
Expand Down Expand Up @@ -193,7 +194,7 @@
[tool.ruff.lint.isort]
combine-as-imports = true
extra-standard-library = ["typing_extensions"]
known-first-party = ["quaxed", "unxt"]
known-first-party = ["dataclasstools", "quaxed", "unxt"]
known-local-folder = ["coordinax"]


Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from plum import dispatch

import quaxed.array_api as xp
from dataclasstools import field_items, field_values, replace
from unxt import Quantity, unitsystem

from ._utils import classproperty, field_items, field_values, full_shaped, replace
from ._utils import classproperty, full_shaped
from coordinax._typing import Unit

if TYPE_CHECKING:
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import jax
from plum import dispatch

from dataclasstools import field_items
from unxt import Quantity

from ._base import AbstractVector
from ._base_pos import AbstractPosition
from ._base_vel import AbstractVelocity
from ._utils import classproperty, field_items
from ._utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import jax
from jaxtyping import ArrayLike

from dataclasstools import field_items
from unxt import Quantity

from ._base import AbstractVector
from ._utils import classproperty, field_items
from ._utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_base_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import jax

from dataclasstools import field_items
from unxt import Quantity

from ._base import AbstractVector
from ._base_pos import AbstractPosition
from ._utils import classproperty, field_items
from ._utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_d1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from plum import conversion_method

import quaxed.array_api as xp
from dataclasstools import field_values
from unxt import Quantity

from .base import AbstractPosition1D
from .cartesian import CartesianAcceleration1D, CartesianPosition1D, CartesianVelocity1D
from coordinax._utils import field_values, full_shaped
from coordinax._utils import full_shaped


@conversion_method(type_from=AbstractPosition1D, type_to=Quantity) # type: ignore[misc]
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_d2/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from plum import conversion_method

import quaxed.array_api as xp
from dataclasstools import field_values
from unxt import Quantity

from .base import AbstractPosition2D
from .cartesian import CartesianPosition2D, CartesianVelocity2D
from coordinax._utils import field_values, full_shaped
from coordinax._utils import full_shaped

#####################################################################
# Quantity
Expand Down
3 changes: 2 additions & 1 deletion src/coordinax/_d3/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from plum import conversion_method

import quaxed.array_api as xp
from dataclasstools import field_values
from unxt import Quantity

from .base import AbstractPosition3D
from .cartesian import CartesianAcceleration3D, CartesianPosition3D, CartesianVelocity3D
from coordinax._utils import field_values, full_shaped
from coordinax._utils import full_shaped

#####################################################################
# Quantity
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_transform/accelerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from plum import dispatch

import quaxed.array_api as xp
from dataclasstools import field_items
from unxt import AbstractDistance, Quantity

from coordinax._base_acc import AbstractAcceleration
Expand All @@ -18,7 +19,6 @@
from coordinax._d1.base import AbstractAcceleration1D
from coordinax._d2.base import AbstractAcceleration2D
from coordinax._d3.base import AbstractAcceleration3D
from coordinax._utils import field_items


# TODO: implement for cross-representations
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_transform/differentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from plum import dispatch

import quaxed.array_api as xp
from dataclasstools import field_items
from unxt import AbstractDistance, Quantity

from coordinax._base_pos import AbstractPosition
from coordinax._base_vel import AbstractVelocity
from coordinax._d1.base import AbstractVelocity1D
from coordinax._d2.base import AbstractVelocity2D
from coordinax._d3.base import AbstractVelocity3D
from coordinax._utils import field_items


# TODO: implement for cross-representations
Expand Down
52 changes: 4 additions & 48 deletions src/coordinax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,17 @@

__all__: list[str] = []

from collections.abc import Callable, Iterator
from dataclasses import dataclass, fields, replace as _dataclass_replace
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
Protocol,
TypeVar,
runtime_checkable,
)

from plum import dispatch
from collections.abc import Callable
from dataclasses import dataclass, replace as _dataclass_replace
from typing import TYPE_CHECKING, Generic, TypeVar

import quaxed.array_api as xp
from dataclasstools import field_values

if TYPE_CHECKING:
from coordinax._base import AbstractVector


################################################################################


@runtime_checkable
class DataclassInstance(Protocol):
"""Protocol for dataclass instances."""

__dataclass_fields__: ClassVar[dict[str, Any]]

# B/c of https://github.com/python/mypy/issues/3939 just having
# `__dataclass_fields__` is insufficient for `issubclass` checks.
@classmethod
def __subclasshook__(cls: type, c: type) -> bool:
"""Customize the subclass check."""
return hasattr(c, "__dataclass_fields__")


@dispatch # type: ignore[misc]
def replace(obj: DataclassInstance, /, **kwargs: Any) -> DataclassInstance:
"""Replace the fields of a dataclass instance."""
return _dataclass_replace(obj, **kwargs)


@dispatch # type: ignore[misc]
def field_values(obj: DataclassInstance) -> Iterator[Any]:
"""Return the values of a dataclass instance."""
yield from (getattr(obj, f.name) for f in fields(obj))


@dispatch # type: ignore[misc]
def field_items(obj: DataclassInstance) -> Iterator[tuple[str, Any]]:
"""Return the field names and values of a dataclass instance."""
yield from ((f.name, getattr(obj, f.name)) for f in fields(obj))


def full_shaped(obj: "AbstractVector", /) -> "AbstractVector":
"""Return the vector, fully broadcasting all components."""
arrays = xp.broadcast_arrays(*field_values(obj))
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/operators/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import equinox as eqx
from plum import dispatch

from dataclasstools import field_items
from unxt import Quantity

from coordinax._base_pos import AbstractPosition
from coordinax._utils import field_items

if TYPE_CHECKING:
from ._sequential import OperatorSequence
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/operators/_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
from dataclasses import replace
from typing import TYPE_CHECKING, Protocol, overload, runtime_checkable

from dataclasstools import DataclassInstance
from unxt import Quantity

from ._base import AbstractOperator, op_call_dispatch
from coordinax._base_pos import AbstractPosition
from coordinax._utils import DataclassInstance

if TYPE_CHECKING:
from typing_extensions import Self


@runtime_checkable
class HasOperatorsAttr(DataclassInstance, Protocol):
class HasOperatorsAttr(DataclassInstance, Protocol): # type: ignore[misc]
"""Protocol for classes with an `operators` attribute."""

operators: tuple[AbstractOperator, ...]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import quaxed.array_api as xp
import quaxed.numpy as qnp
from dataclasstools import field_items
from unxt import AbstractQuantity

from coordinax import (
Expand Down Expand Up @@ -38,7 +39,6 @@
SphericalPosition,
SphericalVelocity,
)
from coordinax._utils import field_items

BUILTIN_VECTORS = [
# 1D
Expand Down
2 changes: 1 addition & 1 deletion tests/test_jax_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import jax
import pytest

from dataclasstools import field_items
from unxt import AbstractQuantity, Quantity

import coordinax as cx
from coordinax._base_pos import VECTOR_CLASSES
from coordinax._utils import field_items

VECTOR_CLASSES_3D = [c for c in VECTOR_CLASSES if issubclass(c, cx.AbstractPosition3D)]

Expand Down

0 comments on commit 30173aa

Please sign in to comment.