Skip to content

Commit

Permalink
feat: allow broadcasting
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Sep 23, 2024
1 parent ddf8fec commit 82b52c5
Showing 1 changed file with 115 additions and 0 deletions.
115 changes: 115 additions & 0 deletions src/coordinax/_src/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,3 +904,118 @@ def _convert_element_type_p(operand: AbstractVector, **kwargs: Any) -> AbstractV
operand,
**{k: convert_p(v, **kwargs) for k, v in field_items(operand)},
)


@register(jax.lax.broadcast_in_dim_p) # type: ignore[misc]
def _broadcast_in_dim_p(
operand: AbstractVector, *, shape: tuple[int, ...], **kwargs: Any
) -> AbstractVector:
"""Broadcast in a dimension.
Examples
--------
>>> import quaxed.numpy as jnp
>>> import coordinax as cx
Cartesian 1D position, velocity, and acceleration:
>>> q = cx.CartesianPosition1D.constructor([1], "m")
>>> q.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 1)).x
Quantity['length'](Array([1.], dtype=float32), unit='m')
>>> p = cx.CartesianVelocity1D.constructor([1], "m/s")
>>> p.d_x
Quantity['speed'](Array(1, dtype=int32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 1)).d_x
Quantity['speed'](Array([1], dtype=int32), unit='m / s')
>>> a = cx.CartesianAcceleration1D.constructor([1], "m/s2")
>>> a.d2_x
Quantity['acceleration'](Array(1, dtype=int32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 1)).d2_x
Quantity['acceleration'](Array([1], dtype=int32), unit='m / s2')
Radial 1D position, velocity, and acceleration:
>>> q = cx.RadialPosition.constructor([1], "m")
>>> q.r
Distance(Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 1)).r
Distance(Array([1.], dtype=float32), unit='m')
>>> p = cx.RadialVelocity.constructor([1], "m/s")
>>> p.d_r
Quantity['speed'](Array(1, dtype=int32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 1)).d_r
Quantity['speed'](Array([1], dtype=int32), unit='m / s')
>>> a = cx.RadialAcceleration.constructor([1], "m/s2")
>>> a.d2_r
Quantity['acceleration'](Array(1, dtype=int32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 1)).d2_r
Quantity['acceleration'](Array([1], dtype=int32), unit='m / s2')
Cartesian 2D position, velocity, and acceleration:
>>> q = cx.CartesianPosition2D.constructor([1, 2], "m")
>>> q.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 2)).x
Quantity['length'](Array([1.], dtype=float32), unit='m')
>>> p = cx.CartesianVelocity2D.constructor([1, 2], "m/s")
>>> p.d_x
Quantity['speed'](Array(1., dtype=float32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 2)).d_x
Quantity['speed'](Array([1.], dtype=float32), unit='m / s')
>>> a = cx.CartesianAcceleration2D.constructor([1, 2], "m/s2")
>>> a.d2_x
Quantity['acceleration'](Array(1., dtype=float32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 2)).d2_x
Quantity['acceleration'](Array([1.], dtype=float32), unit='m / s2')
Cartesian 3D position, velocity, and acceleration:
>>> q = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> q.x
Quantity['length'](Array(1., dtype=float32), unit='m')
>>> jnp.broadcast_to(q, (1, 3)).x
Quantity['length'](Array([1.], dtype=float32), unit='m')
>>> p = cx.CartesianVelocity3D.constructor([1, 2, 3], "m/s")
>>> p.d_x
Quantity['speed'](Array(1., dtype=float32), unit='m / s')
>>> jnp.broadcast_to(p, (1, 3)).d_x
Quantity['speed'](Array([1.], dtype=float32), unit='m / s')
>>> a = cx.CartesianAcceleration3D.constructor([1, 2, 3], "m/s2")
>>> a.d2_x
Quantity['acceleration'](Array(1., dtype=float32), unit='m / s2')
>>> jnp.broadcast_to(a, (1, 3)).d2_x
Quantity['acceleration'](Array([1.], dtype=float32), unit='m / s2')
"""
# TODO: use `jax.lax.broadcast_in_dim_p`
c_shape = shape[:-1]
return replace(
operand,
**{k: jnp.broadcast_to(v, c_shape) for k, v in field_items(operand)},
)

0 comments on commit 82b52c5

Please sign in to comment.