Skip to content

Commit

Permalink
Merge pull request #12960 from jakevdp:annotate-lax-numpy
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 483764493
  • Loading branch information
jax authors committed Oct 25, 2022
2 parents b0a1dea + 2f27d51 commit 05f78d7
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 65 deletions.
168 changes: 108 additions & 60 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,9 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
dtype = dtypes.to_inexact_dtype(arr.dtype)
if _ndim(bins) == 1:
return asarray(bins, dtype=dtype)
bins = core.concrete_or_error(operator.index, bins,
"bins argument of histogram_bin_edges")

bins_int = core.concrete_or_error(operator.index, bins,
"bins argument of histogram_bin_edges")
if range is None:
range = [arr.min(), arr.max()]
range = asarray(range, dtype=dtype)
Expand All @@ -436,7 +437,7 @@ def histogram_bin_edges(a: ArrayLike, bins: ArrayLike = 10,
range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
where(ptp(range) == 0, range[1] + 0.5, range[1]))
assert range is not None
return linspace(range[0], range[1], bins + 1, dtype=dtype)
return linspace(range[0], range[1], bins_int + 1, dtype=dtype)


@_wraps(np.histogram)
Expand Down Expand Up @@ -865,13 +866,13 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> Tuple[Array, ...]:
shape = [shape]
if _any(ndim(s) != 0 for s in shape):
raise ValueError("unravel_index: shape should be a scalar or 1D sequence.")
out_indices = [None] * len(shape)
out_indices = [0] * len(shape)
for i, s in reversed(list(enumerate(shape))):
indices_arr, out_indices[i] = divmod(indices_arr, s)
oob_pos = indices_arr > 0
oob_neg = indices_arr < -1
return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i))
for s, i in zip(shape, out_indices))
for s, i in safe_zip(shape, out_indices))

@_wraps(np.resize)
@partial(jit, static_argnames=('new_shape',))
Expand Down Expand Up @@ -986,41 +987,66 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05, atol: ArrayLike

@_wraps(np.interp)
@jit
def interp(x, xp, fp, left=None, right=None, period=None):
def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike,
left: Optional[ArrayLike] = None,
right: Optional[ArrayLike] = None,
period: Optional[ArrayLike] = None) -> Array:
_check_arraylike("interp", x, xp, fp)
if shape(xp) != shape(fp) or ndim(xp) != 1:
raise ValueError("xp and fp must be one-dimensional arrays of equal size")
x, xp = _promote_dtypes_inexact(x, xp)
fp, = _promote_dtypes_inexact(fp)
x_arr, xp_arr = _promote_dtypes_inexact(x, xp)
fp_arr, = _promote_dtypes_inexact(fp)
del x, xp, fp

if dtypes.issubdtype(x.dtype, np.complexfloating):
if dtypes.issubdtype(x_arr.dtype, np.complexfloating):
raise ValueError("jnp.interp: complex x values not supported.")

if period is not None:
if ndim(period) != 0:
raise ValueError(f"period must be a scalar; got {period}")
period = abs(period)
x = x % period
xp = xp % period
xp, fp = lax.sort_key_val(xp, fp)
xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
fp = concatenate([fp[-1:], fp, fp[:1]])

i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
df = fp[i] - fp[i - 1]
dx = xp[i] - xp[i - 1]
delta = x - xp[i - 1]

epsilon = np.spacing(np.finfo(xp.dtype).eps)
x_arr = x_arr % period
xp_arr = xp_arr % period
xp_arr, fp_arr = lax.sort_key_val(xp_arr, fp_arr)
xp_arr = concatenate([xp_arr[-1:] - period, xp_arr, xp_arr[:1] + period])
fp_arr = concatenate([fp_arr[-1:], fp_arr, fp_arr[:1]])

i = clip(searchsorted(xp_arr, x_arr, side='right'), 1, len(xp_arr) - 1)
df = fp_arr[i] - fp_arr[i - 1]
dx = xp_arr[i] - xp_arr[i - 1]
delta = x_arr - xp_arr[i - 1]

epsilon = np.spacing(np.finfo(xp_arr.dtype).eps)
dx0 = lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small.
f = where(dx0, fp[i - 1], fp[i - 1] + (delta / where(dx0, 1, dx)) * df)
f = where(dx0, fp_arr[i - 1], fp_arr[i - 1] + (delta / where(dx0, 1, dx)) * df)

left_arr: ArrayLike = fp_arr[0] if left is None else left
right_arr: ArrayLike = fp_arr[-1] if right is None else right

if period is None:
f = where(x < xp[0], fp[0] if left is None else left, f)
f = where(x > xp[-1], fp[-1] if right is None else right, f)
f = where(x_arr < xp_arr[0], left_arr, f)
f = where(x_arr > xp_arr[-1], right_arr, f)
return f


@overload
def where(condition: ArrayLike, x: Literal[None] = None, y: Literal[None] = None, *,
size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]: ...

@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, *,
size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Array: ...

@overload
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]: ...

@_wraps(np.where,
lax_description=_dedent("""
At present, JAX does not support JIT-compilation of the single-argument form
Expand All @@ -1036,7 +1062,10 @@ def interp(x, xp, fp, left=None, right=None, period=None):
fill_value : array_like, optional
When ``size`` is specified and there are fewer than the indicated number of elements, the
remaining elements will be filled with ``fill_value``, which defaults to zero."""))
def where(condition, x=None, y=None, *, size=None, fill_value=None):
def where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None, *, size: Optional[int] = None,
fill_value: Union[None, Array, Tuple[ArrayLike]] = None
) -> Union[Array, Tuple[Array, ...]]:
if x is None and y is None:
_check_arraylike("where", condition)
return nonzero(condition, size=size, fill_value=fill_value)
Expand Down Expand Up @@ -1094,6 +1123,13 @@ def bincount(x, weights=None, minlength=0, *, length=None):
raise ValueError("shape of weights must match shape of x.")
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)

@overload
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...

@overload
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
) -> Tuple[Union[int, core.Tracer], ...]: ...

@_wraps(getattr(np, "broadcast_shapes", None))
def broadcast_shapes(*shapes):
if not shapes:
Expand All @@ -1102,17 +1138,22 @@ def broadcast_shapes(*shapes):
return lax.broadcast_shapes(*shapes)


broadcast_arrays = _wraps(np.broadcast_arrays, lax_description="""\
@_wraps(np.broadcast_arrays, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")(_broadcast_arrays)
""")
def broadcast_arrays(*args: ArrayLike) -> List[Array]:
return _broadcast_arrays(*args)


broadcast_to = _wraps(np.broadcast_to, lax_description="""\
@_wraps(np.broadcast_to, lax_description="""\
The JAX version does not necessarily return a view of the input.
""")(_broadcast_to)
""")
def broadcast_to(array: ArrayLike, shape: Shape) -> Array:
return _broadcast_to(array, shape)


def _split(op, ary, indices_or_sections, axis=0):
def _split(op: str, ary: ArrayLike, indices_or_sections: Union[int, ArrayLike],
axis: int = 0) -> List[Array]:
_check_arraylike(op, ary)
ary = asarray(ary)
axis = core.concrete_or_error(operator.index, axis, f"in jax.numpy.{op} argument `axis`")
Expand All @@ -1133,7 +1174,7 @@ def _split(op, ary, indices_or_sections, axis=0):
else:
indices_or_sections = core.concrete_or_error(np.int64, indices_or_sections,
f"in jax.numpy.{op} argument 1")
part_size, r = _divmod(size, indices_or_sections)
part_size, r = _divmod(size, indices_or_sections) # type: ignore[misc]
if r == 0:
split_indices = np.arange(indices_or_sections + 1,
dtype=np.int64) * part_size
Expand All @@ -1150,12 +1191,12 @@ def _split(op, ary, indices_or_sections, axis=0):
for start, end in zip(split_indices[:-1], split_indices[1:])]

@_wraps(np.split, lax_description=_ARRAY_VIEW_DOC)
def split(ary, indices_or_sections, axis: int = 0):
def split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
return _split("split", ary, indices_or_sections, axis=axis)

def _split_on_axis(op, axis):
def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, Union[int, ArrayLike]], List[Array]]:
@_wraps(getattr(np, op), update_doc=False)
def f(ary, indices_or_sections):
def f(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike]) -> List[Array]:
return _split(op, ary, indices_or_sections, axis=axis)
return f

Expand All @@ -1164,12 +1205,13 @@ def f(ary, indices_or_sections):
dsplit = _split_on_axis("dsplit", axis=2)

@_wraps(np.array_split)
def array_split(ary, indices_or_sections, axis: int = 0):
def array_split(ary: ArrayLike, indices_or_sections: Union[int, ArrayLike], axis: int = 0) -> List[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)

@_wraps(np.clip, skip_params=['out'])
@jit
def clip(a, a_min=None, a_max=None, out=None):
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = None,
a_max: Optional[ArrayLike] = None, out: None = None) -> Array:
_check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
Expand All @@ -1179,11 +1221,11 @@ def clip(a, a_min=None, a_max=None, out=None):
a = maximum(a_min, a)
if a_max is not None:
a = minimum(a_max, a)
return a
return asarray(a)

@_wraps(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
def round(a, decimals=0, out=None):
def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array:
_check_arraylike("round", a)
decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round")
if out is not None:
Expand All @@ -1193,9 +1235,9 @@ def round(a, decimals=0, out=None):
if decimals < 0:
raise NotImplementedError(
"integer np.round not implemented for decimals < 0")
return a # no-op on integer types
return asarray(a) # no-op on integer types

def _round_float(x):
def _round_float(x: ArrayLike) -> Array:
if decimals == 0:
return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)

Expand All @@ -1219,7 +1261,7 @@ def _round_float(x):

@_wraps(np.fix, skip_params=['out'])
@jit
def fix(x, out=None):
def fix(x: ArrayLike, out: None = None) -> Array:
_check_arraylike("fix", x)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
Expand All @@ -1229,7 +1271,9 @@ def fix(x, out=None):

@_wraps(np.nan_to_num)
@jit
def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
def nan_to_num(x: ArrayLike, copy: bool = True, nan: ArrayLike = 0.0,
posinf: Optional[ArrayLike] = None,
neginf: Optional[ArrayLike] = None) -> Array:
del copy
_check_arraylike("nan_to_num", x)
dtype = _dtype(x)
Expand All @@ -1240,15 +1284,16 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
info = finfo(dtypes.canonicalize_dtype(dtype))
posinf = info.max if posinf is None else posinf
neginf = info.min if neginf is None else neginf
x = where(isnan(x), array(nan, dtype=x.dtype), x)
x = where(isposinf(x), array(posinf, dtype=x.dtype), x)
x = where(isneginf(x), array(neginf, dtype=x.dtype), x)
return x
out = where(isnan(x), asarray(nan, dtype=dtype), x)
out = where(isposinf(out), asarray(posinf, dtype=dtype), out)
out = where(isneginf(out), asarray(neginf, dtype=dtype), out)
return out


@_wraps(np.allclose)
@partial(jit, static_argnames=('equal_nan',))
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
def allclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = 1e-05,
atol: ArrayLike = 1e-08, equal_nan: bool = False):
_check_arraylike("allclose", a, b)
return all(isclose(a, b, rtol, atol, equal_nan))

Expand All @@ -1269,31 +1314,34 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
"""

@_wraps(np.nonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def nonzero(a, *, size=None, fill_value=None):
def nonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None
) -> Tuple[Array, ...]:
_check_arraylike("nonzero", a)
a = atleast_1d(a)
mask = a if a.dtype == bool else (a != 0)
arr = atleast_1d(a)
del a
mask = arr if arr.dtype == bool else (arr != 0)
if size is None:
size = mask.sum()
size = core.concrete_or_error(operator.index, size,
"The size argument of jnp.nonzero must be statically specified "
"to use jnp.nonzero within JAX transformations.")
if a.size == 0 or size == 0:
return tuple(zeros(size, int) for dim in a.shape)
if arr.size == 0 or size == 0:
return tuple(zeros(size, int) for dim in arr.shape)
flat_indices = cumsum(bincount(cumsum(mask), length=size))
strides = (np.cumprod(a.shape[::-1])[::-1] // a.shape).astype(int_)
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape))
strides = (np.cumprod(arr.shape[::-1])[::-1] // arr.shape).astype(int_)
out = tuple((flat_indices // stride) % size for stride, size in zip(strides, arr.shape))
if size is not None and fill_value is not None:
if not isinstance(fill_value, tuple):
fill_value = a.ndim * (fill_value,)
if _shape(fill_value) != (a.ndim,):
raise ValueError(f"fill_value must be a scalar or a tuple of length {a.ndim}; got {fill_value}")
fill_value_tup = fill_value if isinstance(fill_value, tuple) else arr.ndim * (fill_value,)
if _any(_shape(val) != () for val in fill_value_tup):
raise ValueError(f"fill_value must be a scalar or a tuple of length {arr.ndim}; got {fill_value}")
fill_mask = arange(size) >= mask.sum()
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value, out))
out = tuple(where(fill_mask, fval, entry) for fval, entry in safe_zip(fill_value_tup, out))
return out

@_wraps(np.flatnonzero, lax_description=_NONZERO_DOC, extra_params=_NONZERO_EXTRA_PARAMS)
def flatnonzero(a, *, size=None, fill_value=None):
def flatnonzero(a: ArrayLike, *, size: Optional[int] = None,
fill_value: Union[None, ArrayLike, Tuple[ArrayLike]] = None) -> Array:
return nonzero(ravel(a), size=size, fill_value=fill_value)[0]


Expand Down
3 changes: 1 addition & 2 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,7 @@ def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None) -> Array:
def _where(condition: ArrayLike, x: ArrayLike, y: ArrayLike) -> Array:
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/optimize/_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def body_fun(state: LBFGSResults):
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))

# replacements for next iteration
status = 0
status = jnp.array(0)
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore
status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore
Expand Down
3 changes: 1 addition & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5134,7 +5134,7 @@ def testWrappedSignaturesMatch(self):
# TODO(jakevdp): fix some of the following signatures. Some are due to wrong argument names.
unsupported_params = {
'asarray': ['like'],
'broadcast_to': ['subok', 'array'],
'broadcast_to': ['subok'],
'clip': ['kwargs'],
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
Expand Down Expand Up @@ -5164,7 +5164,6 @@ def testWrappedSignaturesMatch(self):
}

extra_params = {
'broadcast_to': ['arr'],
'einsum': ['precision'],
'einsum_path': ['subscripts'],
'take_along_axis': ['mode'],
Expand Down

0 comments on commit 05f78d7

Please sign in to comment.