Skip to content

Commit

Permalink
feat: register broadcasting primitive
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Sep 23, 2024
1 parent ddf8fec commit e378363
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 16 deletions.
3 changes: 3 additions & 0 deletions src/coordinax/_src/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@
from .base_acc import AbstractAcceleration
from .base_pos import AbstractPosition
from .base_vel import AbstractVelocity

# isort: split
from . import register_primitives # noqa: F401
17 changes: 1 addition & 16 deletions src/coordinax/_src/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax import Device
from jaxtyping import ArrayLike
from plum import dispatch
from quax import ArrayValue, quaxify, register
from quax import ArrayValue

import quaxed.lax as qlax
import quaxed.numpy as jnp
Expand Down Expand Up @@ -889,18 +889,3 @@ def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVe
return obj

return cls(**dict(field_items(obj)))


# ===================================================================
# Dispatches for Primitives


@register(jax.lax.convert_element_type_p) # type: ignore[misc]
def _convert_element_type_p(operand: AbstractVector, **kwargs: Any) -> AbstractVector:
"""Convert the element type of a quantity."""
# TODO: examples
convert_p = quaxify(jax.lax.convert_element_type_p.bind)
return replace(
operand,
**{k: convert_p(v, **kwargs) for k, v in field_items(operand)},
)
139 changes: 139 additions & 0 deletions src/coordinax/_src/base/register_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Representation of coordinates in different systems."""

__all__: list[str] = []

from typing import Any

import jax
from quax import quaxify, register

import quaxed.numpy as jnp
from dataclassish import field_items, replace

from .base import AbstractVector


@register(jax.lax.convert_element_type_p) # type: ignore[misc]
def _convert_element_type_p(operand: AbstractVector, **kwargs: Any) -> AbstractVector:
"""Convert the element type of a quantity."""
# TODO: examples
convert_p = quaxify(jax.lax.convert_element_type_p.bind)
return replace(
operand,
**{k: convert_p(v, **kwargs) for k, v in field_items(operand)},
)


@register(jax.lax.broadcast_in_dim_p) # type: ignore[misc]
def _broadcast_in_dim_p(
operand: AbstractVector, *, shape: tuple[int, ...], **kwargs: Any
) -> AbstractVector:
"""Broadcast in a dimension.
Examples
--------
>>> import quaxed.numpy as jnp
>>> import coordinax as cx
Cartesian 1D position, velocity, and acceleration:
>>> q = cx.CartesianPosition1D.constructor([1], "m")
>>> q.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 1)).x
Quantity['length'](Array([1.], dtype=float32), unit='m')
>>> p = cx.CartesianVelocity1D.constructor([1], "m/s")
>>> p.d_x
Quantity['speed'](Array(1, dtype=int32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 1)).d_x
Quantity['speed'](Array([1], dtype=int32), unit='m / s')
>>> a = cx.CartesianAcceleration1D.constructor([1], "m/s2")
>>> a.d2_x
Quantity['acceleration'](Array(1, dtype=int32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 1)).d2_x
Quantity['acceleration'](Array([1], dtype=int32), unit='m / s2')
Radial 1D position, velocity, and acceleration:
>>> q = cx.RadialPosition.constructor([1], "m")
>>> q.r
Distance(Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 1)).r
Distance(Array([1.], dtype=float32), unit='m')
>>> p = cx.RadialVelocity.constructor([1], "m/s")
>>> p.d_r
Quantity['speed'](Array(1, dtype=int32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 1)).d_r
Quantity['speed'](Array([1], dtype=int32), unit='m / s')
>>> a = cx.RadialAcceleration.constructor([1], "m/s2")
>>> a.d2_r
Quantity['acceleration'](Array(1, dtype=int32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 1)).d2_r
Quantity['acceleration'](Array([1], dtype=int32), unit='m / s2')
Cartesian 2D position, velocity, and acceleration:
>>> q = cx.CartesianPosition2D.constructor([1, 2], "m")
>>> q.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 2)).x
Quantity['length'](Array([1.], dtype=float32), unit='m')
>>> p = cx.CartesianVelocity2D.constructor([1, 2], "m/s")
>>> p.d_x
Quantity['speed'](Array(1., dtype=float32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 2)).d_x
Quantity['speed'](Array([1.], dtype=float32), unit='m / s')
>>> a = cx.CartesianAcceleration2D.constructor([1, 2], "m/s2")
>>> a.d2_x
Quantity['acceleration'](Array(1., dtype=float32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 2)).d2_x
Quantity['acceleration'](Array([1.], dtype=float32), unit='m / s2')
Cartesian 3D position, velocity, and acceleration:
>>> q = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> q.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 3)).x
Quantity['length'](Array([1.], dtype=float32), unit='m')
>>> p = cx.CartesianVelocity3D.constructor([1, 2, 3], "m/s")
>>> p.d_x
Quantity['speed'](Array(1., dtype=float32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 3)).d_x
Quantity['speed'](Array([1.], dtype=float32), unit='m / s')
>>> a = cx.CartesianAcceleration3D.constructor([1, 2, 3], "m/s2")
>>> a.d2_x
Quantity['acceleration'](Array(1., dtype=float32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 3)).d2_x
Quantity['acceleration'](Array([1.], dtype=float32), unit='m / s2')
"""
# TODO: use `jax.lax.broadcast_in_dim_p`
c_shape = shape[:-1]
return replace(
operand,
**{k: jnp.broadcast_to(v, c_shape) for k, v in field_items(operand)},
)

0 comments on commit e378363

Please sign in to comment.