Skip to content

Commit

Permalink
jax.Array: support fast path for lax.transpose & lax.squeeze
Browse files Browse the repository at this point in the history
As part of this change, I created a helper function so that the logic of type checking is
in a single location. Eventually we can replace this helper function with appropriate
isinstance() checks using the APIs described in jax-ml#11859.
  • Loading branch information
jakevdp committed Sep 9, 2022
1 parent 056f400 commit 07f55b3
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

# TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
def _is_array_or_tracer(operand: Any) -> bool:
if config.jax_array:
from jax.experimental import array # pylint: disable=g-import-not-at-top
return isinstance(operand, (core.Tracer, array.Array))
else:
return isinstance(operand, (core.Tracer, device_array.DeviceArray))

def _validate_shapes(shapes: Sequence[Shape]):
def _check_static_shape(shape: Shape):
checked = canonicalize_shape(shape)
Expand Down Expand Up @@ -548,8 +556,6 @@ def convert_element_type(operand: Array, new_dtype: DType) -> Array:

def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
weak_type: bool = False):
from jax.experimental import array

# Don't canonicalize old_dtype because x64 context might cause
# un-canonicalized operands to be passed in.
old_dtype = dtypes.dtype(operand, canonicalize=False)
Expand All @@ -576,8 +582,7 @@ def _convert_element_type(operand: Array, new_dtype: Optional[DType] = None,
operand = np.asarray(operand, new_dtype)
old_weak_type = False

if ((old_dtype, old_weak_type) == (new_dtype, new_weak_type)
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
if (old_dtype, old_weak_type) == (new_dtype, new_weak_type) and _is_array_or_tracer(operand):
return operand
else:
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
Expand Down Expand Up @@ -628,13 +633,11 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
Returns:
An array containing the concatenation.
"""
from jax.experimental import array

if len(operands) == 0:
raise ValueError("concatenate requires a non-empty sequences of arrays")
if len(operands) == 1:
op, = operands
if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)):
if _is_array_or_tracer(op):
return op
return concatenate_p.bind(*operands, dimension=dimension)

Expand Down Expand Up @@ -802,10 +805,7 @@ def broadcast_in_dim(operand: Array, shape: Shape,
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
from jax.experimental import array

if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
and isinstance(operand, (device_array.DeviceArray, core.Tracer, array.Array))):
if np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and _is_array_or_tracer(operand):
return operand
if config.jax_dynamic_shapes:
# We must gate this behavior under a flag because otherwise the errors
Expand Down Expand Up @@ -860,8 +860,6 @@ def reshape(operand: Array, new_sizes: Shape,
>>> reshape(y, (6,), (1, 0))
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)
"""
from jax.experimental import array

new_sizes = canonicalize_shape(new_sizes) # TODO
new_sizes = tuple(new_sizes)
same_shape = core.symbolic_equal_shape(np.shape(operand), new_sizes)
Expand All @@ -871,8 +869,7 @@ def reshape(operand: Array, new_sizes: Shape,
else:
dims = api_util._ensure_index_tuple(dimensions)
same_dims = tuple(dims) == tuple(range(np.ndim(operand)))
if (np.shape(operand) and same_shape and same_dims
and isinstance(operand, (core.Tracer, device_array.DeviceArray, array.Array))):
if np.shape(operand) and same_shape and same_dims and _is_array_or_tracer(operand):
return operand
else:
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
Expand Down Expand Up @@ -951,8 +948,7 @@ def transpose(operand: Array, permutation: Sequence[int]) -> Array:
operator.
"""
permutation = tuple(operator.index(d) for d in permutation)
if (permutation == tuple(range(np.ndim(operand)))
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
if permutation == tuple(range(np.ndim(operand))) and _is_array_or_tracer(operand):
return operand
else:
return transpose_p.bind(operand, permutation=permutation)
Expand Down Expand Up @@ -1282,7 +1278,7 @@ def squeeze(array: Array, dimensions: Sequence[int]) -> Array:
"""Squeeze any number of size 1 dimensions from an array."""
ndim = np.ndim(array)
dimensions = tuple(sorted(canonicalize_axis(i, ndim) for i in dimensions))
if not dimensions and isinstance(array, (core.Tracer, device_array.DeviceArray)):
if not dimensions and _is_array_or_tracer(array):
return array
return squeeze_p.bind(array, dimensions=dimensions)

Expand Down

0 comments on commit 07f55b3

Please sign in to comment.