Skip to content

Commit

Permalink
[typing]: add type annotations for jnp.einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 10, 2023
1 parent f25b701 commit 0de0141
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3136,18 +3136,42 @@ def tensordot(a, b, axes=2, *, precision=None):
the implementation.
"""

@overload
def einsum(
subscript: str, /,
*operands: ArrayLike,
out: None = None,
optimize: str = "optimal",
precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None,
_use_xeinsum: bool = False,
_dot_general: Callable[..., Array] = lax.dot_general,
) -> Array: ...

@overload
def einsum(
arr: ArrayLike,
axes: Sequence[Any], /,
*operands: Union[ArrayLike, Sequence[Any]],
out: None = None,
optimize: str = "optimal",
precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None,
_use_xeinsum: bool = False,
_dot_general: Callable[..., Array] = lax.dot_general,
) -> Array: ...

@util._wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out'])
def einsum(
subscripts,
*operands,
out=None,
optimize="optimal",
precision=None,
preferred_element_type=None,
_use_xeinsum=False,
_dot_general=lax.dot_general,
):
out: None = None,
optimize: str = "optimal",
precision: PrecisionLike = None,
preferred_element_type: Optional[DTypeLike] = None,
_use_xeinsum: bool = False,
_dot_general: Callable[..., Array] = lax.dot_general,
) -> Array:
operands = (subscripts, *operands)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
Expand Down Expand Up @@ -3177,7 +3201,7 @@ def einsum(

_einsum_computation = jax.named_call(
_einsum, name=spec) if spec is not None else _einsum
return _einsum_computation(operands, contractions, precision,
return _einsum_computation(operands, contractions, precision, # type: ignore[operator]
preferred_element_type, _dot_general)


Expand Down

0 comments on commit 0de0141

Please sign in to comment.