From 81e50118cfe868b11d6b8b7e126d6d94d00cd441 Mon Sep 17 00:00:00 2001 From: rajasekharporeddy Date: Fri, 20 Sep 2024 22:19:31 +0530 Subject: [PATCH] Better doc for jax.numpy.i0 --- jax/_src/numpy/lax_numpy.py | 46 ++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a1601e9201fe..6e1834c0e92e 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5973,19 +5973,53 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output -@custom_jvp -@util.implements(np.i0) @jit def i0(x: ArrayLike) -> Array: + r"""Calculate modified Bessel function of first kind, zeroth order. + + JAX implementation of :func:`numpy.i0`. + + Modified Bessel function of first kind, zeroth order is defined by: + + .. math:: + + \mathrm{i0}(x) = I_0(x) = \sum_{k=0}^{\infty} \frac{(x^2/4)^k}{(k!)^2} + + Args: + x: scalar or array. Specifies the argument of Bessel function. Complex inputs + are not supported. + + Returns: + An array containing the corresponding vlaues of the modified Bessel function + of ``x``. + + See also: + - :func:`jax.scipy.special.i0`: Calculates the modified Bessel function of + zeroth order. + - :func:`jax.scipy.special.i1`: Calculates the modified Bessel function of + first order. + - :func:`jax.scipy.special.i0e`: Calculates the exponentially scaled modified + Bessel function of zeroth order. + + Examples: + >>> x = jnp.array([-2, -1, 0, 1, 2]) + >>> jnp.i0(x) + Array([2.2795851, 1.266066 , 1.0000001, 1.266066 , 2.2795851], dtype=float32) + """ x_arr, = util.promote_args_inexact("i0", x) if not issubdtype(x_arr.dtype, np.floating): raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") - x_arr = lax.abs(x_arr) - return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr)) + return _i0(x_arr) + + +@custom_jvp +def _i0(x): + abs_x = lax.abs(x) + return lax.mul(lax.exp(abs_x), lax.bessel_i0e(abs_x)) -@i0.defjvp +@_i0.defjvp def _i0_jvp(primals, tangents): - primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents) + primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) def ix_(*args: ArrayLike) -> tuple[Array, ...]: