From 71450cad5689e6466583b0d705be766397fec21c Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 20 Sep 2024 09:14:38 -0700 Subject: [PATCH] Add docstrings for jnp.blackman, jnp.bartlett, jnp.hamming, jnp.hanning, jnp.kaiser Part of https://github.com/jax-ml/jax/issues/21461 PiperOrigin-RevId: 676866721 --- jax/_src/numpy/lax_numpy.py | 111 ++++++++++++++++++++++++++++++++++-- 1 file changed, 106 insertions(+), 5 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 5b936268581a..716c764ee074 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -10328,8 +10328,28 @@ def clamp_index(i: DimSize, which: str): return start, step, slice_size -@util.implements(np.blackman) def blackman(M: int) -> Array: + """Return a Blackman window of size M. + + JAX implementation of :func:`numpy.blackman`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Blackman window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.blackman(4)) + [-0. 0.63 0.63 -0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.blackman") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10338,8 +10358,28 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) -@util.implements(np.bartlett) def bartlett(M: int) -> Array: + """Return a Bartlett window of size M. + + JAX implementation of :func:`numpy.bartlett`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Bartlett window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.bartlett(4)) + [0. 0.67 0.67 0. ] + + See also: + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10348,8 +10388,28 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) -@util.implements(np.hamming) def hamming(M: int) -> Array: + """Return a Hamming window of size M. + + JAX implementation of :func:`numpy.hamming`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hamming window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hamming(4)) + [0.08 0.77 0.77 0.08] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.hamming") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10358,8 +10418,28 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) -@util.implements(np.hanning) def hanning(M: int) -> Array: + """Return a Hanning window of size M. + + JAX implementation of :func:`numpy.hanning`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hanning window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hanning(4)) + [0. 0.75 0.75 0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.hanning") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -10368,8 +10448,29 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) -@util.implements(np.kaiser) def kaiser(M: int, beta: ArrayLike) -> Array: + """Return a Kaiser window of size M. + + JAX implementation of :func:`numpy.kaiser`. + + Args: + M: The window size. + beta: The Kaiser window parameter. + + Returns: + An array of size M containing the Kaiser window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.kaiser(4, 1.5)) + [0.61 0.95 0.95 0.61] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: