Skip to content

Commit

Permalink
[typing] annotate next chunk of lax_numpy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 25, 2022
1 parent 05f78d7 commit a5bccc8
Showing 1 changed file with 67 additions and 57 deletions.
124 changes: 67 additions & 57 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,22 +760,25 @@ def isrealobj(x: Any) -> bool:


@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
def reshape(a: ArrayLike, newshape: Shape, order: str = "C") -> Array:
def reshape(a: ArrayLike, newshape: Union[DimSize, Shape], order: str = "C") -> Array:
_stackable(a) or _check_arraylike("reshape", a)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
return _reshape(asarray(a), newshape, order=order)

def _compute_newshape(a: ArrayLike, newshape: Shape) -> Shape:
def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
"""Fixes a -1 value in newshape, if present."""
# other errors, like having more than one -1, are caught downstream, in
# reshape_shape_rule.
try: iter(newshape)
except: iterable = False
else: iterable = True
newshape = core.canonicalize_shape(newshape if iterable else [newshape])
try:
iter(newshape) # type: ignore[arg-type]
except:
iterable = False
else:
iterable = True
newshape = core.canonicalize_shape(newshape if iterable else [newshape]) # type: ignore[arg-type]
return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
if core.symbolic_equal_dim(d, -1) else d
for d in newshape)
Expand Down Expand Up @@ -1100,20 +1103,20 @@ def select(condlist, choicelist, default=0):
Additionally, while ``np.bincount`` raises an error if the input array contains
negative values, ``jax.numpy.bincount`` clips negative values to zero.
""")
def bincount(x, weights=None, minlength=0, *, length=None):
def bincount(x: ArrayLike, weights: Optional[ArrayLike] = None,
minlength: int = 0, *, length: Optional[int] = None) -> Array:
_check_arraylike("bincount", x)
if not issubdtype(_dtype(x), integer):
msg = f"x argument to bincount must have an integer type; got {x.dtype}"
raise TypeError(msg)
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
if ndim(x) != 1:
raise ValueError("only 1-dimensional input supported.")
minlength = core.concrete_or_error(operator.index, minlength,
"The error occurred because of argument 'minlength' of jnp.bincount.")
if length is None:
x = core.concrete_or_error(asarray, x,
x_arr = core.concrete_or_error(asarray, x,
"The error occurred because of argument 'x' of jnp.bincount. "
"To avoid this error, pass a static `length` argument.")
length = _max(minlength, x.size and x.max() + 1)
length = _max(minlength, x_arr.size and int(x_arr.max()) + 1)
else:
length = core.concrete_or_error(operator.index, length,
"The error occurred because of argument 'length' of jnp.bincount.")
Expand Down Expand Up @@ -1293,7 +1296,7 @@ def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
@_wraps(np.allclose)
@partial(jit, static_argnames=('equal_nan',))
def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
atol: ArrayLike = 1e-08, equal_nan: bool = False):
atol: ArrayLike = 1e-08, equal_nan: bool = False) -> Array:
_check_arraylike("allclose", a, b)
return all(isclose(a, b, rtol, atol, equal_nan))

Expand Down Expand Up @@ -1347,7 +1350,8 @@ def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,

@_wraps(np.unwrap)
@partial(jit, static_argnames=('axis',))
def unwrap(p, discont=None, axis: int = -1, period = 2 * pi):
def unwrap(p: ArrayLike, discont: Optional[ArrayLike] = None,
axis: int = -1, period: ArrayLike = 2 * pi) -> Array:
_check_arraylike("unwrap", p)
p = asarray(p)
if issubdtype(p.dtype, np.complexfloating):
Expand Down Expand Up @@ -1689,7 +1693,8 @@ def pad(array, pad_width, mode="constant", **kwargs):


@_wraps(np.stack, skip_params=['out'])
def stack(arrays, axis: int = 0, out=None, dtype=None):
def stack(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: int = 0, out: None = None, dtype: Optional[DTypeLike] = None) -> Array:
if not len(arrays):
raise ValueError("Need at least one array to stack.")
if out is not None:
Expand All @@ -1709,21 +1714,24 @@ def stack(arrays, axis: int = 0, out=None, dtype=None):
return concatenate(new_arrays, axis=axis, dtype=dtype)

@_wraps(np.tile)
def tile(A, reps):
def tile(A: ArrayLike, reps: Union[DimSize, Sequence[DimSize]]) -> Array:
_stackable(A) or _check_arraylike("tile", A)
try:
iter(reps)
iter(reps) # type: ignore[arg-type]
except TypeError:
reps = (reps,)
reps = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
for rep in reps)
A_shape = (1,) * (len(reps) - ndim(A)) + shape(A)
reps = (1,) * (len(A_shape) - len(reps)) + reps
reps_tup: Tuple[DimSize, ...] = (reps,)
else:
reps_tup = tuple(reps) # type: ignore[assignment,arg-type]
reps_tup = tuple(operator.index(rep) if core.is_constant_dim(rep) else rep
for rep in reps_tup)
A_shape = (1,) * (len(reps_tup) - ndim(A)) + shape(A)
reps_tup = (1,) * (len(A_shape) - len(reps_tup)) + reps_tup
result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
[k for pair in zip(reps, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps)))
[k for pair in zip(reps_tup, A_shape) for k in pair])
return reshape(result, tuple(np.multiply(A_shape, reps_tup)))

def _concatenate_array(arr, axis: Optional[int], dtype=None):
def _concatenate_array(arr: ArrayLike, axis: Optional[int],
dtype: Optional[DTypeLike] = None) -> Array:
# Fast path for concatenation when the input is an ndarray rather than a list.
arr = asarray(arr, dtype=dtype)
if arr.ndim == 0 or arr.shape[0] == 0:
Expand All @@ -1738,7 +1746,8 @@ def _concatenate_array(arr, axis: Optional[int], dtype=None):
return lax.reshape(arr, shape, dimensions)

@_wraps(np.concatenate)
def concatenate(arrays, axis: Optional[int] = 0, dtype=None):
def concatenate(arrays: Union[np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = 0, dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(arrays, (np.ndarray, ndarray)):
return _concatenate_array(arrays, axis, dtype=dtype)
_stackable(*arrays) or _check_arraylike("concatenate", *arrays)
Expand All @@ -1749,27 +1758,25 @@ def concatenate(arrays, axis: Optional[int] = 0, dtype=None):
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if hasattr(arrays[0], "concatenate"):
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype)
return arrays[0].concatenate(arrays[1:], axis, dtype=dtype) # type: ignore[union-attr]
axis = _canonicalize_axis(axis, ndim(arrays[0]))
if dtype is None:
arrays = _promote_dtypes(*arrays)
arrays_out = _promote_dtypes(*arrays)
else:
arrays = [asarray(arr, dtype=dtype) for arr in arrays]
arrays_out = [asarray(arr, dtype=dtype) for arr in arrays]
# lax.concatenate can be slow to compile for wide concatenations, so form a
# tree of concatenations as a workaround especially for op-by-op mode.
# (https://github.com/google/jax/issues/653).
k = 16
if len(arrays) == 1:
return asarray(arrays[0], dtype=dtype)
else:
while len(arrays) > 1:
arrays = [lax.concatenate(arrays[i:i+k], axis)
for i in range(0, len(arrays), k)]
return arrays[0]
while len(arrays_out) > 1:
arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
for i in range(0, len(arrays_out), k)]
return arrays_out[0]


@_wraps(np.vstack)
def vstack(tup, dtype=None):
def vstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_2d)(tup)
else:
Expand All @@ -1779,7 +1786,8 @@ def vstack(tup, dtype=None):


@_wraps(np.hstack)
def hstack(tup, dtype=None):
def hstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_1d)(tup)
arr0_ndim = arrs.ndim - 1
Expand All @@ -1790,7 +1798,8 @@ def hstack(tup, dtype=None):


@_wraps(np.dstack)
def dstack(tup, dtype=None):
def dstack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]],
dtype: Optional[DTypeLike] = None) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(atleast_3d)(tup)
else:
Expand All @@ -1799,7 +1808,7 @@ def dstack(tup, dtype=None):


@_wraps(np.column_stack)
def column_stack(tup):
def column_stack(tup: Union[np.ndarray, Array, Sequence[ArrayLike]]) -> Array:
if isinstance(tup, (np.ndarray, ndarray)):
arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup
else:
Expand All @@ -1808,7 +1817,8 @@ def column_stack(tup):


@_wraps(np.choose, skip_params=['out'])
def choose(a, choices, out=None, mode='raise'):
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = None, mode: str = 'raise') -> Array:
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
_check_arraylike('choose', a, *choices)
Expand All @@ -1817,51 +1827,51 @@ def choose(a, choices, out=None, mode='raise'):
N = len(choices)

if mode == 'raise':
a = core.concrete_or_error(asarray, a,
arr: Array = core.concrete_or_error(asarray, a,
"The error occurred because jnp.choose was jit-compiled"
" with mode='raise'. Use mode='wrap' or mode='clip' instead.")
if any((a < 0) | (a >= N)):
if any((arr < 0) | (arr >= N)):
raise ValueError("invalid entry in choice array")
elif mode == 'wrap':
a = a % N
arr = asarray(a) % N
elif mode == 'clip':
a = clip(a, 0, N - 1)
arr = clip(a, 0, N - 1)
else:
raise ValueError(f"mode={mode!r} not understood. Must be 'raise', 'wrap', or 'clip'")

a, *choices = broadcast_arrays(a, *choices)
return array(choices)[(a,) + indices(a.shape, sparse=True)]
arr, *choices = broadcast_arrays(arr, *choices)
return array(choices)[(arr,) + indices(arr.shape, sparse=True)]


def _atleast_nd(x, n):
def _atleast_nd(x: ArrayLike, n: int) -> Array:
m = ndim(x)
return lax.broadcast(x, (1,) * (n - m)) if m < n else x
return lax.broadcast(x, (1,) * (n - m)) if m < n else asarray(x)

def _block(xs):
def _block(xs: Union[ArrayLike, List[ArrayLike]]) -> Tuple[Array, int]:
if isinstance(xs, tuple):
raise ValueError("jax.numpy.block does not allow tuples, got {}"
.format(xs))
elif isinstance(xs, list):
if len(xs) == 0:
raise ValueError("jax.numpy.block does not allow empty list arguments")
xs, depths = unzip2([_block(x) for x in xs])
xs_tup, depths = unzip2([_block(x) for x in xs])
if _any(d != depths[0] for d in depths[1:]):
raise ValueError("Mismatched list depths in jax.numpy.block")
rank = _max(depths[0], _max(ndim(x) for x in xs))
xs = [_atleast_nd(x, rank) for x in xs]
return concatenate(xs, axis=-depths[0]), depths[0] + 1
rank = _max(depths[0], _max(ndim(x) for x in xs_tup))
xs_tup = tuple(_atleast_nd(x, rank) for x in xs_tup)
return concatenate(xs_tup, axis=-depths[0]), depths[0] + 1
else:
return asarray(xs), 1

@_wraps(np.block)
@jit
def block(arrays):
def block(arrays: Union[ArrayLike, List[ArrayLike]]) -> Array:
out, _ = _block(arrays)
return out

@_wraps(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_1d(*arys):
def atleast_1d(*arys: ArrayLike) -> Union[Array, List[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
return arr if ndim(arr) >= 1 else reshape(arr, -1)
Expand All @@ -1871,7 +1881,7 @@ def atleast_1d(*arys):

@_wraps(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_2d(*arys):
def atleast_2d(*arys: ArrayLike) -> Union[Array, List[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) >= 2:
Expand All @@ -1886,7 +1896,7 @@ def atleast_2d(*arys):

@_wraps(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC)
@jit
def atleast_3d(*arys):
def atleast_3d(*arys: ArrayLike) -> Union[Array, List[Array]]:
if len(arys) == 1:
arr = asarray(arys[0])
if ndim(arr) == 0:
Expand Down

0 comments on commit a5bccc8

Please sign in to comment.