Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docstrings for jnp.blackman, jnp.bartlett, jnp.hamming, jnp.hanning, jnp.kaiser #23776

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 106 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10328,8 +10328,28 @@ def clamp_index(i: DimSize, which: str):
return start, step, slice_size


@util.implements(np.blackman)
def blackman(M: int) -> Array:
"""Return a Blackman window of size M.

JAX implementation of :func:`numpy.blackman`.

Args:
M: The window size.

Returns:
An array of size M containing the Blackman window.

Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.blackman(4))
[-0. 0.63 0.63 -0. ]

See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.blackman")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10338,8 +10358,28 @@ def blackman(M: int) -> Array:
return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1))


@util.implements(np.bartlett)
def bartlett(M: int) -> Array:
"""Return a Bartlett window of size M.

JAX implementation of :func:`numpy.bartlett`.

Args:
M: The window size.

Returns:
An array of size M containing the Bartlett window.

Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.bartlett(4))
[0. 0.67 0.67 0. ]

See also:
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.bartlett")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10348,8 +10388,28 @@ def bartlett(M: int) -> Array:
return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1)


@util.implements(np.hamming)
def hamming(M: int) -> Array:
"""Return a Hamming window of size M.

JAX implementation of :func:`numpy.hamming`.

Args:
M: The window size.

Returns:
An array of size M containing the Hamming window.

Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hamming(4))
[0.08 0.77 0.77 0.08]

See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hamming")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10358,8 +10418,28 @@ def hamming(M: int) -> Array:
return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1))


@util.implements(np.hanning)
def hanning(M: int) -> Array:
"""Return a Hanning window of size M.

JAX implementation of :func:`numpy.hanning`.

Args:
M: The window size.

Returns:
An array of size M containing the Hanning window.

Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.hanning(4))
[0. 0.75 0.75 0. ]

See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.kaiser`: return a Kaiser window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.hanning")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand All @@ -10368,8 +10448,29 @@ def hanning(M: int) -> Array:
return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1)))


@util.implements(np.kaiser)
def kaiser(M: int, beta: ArrayLike) -> Array:
"""Return a Kaiser window of size M.

JAX implementation of :func:`numpy.kaiser`.

Args:
M: The window size.
beta: The Kaiser window parameter.

Returns:
An array of size M containing the Kaiser window.

Examples:
>>> with jnp.printoptions(precision=2, suppress=True):
... print(jnp.kaiser(4, 1.5))
[0.61 0.95 0.95 0.61]

See also:
- :func:`jax.numpy.bartlett`: return a Bartlett window of size M.
- :func:`jax.numpy.blackman`: return a Blackman window of size M.
- :func:`jax.numpy.hamming`: return a Hamming window of size M.
- :func:`jax.numpy.hanning`: return a Hanning window of size M.
"""
M = core.concrete_or_error(int, M, "M argument of jnp.kaiser")
dtype = dtypes.canonicalize_dtype(float_)
if M <= 1:
Expand Down