Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better doc for jax.numpy.i0 #23790

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5973,19 +5973,53 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
return output


@custom_jvp
@util.implements(np.i0)
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
@jit
def i0(x: ArrayLike) -> Array:
r"""Calculate modified Bessel function of first kind, zeroth order.

JAX implementation of :func:`numpy.i0`.
jakevdp marked this conversation as resolved.
Show resolved Hide resolved

Modified Bessel function of first kind, zeroth order is defined by:

.. math::

\mathrm{i0}(x) = I_0(x) = \sum_{k=0}^{\infty} \frac{(x^2/4)^k}{(k!)^2}

Args:
x: scalar or array. Specifies the argument of Bessel function. Complex inputs
are not supported.

Returns:
An array containing the corresponding vlaues of the modified Bessel function
of ``x``.

See also:
jakevdp marked this conversation as resolved.
Show resolved Hide resolved
- :func:`jax.scipy.special.i0`: Calculates the modified Bessel function of
zeroth order.
- :func:`jax.scipy.special.i1`: Calculates the modified Bessel function of
first order.
- :func:`jax.scipy.special.i0e`: Calculates the exponentially scaled modified
Bessel function of zeroth order.

Examples:
>>> x = jnp.array([-2, -1, 0, 1, 2])
>>> jnp.i0(x)
Array([2.2795851, 1.266066 , 1.0000001, 1.266066 , 2.2795851], dtype=float32)
"""
x_arr, = util.promote_args_inexact("i0", x)
if not issubdtype(x_arr.dtype, np.floating):
raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}")
x_arr = lax.abs(x_arr)
return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr))
return _i0(x_arr)


@custom_jvp
def _i0(x):
abs_x = lax.abs(x)
return lax.mul(lax.exp(abs_x), lax.bessel_i0e(abs_x))

@i0.defjvp
@_i0.defjvp
def _i0_jvp(primals, tangents):
primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents)
primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents)
return primal_out, where(primals[0] == 0, 0.0, tangent_out)

def ix_(*args: ArrayLike) -> tuple[Array, ...]:
Expand Down