Skip to content

Commit

Permalink
Add docstrings for jnp.blackman, jnp.bartlett, jnp.hamming, jnp.hanni…
Browse files Browse the repository at this point in the history
…ng, jnp.kaiser

Part of #21461

PiperOrigin-RevId: 676866721
  • Loading branch information
Jake VanderPlas authored and Google-ML-Automation committed Sep 20, 2024
1 parent 81b8b4b commit 71450ca
Showing 1 changed file with 106 additions and 5 deletions.
111 changes: 106 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10328,8 +10328,28 @@ def clamp_index(i: DimSize, which: str):
return start, step, slice_size


@util.implements(np.blackman)
def blackman(M: int) -> Array:
"""Return a Blackman window of size M.
JAX implementation of :func:`numpy.blackman`.
Args:
M: The window size.
Returns:
An array of size M containing the Blackman window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.blackman(4))
[-0. 0.63 0.63 -0. ]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.blackman")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10338,8 +10358,28 @@ def blackman(M: int) -> Array:
return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1))


@util.implements(np.bartlett)
def bartlett(M: int) -> Array:
"""Return a Bartlett window of size M.
JAX implementation of :func:`numpy.bartlett`.
Args:
M: The window size.
Returns:
An array of size M containing the Bartlett window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.bartlett(4))
[0. 0.67 0.67 0. ]
See also:
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.bartlett")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10348,8 +10388,28 @@ def bartlett(M: int) -> Array:
return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1)


@util.implements(np.hamming)
def hamming(M: int) -> Array:
"""Return a Hamming window of size M.
JAX implementation of :func:`numpy.hamming`.
Args:
M: The window size.
Returns:
An array of size M containing the Hamming window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hamming(4))
[0.08 0.77 0.77 0.08]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hamming")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10358,8 +10418,28 @@ def hamming(M: int) -> Array:
return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1))


@util.implements(np.hanning)
def hanning(M: int) -> Array:
"""Return a Hanning window of size M.
JAX implementation of :func:`numpy.hanning`.
Args:
M: The window size.
Returns:
An array of size M containing the Hanning window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hanning(4))
[0. 0.75 0.75 0. ]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hanning")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10368,8 +10448,29 @@ def hanning(M: int) -> Array:
return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1)))


@util.implements(np.kaiser)
def kaiser(M: int, beta: ArrayLike) -> Array:
"""Return a Kaiser window of size M.
JAX implementation of :func:`numpy.kaiser`.
Args:
M: The window size.
beta: The Kaiser window parameter.
Returns:
An array of size M containing the Kaiser window.
Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.kaiser(4, 1.5))
[0.61 0.95 0.95 0.61]
See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.kaiser")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand Down

0 comments on commit 71450ca

Please sign in to comment.