Skip to content

Commit

Permalink
jnp.roll: more efficient implementation for static shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 2, 2023
1 parent e990453 commit 5497bb0
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3712,35 +3712,50 @@ def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:


@partial(jit, static_argnums=(2,))
def _roll(a, shift, axis):
a_shape = shape(a)
if axis is None:
return lax.reshape(_roll(ravel(a), shift, axis=0), a_shape)
shift = asarray(shift)
a_ndim = len(a_shape)
axis = np.asarray(axis)
b_shape = lax.broadcast_shapes(shift.shape, axis.shape, (1,))
def _roll_dynamic(a: Array, shift: Array, axis: Sequence[int]) -> Array:
b_shape = lax.broadcast_shapes(shift.shape, np.shape(axis))
if len(b_shape) != 1:
msg = "'shift' and 'axis' arguments to roll must be scalars or 1D arrays"
raise ValueError(msg)

for x, i in zip(broadcast_to(shift, b_shape),
np.broadcast_to(axis, b_shape)):
i = _canonicalize_axis(i, a_ndim)
a_shape_i = array(a_shape[i], dtype=np.int32)
a_shape_i = array(a.shape[i], dtype=np.int32)
x = ufuncs.remainder(lax.convert_element_type(x, np.int32),
lax.max(a_shape_i, np.int32(1)))
a = lax.concatenate((a, a), i)
a = lax.dynamic_slice_in_dim(a, a_shape_i - x, a_shape[i], axis=i)
lax.max(a_shape_i, np.int32(1)))
a_concat = lax.concatenate((a, a), i)
a = lax.dynamic_slice_in_dim(a_concat, a_shape_i - x, a.shape[i], axis=i)
return a

@partial(jit, static_argnums=(1, 2))
def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array:
for ax, s in zip(*np.broadcast_arrays(axis, shift)):
if a.shape[ax] == 0:
continue
i = (-s) % a.shape[ax]
a = lax.concatenate([lax.slice_in_dim(a, i, a.shape[ax], axis=ax),
lax.slice_in_dim(a, 0, i, axis=ax)],
dimension=ax)
return a

@util._wraps(np.roll)
def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None):
util.check_arraylike("roll", a,)
if isinstance(axis, list):
axis = tuple(axis)
return _roll(a, shift, axis)
def roll(a: ArrayLike, shift: Union[ArrayLike, Sequence[int]],
axis: Optional[Union[int, Sequence[int]]] = None) -> Array:
util.check_arraylike("roll", a)
arr = asarray(a)
if axis is None:
return roll(arr.ravel(), shift, 0).reshape(arr.shape)
axis = _ensure_index_tuple(axis)
axis = tuple(_canonicalize_axis(ax, arr.ndim) for ax in axis)
if not core.is_constant_shape(arr.shape):
# TODO(necula): support static roll for polymorphic shapes.
return _roll_dynamic(arr, asarray(shift), axis)
try:
shift = _ensure_index_tuple(shift)
except TypeError:
return _roll_dynamic(arr, asarray(shift), axis)
else:
return _roll_static(arr, shift, axis)


@util._wraps(np.rollaxis, lax_description=_ARRAY_VIEW_DOC)
Expand Down

0 comments on commit 5497bb0

Please sign in to comment.