From 07f55b38966a3d54b00dd0c1f490759d79776542 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 9 Sep 2022 12:20:23 -0700 Subject: [PATCH] jax.Array: support fast path for lax.transpose & lax.squeeze As part of this change, I created a helper function so that the logic of type checking is in a single location. Eventually we can replace this helper function with appropriate isinstance() checks using the APIs described in #11859. --- jax/_src/lax/lax.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 5bc6c504a395..d397f72afcb8 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -85,6 +85,14 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +# TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete. +def _is_array_or_tracer(operand: Any) -> bool: + if config.jax_array: + from jax.experimental import array # pylint: disable=g-import-not-at-top + return isinstance(operand, (core.Tracer, array.Array)) + else: + return isinstance(operand, (core.Tracer, device_array.DeviceArray)) + def _validate_shapes(shapes: Sequence[Shape]): def _check_static_shape(shape: Shape): checked = canonicalize_shape(shape) @@ -548,8 +556,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array: def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None, weak_type: bool = False): - from jax.experimental import array - # Don't canonicalize old_dtype because x64 context might cause # un-canonicalized operands to be passed in. old_dtype = dtypes.dtype(operand, canonicalize=False) @@ -576,8 +582,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None, operand = np.asarray(operand, new_dtype) old_weak_type = False - if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type) - and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))): + if (old_dtype, old_weak_type) == (new_dtype, new_weak_type) and _is_array_or_tracer(operand): return operand else: return convert_element_type_p.bind(operand, new_dtype=new_dtype, @@ -628,13 +633,11 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array: Returns: An array containing the concatenation. """ - from jax.experimental import array - if len(operands) == 0: raise ValueError("concatenate requires a non-empty sequences of arrays") if len(operands) == 1: op, = operands - if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)): + if _is_array_or_tracer(op): return op return concatenate_p.bind(*operands, dimension=dimension) @@ -802,10 +805,7 @@ def broadcast_in_dim(operand: Array, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ - from jax.experimental import array - - if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) - and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))): + if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and _is_array_or_tracer(operand): return operand if config.jax_dynamic_shapes: # We must gate this behavior under a flag because otherwise the errors @@ -860,8 +860,6 @@ def reshape(operand: Array, new_sizes: Shape, >>> reshape(y, (6,), (1, 0)) DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32) """ - from jax.experimental import array - new_sizes = canonicalize_shape(new_sizes) # TODO new_sizes = tuple(new_sizes) same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes) @@ -871,8 +869,7 @@ def reshape(operand: Array, new_sizes: Shape, else: dims = api_util._ensure_index_tuple(dimensions) same_dims = tuple(dims) == tuple(range(np.ndim(operand))) - if (np.shape(operand) and same_shape and same_dims - and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))): + if np.shape(operand) and same_shape and same_dims and _is_array_or_tracer(operand): return operand else: dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes) @@ -951,8 +948,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array: operator. """ permutation = tuple(operator.index(d) for d in permutation) - if (permutation == tuple(range(np.ndim(operand))) - and isinstance(operand, (core.Tracer, device_array.DeviceArray))): + if permutation == tuple(range(np.ndim(operand))) and _is_array_or_tracer(operand): return operand else: return transpose_p.bind(operand, permutation=permutation) @@ -1282,7 +1278,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array: """Squeeze any number of size 1 dimensions from an array.""" ndim = np.ndim(array) dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions)) - if not dimensions and isinstance(array, (core.Tracer, device_array.DeviceArray)): + if not dimensions and _is_array_or_tracer(array): return array return squeeze_p.bind(array, dimensions=dimensions)