Skip to content

Commit

Permalink
Merge pull request #12697 from jakevdp:lax-slicing-types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480131675
  • Loading branch information
jax authors committed Oct 10, 2022
2 parents ad4dcc4 + ae9f8ee commit 707b07c
Showing 1 changed file with 28 additions and 27 deletions.
55 changes: 28 additions & 27 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import enum
from functools import partial
from typing import Any, Callable, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
import weakref

import numpy as np
Expand All @@ -40,20 +40,18 @@
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.typing import Array, ArrayLike, Shape

xb = xla_bridge
xc = xla_client

Array = Any
Shape = core.Shape

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

_dtype = partial(dtypes.dtype, canonicalize=True)


def slice(operand: Array, start_indices: Sequence[int],
def slice(operand: ArrayLike, start_indices: Sequence[int],
limit_indices: Sequence[int],
strides: Optional[Sequence[int]] = None) -> Array:
"""Wraps XLA's `Slice
Expand All @@ -64,7 +62,7 @@ def slice(operand: Array, start_indices: Sequence[int],
limit_indices=tuple(limit_indices),
strides=None if strides is None else tuple(strides))

def dynamic_slice(operand: Array, start_indices: Sequence[Array],
def dynamic_slice(operand: Array, start_indices: Union[Array, Sequence[ArrayLike]],
slice_sizes: Shape) -> Array:
"""Wraps XLA's `DynamicSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicslice>`_
Expand Down Expand Up @@ -112,8 +110,8 @@ def dynamic_slice(operand: Array, start_indices: Sequence[Array],
return dynamic_slice_p.bind(operand, *start_indices, *dynamic_sizes,
slice_sizes=tuple(static_sizes))

def dynamic_update_slice(operand: Array, update: Array,
start_indices: Array) -> Array:
def dynamic_update_slice(operand: Array, update: ArrayLike,
start_indices: Union[Array, Sequence[ArrayLike]]) -> Array:
"""Wraps XLA's `DynamicUpdateSlice
<https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice>`_
operator.
Expand Down Expand Up @@ -222,7 +220,7 @@ def from_any(s: Optional[Union[str, 'GatherScatterMode']]):
raise ValueError(f'Unknown gather mode "{s}"')


def gather(operand: Array, start_indices: Array,
def gather(operand: ArrayLike, start_indices: ArrayLike,
dimension_numbers: GatherDimensionNumbers,
slice_sizes: Shape,
*,
Expand Down Expand Up @@ -320,7 +318,7 @@ class ScatterDimensionNumbers(NamedTuple):
scatter_dims_to_operand_dims: Sequence[int]

def scatter_add(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
Expand Down Expand Up @@ -367,7 +365,7 @@ def scatter_add(
mode=GatherScatterMode.from_any(mode))

def scatter_mul(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
Expand Down Expand Up @@ -414,7 +412,7 @@ def scatter_mul(
mode=GatherScatterMode.from_any(mode))

def scatter_min(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
Expand Down Expand Up @@ -461,7 +459,7 @@ def scatter_min(
mode=GatherScatterMode.from_any(mode))

def scatter_max(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
Expand Down Expand Up @@ -573,7 +571,7 @@ def scatter_apply(
_scatter_reduction_computation = lambda x, y: y

def scatter(
operand: Array, scatter_indices: Array, updates: Array,
operand: ArrayLike, scatter_indices: ArrayLike, updates: ArrayLike,
dimension_numbers: ScatterDimensionNumbers, *,
indices_are_sorted: bool = False, unique_indices: bool = False,
mode: Optional[Union[str, GatherScatterMode]] = None) -> Array:
Expand Down Expand Up @@ -686,10 +684,10 @@ def index_in_dim(operand: Array, index: int, axis: int = 0,
return lax.squeeze(result, (axis,))


def dynamic_slice_in_dim(operand: Array, start_index: Array,
def dynamic_slice_in_dim(operand: Array, start_index: ArrayLike,
slice_size: int, axis: int = 0) -> Array:
"""Convenience wrapper around dynamic_slice applying to one dimension."""
start_indices = [np.zeros((), dtype=dtypes.dtype(start_index))] * operand.ndim
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * operand.ndim
slice_sizes = list(operand.shape)

axis = int(axis)
Expand All @@ -708,18 +706,18 @@ def dynamic_index_in_dim(operand: Array, index: Array, axis: int = 0,
return lax.squeeze(result, (axis,))


def dynamic_update_slice_in_dim(operand: Array, update: Array,
start_index: Array, axis: int) -> Array:
def dynamic_update_slice_in_dim(operand: Array, update: ArrayLike,
start_index: ArrayLike, axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
in a single ``axis``.
"""
axis = int(axis)
start_indices = [lax._zero(start_index)] * lax._ndim(operand)
start_indices: List[ArrayLike] = [lax._const(start_index, 0)] * lax._ndim(operand)
start_indices[axis] = start_index
return dynamic_update_slice(operand, update, start_indices)


def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
def dynamic_update_index_in_dim(operand: Array, update: ArrayLike, index: ArrayLike,
axis: int) -> Array:
"""Convenience wrapper around :func:`dynamic_update_slice` to update a slice
of size 1 in a single ``axis``.
Expand All @@ -731,7 +729,6 @@ def dynamic_update_index_in_dim(operand: Array, update: Array, index: Array,
return dynamic_update_slice_in_dim(operand, update, index, axis)



def _slice_shape_rule(operand, *, start_indices, limit_indices, strides):
lax._check_shapelike("slice", "start_indices", start_indices)
lax._check_shapelike("slice", "limit_indices", limit_indices)
Expand Down Expand Up @@ -2094,26 +2091,30 @@ def _scatter(operand_part, updates_part):
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")


def _dynamic_slice_indices(operand, start_indices: Any):
def _dynamic_slice_indices(
operand: Array,
start_indices: Union[Array, Sequence[ArrayLike]]
) -> List[Array]:
# Normalize the start_indices w.r.t. operand.shape
if len(start_indices) != operand.ndim:
msg = ("Length of slice indices must match number of operand dimensions ({} "
"vs {})")
raise ValueError(msg.format(len(start_indices), operand.shape))
if not isinstance(start_indices, (tuple, list)):
if start_indices.ndim != 1:
if start_indices.ndim != 1: # type: ignore[union-attr]
raise ValueError("Slice indices must be a 1D sequence, got {}"
.format(start_indices.shape))
.format(start_indices.shape)) # type: ignore[union-attr]
start_indices = list(start_indices)
result = []
result: List[Array] = []
for i, d in zip(start_indices, operand.shape):
# We test whether i and d are static to avoid unnecessary staging.
if isinstance(i, (int, np.integer)) and core.is_constant_dim(d):
result.append(lax.convert_element_type(i + d, _dtype(i)) if i < 0 else i)
result.append(lax.convert_element_type(i + d if i < 0 else i, _dtype(i)))
continue
d = core.dimension_as_value(d)
if isinstance(i, (int, np.integer)):
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0 else i)
result.append(i + lax.convert_element_type(d, _dtype(i)) if i < 0
else lax.convert_element_type(i, _dtype(i)))
continue
d = lax.convert_element_type(d, _dtype(i))
result.append(lax.select(i < 0, i + d, i))
Expand Down

0 comments on commit 707b07c

Please sign in to comment.