Skip to content

Commit

Permalink
♻️ refactor(vecs): pass kwargs to astype (#304)
Browse files Browse the repository at this point in the history
* ♻️ refactor(vecs): pass kwargs to astype
* 🎨 style(vecs): import format for checks
  • Loading branch information
nstarman authored Dec 18, 2024
1 parent 43d1605 commit beea0f8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
12 changes: 8 additions & 4 deletions src/coordinax/_src/vectors/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,9 @@ def to_device(self, device: None | Device = None) -> "Self":
# -------------------------------

@dispatch
def astype(self: "AbstractVector", dtype: Any, /) -> "AbstractVector":
def astype(
self: "AbstractVector", dtype: Any, /, **kwargs: Any
) -> "AbstractVector":
"""Cast the vector to a new dtype.
Examples
Expand All @@ -715,11 +717,13 @@ def astype(self: "AbstractVector", dtype: Any, /) -> "AbstractVector":
CartesianPos1D(x=Quantity[...](value=f32[2], unit=Unit("m")))
"""
return replace(self, **{k: v.astype(dtype) for k, v in field_items(self)})
return replace(
self, **{k: v.astype(dtype, **kwargs) for k, v in field_items(self)}
)

@dispatch
def astype(
self: "AbstractVector", dtypes: Mapping[str, DTypeLike], /
self: "AbstractVector", dtypes: Mapping[str, DTypeLike], /, **kwargs: Any
) -> "AbstractVector":
"""Cast the vector to a new dtype.
Expand All @@ -738,7 +742,7 @@ def astype(
return replace(
self,
**{
k: (v.astype(dtypes[k]) if k in dtypes else v)
k: (v.astype(dtypes[k], **kwargs) if k in dtypes else v)
for k, v in field_items(self)
},
)
Expand Down
16 changes: 8 additions & 8 deletions src/coordinax/_src/vectors/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import equinox as eqx

import quaxed.numpy as xp
import quaxed.numpy as jnp
import unxt as u
from unxt.quantity import AbstractQuantity

Expand Down Expand Up @@ -76,7 +76,7 @@ def check_polar_range(
)
return eqx.error_if(
polar,
xp.any(xp.logical_or((polar < _l), (polar > _u))),
jnp.any(jnp.logical_or((polar < _l), (polar > _u))),
"The inclination angle must be in the range [0, pi].",
)

Expand All @@ -102,7 +102,7 @@ def check_non_negative(x: AbstractQuantity, /, *, name: str = "") -> AbstractQua
"""
name = f" {name}" if name else name
return eqx.error_if(x, xp.any(x < 0), f"The input{name} must be non-negative.")
return eqx.error_if(x, jnp.any(x < 0), f"The input{name} must be non-negative.")


def check_non_negative_non_zero(
Expand Down Expand Up @@ -133,7 +133,7 @@ def check_non_negative_non_zero(
"""
name = f" {name}" if name else name
return eqx.error_if(
x, xp.any(x <= 0), f"The input{name} must be non-negative and non-zero."
x, jnp.any(x <= 0), f"The input{name} must be non-negative and non-zero."
)


Expand Down Expand Up @@ -165,7 +165,7 @@ def check_less_than(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be less than {comparison_name}."
return eqx.error_if(x, xp.any(x >= max_val), msg)
return eqx.error_if(x, jnp.any(x >= max_val), msg)


def check_less_than_equal(
Expand Down Expand Up @@ -196,7 +196,7 @@ def check_less_than_equal(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be less than or equal to {comparison_name}."
return eqx.error_if(x, xp.any(x > max_val), msg)
return eqx.error_if(x, jnp.any(x > max_val), msg)


def check_greater_than(
Expand Down Expand Up @@ -227,7 +227,7 @@ def check_greater_than(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be greater than {comparison_name}."
return eqx.error_if(x, xp.any(x <= min_val), msg)
return eqx.error_if(x, jnp.any(x <= min_val), msg)


def check_greater_than_equal(
Expand Down Expand Up @@ -258,4 +258,4 @@ def check_greater_than_equal(
"""
name = f" {name}" if name else name
msg = f"The input{name} must be greater than or equal to {comparison_name}."
return eqx.error_if(x, xp.any(x < min_val), msg)
return eqx.error_if(x, jnp.any(x < min_val), msg)

0 comments on commit beea0f8

Please sign in to comment.