Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types to jax/_src/numpy/util.py #12641

Merged
merged 1 commit into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4353,7 +4353,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
_rng_bit_generator_lowering)


def _array_copy(arr):
def _array_copy(arr: ArrayLike) -> Array:
return copy_p.bind(arr)

# The copy_p primitive exists for expressing making copies of runtime arrays.
Expand Down
16 changes: 10 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
_register_stackable, _stackable, _where, _wraps)
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl
Expand Down Expand Up @@ -1838,7 +1839,8 @@ def atleast_3d(*arys):
"""

@_wraps(np.array, lax_description=_ARRAY_DOC)
def array(object, dtype=None, copy=True, order="K", ndmin=0):
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
order: str = "K", ndmin: int = 0) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")

Expand Down Expand Up @@ -1878,6 +1880,8 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)

out: ArrayLike

if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
Expand All @@ -1902,10 +1906,10 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):

raise TypeError(f"Unexpected input type for array: {type(object)}")

out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out):
out = lax.expand_dims(out, range(ndmin - ndim(out)))
return out
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
return out_array


def _convert_to_array_if_dtype_fails(x):
Expand All @@ -1918,7 +1922,7 @@ def _convert_to_array_if_dtype_fails(x):


@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a, dtype=None, order=None):
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order)
Expand Down
64 changes: 33 additions & 31 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re
import textwrap
from typing import (
Any, Callable, NamedTuple, Optional, Dict, Sequence, Set, Type, TypeVar
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Type, TypeVar
)
import warnings

Expand All @@ -28,6 +28,7 @@
from jax._src import api
from jax import core
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape

import numpy as np

Expand Down Expand Up @@ -215,7 +216,7 @@ def wrap(op):

_dtype = partial(dtypes.dtype, canonicalize=True)

def _asarray(arr):
def _asarray(arr: ArrayLike) -> Array:
"""
Pared-down utility to convert object to a DeviceArray.
Note this will not correctly handle lists or tuples.
Expand All @@ -224,10 +225,10 @@ def _asarray(arr):
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax_internal._convert_element_type(arr, dtype, weak_type)

def _promote_shapes(fun_name, *args):
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return args
return [_asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
Expand All @@ -238,10 +239,10 @@ def _promote_shapes(fun_name, *args):
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return args # no need for rank promotion, so rely on lax promotion
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return args # rely on lax scalar promotion
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
Expand All @@ -250,7 +251,7 @@ def _promote_shapes(fun_name, *args):
for arg, shp in zip(args, shapes)]


def _rank_promotion_warning_or_error(fun_name, shapes):
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
Expand All @@ -265,18 +266,18 @@ def _rank_promotion_warning_or_error(fun_name, shapes):
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))


def _promote_dtypes(*args):
def _promote_dtypes(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
return [_asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]


def _promote_dtypes_inexact(*args):
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.

Promotes arguments to an inexact type."""
Expand All @@ -287,7 +288,7 @@ def _promote_dtypes_inexact(*args):
for x in args]


def _promote_dtypes_numeric(*args):
def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.

Promotes arguments to a numeric (non-bool) type."""
Expand All @@ -298,7 +299,7 @@ def _promote_dtypes_numeric(*args):
for x in args]


def _promote_dtypes_complex(*args):
def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.

Promotes arguments to a complex type."""
Expand All @@ -309,23 +310,23 @@ def _promote_dtypes_complex(*args):
for x in args]


def _complex_elem_type(dtype):
def _complex_elem_type(dtype: DTypeLike) -> DType:
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype


def _arraylike(x):
def _arraylike(x: ArrayLike) -> bool:
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or np.isscalar(x))


def _stackable(*args):
def _stackable(*args: Any) -> bool:
return all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add


def _check_arraylike(fun_name, *args):
def _check_arraylike(fun_name: str, *args: Any):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
Expand All @@ -335,7 +336,7 @@ def _check_arraylike(fun_name, *args):
raise TypeError(msg.format(fun_name, type(arg), pos))


def _check_no_float0s(fun_name, *args):
def _check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
Expand All @@ -348,20 +349,20 @@ def _check_no_float0s(fun_name, *args):
"taken a gradient with respect to an integer argument.")


def _promote_args(fun_name, *args):
def _promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))


def _promote_args_numeric(fun_name, *args):
def _promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))


def _promote_args_inexact(fun_name, *args):
def _promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion.

Promotes non-inexact types to an inexact type."""
Expand All @@ -371,20 +372,18 @@ def _promote_args_inexact(fun_name, *args):


@partial(api.jit, inline=True)
def _broadcast_arrays(*args):
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
for arg in args]
return [_asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]


def _broadcast_to(arr, shape):
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
return arr.broadcast_to(shape) # type: ignore[union-attr]
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
Expand Down Expand Up @@ -412,15 +411,18 @@ def _broadcast_to(arr, shape):
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition, x=None, y=None):
def _where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None) -> Array:
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax_internal._zero(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = _broadcast_arrays(condition, x, y)
try: is_always_empty = core.is_empty_shape(np.shape(x))
except: is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x, y) if not is_always_empty else x
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:
is_always_empty = core.is_empty_shape(x_arr.shape)
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr