diff --git a/test/test_megablox.py b/test/test_megablox.py new file mode 100644 index 00000000000..af4ecf76b31 --- /dev/null +++ b/test/test_megablox.py @@ -0,0 +1,161 @@ +"""Grouped matrix multiplication kernels for TPU written in Pallas.""" + +import logging +import unittest + +from typing import Optional, Union, Callable + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.experimental.megablox as megablox +from torch_xla import runtime as xr +from torch_xla._internal import tpu + +import numpy as np + +if xr.device_type() == 'TPU': + from torch_xla.experimental.custom_kernel import jax_import_guard + jax_import_guard() + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + + +class MegabloxTest(unittest.TestCase): + + def _reference_gmm( + self, + lhs: np.array, + rhs: np.array, + group_sizes: np.array, + preferred_element_type: np.dtype = np.float32, + ) -> np.array: + + start = 0 + out = [] + for i, size in enumerate(group_sizes): + result = np.dot(lhs[start:start + size, :], rhs[i, :, :]) + + result = result.astype(preferred_element_type) + out.append(result) + start += group_sizes[i] + return np.array(np.concatenate(out, axis=0)) + + def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: + # Randomly sample the ends of the groups in the m-dimension. Let the fuzzer + # sample with replacement so that it's possible to get zero-sized groups. Get + # 'num_groups - 1' run ends. The final group will end at 'm'. + ends_no_final = np.sort( + np.array( + [np.random.randint(low=0, high=m) for _ in range(num_groups - 1)], + dtype=np.int32, + ),) + ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)]) + + # Calculate the run starts by shifting ends 1 to the right. The first run + # starts at zero. + starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) + return torch.from_numpy(ends - starts).to(torch.int32) + + def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype, + out_dtype: torch.dtype) -> tuple[float, float]: + if (lhs_dtype == torch.bfloat16 or rhs_dtype == torch.bfloat16 or + out_dtype == torch.bfloat16): + return 1e-3, 1e-2 # atol, rtol + return 1e-4, 1e-2 # atol, rtol + + LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] + + def _init_test_cases(self): + self.tests_cases = [] + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 128, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 256, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 128, + 'k': 256, + 'n': 128, + 'num_groups': 8 + }) + self.tests_cases.append({ + 'dtype': torch.float32, + 'm': 512, + 'k': 128, + 'n': 256, + 'num_groups': 2 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 128, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 256, + 'k': 128, + 'n': 128, + 'num_groups': 1 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 128, + 'k': 256, + 'n': 128, + 'num_groups': 8 + }) + self.tests_cases.append({ + 'dtype': torch.bfloat16, + 'm': 512, + 'k': 128, + 'n': 256, + 'num_groups': 2 + }) + + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_gmm(self): + self._init_test_cases() + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = test_case['dtype'] + out_dtype = torch.float32 + + lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + out = megablox.gmm(lhs, rhs, group_sizes) + + ref_out = self._reference_gmm(lhs.cpu().float().numpy(), + rhs.cpu().float().numpy(), + group_sizes.numpy()) + + atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype) + np.testing.assert_allclose( + ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + torch.set_default_dtype(torch.float32) + torch.manual_seed(42) + torch_xla._XLAC._xla_set_use_full_mat_mul_precision( + use_full_mat_mul_precision=True) + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index a0eddebc3d5..2653d0ed8c2 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -22,6 +22,7 @@ python3 test/pjrt/test_dynamic_plugin_tpu.py python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py python3 test/test_pallas.py python3 test/test_input_output_aliases.py +python3 test/test_megablox.py python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py diff --git a/torch_xla/experimental/megablox/__init__.py b/torch_xla/experimental/megablox/__init__.py new file mode 100644 index 00000000000..63f60f808cd --- /dev/null +++ b/torch_xla/experimental/megablox/__init__.py @@ -0,0 +1 @@ +from .gmm import gmm diff --git a/torch_xla/experimental/megablox/common.py b/torch_xla/experimental/megablox/common.py new file mode 100644 index 00000000000..19555ab208a --- /dev/null +++ b/torch_xla/experimental/megablox/common.py @@ -0,0 +1,22 @@ +"""Common utilities for Pallas kernels.""" + +from typing import Union +import torch +from torch_xla._internal import tpu + + +def assert_is_supported_dtype(dtype: torch.dtype) -> None: + if dtype != torch.bfloat16 and dtype != torch.float32: + raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.") + + +def select_input_dtype(lhs: torch.Tensor, rhs: torch.Tensor) -> torch.dtype: + """A type to which both input should be adapted to before dot product.""" + # bf16xbf16 matmul is only supported since TPU v4 generation. In + # case of mixed input precision, we need to convert bf16 argument to fp32 + # beforehand. + if (tpu.version() >= 4 and lhs.dtype == torch.bfloat16 and + rhs.dtype == torch.bfloat16): + return torch.bfloat16 + else: + return torch.float32 diff --git a/torch_xla/experimental/megablox/gmm.py b/torch_xla/experimental/megablox/gmm.py new file mode 100644 index 00000000000..518553d474e --- /dev/null +++ b/torch_xla/experimental/megablox/gmm.py @@ -0,0 +1,395 @@ +"""Grouped matrix multiplication kernels for TPU written in Pallas.""" + +from typing import Any, Callable, Optional, Union +from torch_xla.experimental.megablox import common +from torch_xla.experimental.custom_kernel import jax_import_guard +import torch +import torch_xla +import numpy as np + + +def _validate_args( + *, + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + expected_rhs_dims: int = 3, +) -> 'tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]': + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + """Validates the arguments for the gmm function.""" + # Validate 'lhs'. + if lhs.dim() != 2: + raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.dim()}-tensor.") + common.assert_is_supported_dtype(lhs.dtype) + + # Validate 'rhs'. + if rhs.dim() != expected_rhs_dims: + raise ValueError(f"Expected {expected_rhs_dims}-tensor for 'rhs' but got" + f" {rhs.dim()}-tensor.") + common.assert_is_supported_dtype(rhs.dtype) + + # Validate 'group_sizes'. + if group_sizes.dtype != torch.int32: + raise ValueError( + f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}.") + + return lhs, group_sizes, common.select_input_dtype(lhs, rhs) + + +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + + +def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]: + tiles, rem = divmod(x, tx) + if rem: + tiles += 1 + return tiles, rem + + +GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple + + +def _make_group_metadata( + *, + group_sizes: 'jnp.ndarray', + m: int, + tm: int, + start_group: 'jnp.ndarray', + num_nonzero_groups: int, + visit_empty_groups: bool = True, +) -> GroupMetadata: + """Create the metadata needed for grouped matmul computation. + + Args: + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + m: The number of rows in lhs. + tm: The m-dimension tile size being used. + start_group: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_nonzero_groups: Number of groups in group sizes to compute on. Useful in + combination with group_offset. + visit_empty_groups: If True, do not squeeze tiles for empty groups out of + the metadata. This is necessary for tgmm, where we at least need to zero + the output for each group. + + Returns: + tuple of: + group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute. + """ + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + + num_groups = group_sizes.shape[0] + end_group = start_group + num_nonzero_groups - 1 + + # Calculate the offset of each group, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # group_offsets.shape = [num_groups + 1] + # group_offsets[0] = 0 + # group_offsets[num_groups] = m + # + # The row at which group 'i' starts is group_offsets[i]. + group_ends = jnp.cumsum(group_sizes) + group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends]) + + # Assign a group id to each grid index. + # + # If a group starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each group by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the group_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change group_offsets[num_groups], which is m + # (because we enforce m is divisible by tm). + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) + + # (2) Round the group_starts down to the nearest multiple of 'tm'. + group_starts = jnp.concatenate( + [jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]) + rounded_group_starts = group_starts // tm * tm + + # (3) Calculate the number of rows in each group. + # + # NOTE: Handle zero-sized groups as a special case. If the start for a + # zero-sized group is not divisible by 'tm' its start will be rounded down and + # its end will be rounded up such that its size will become 1 tile here. + rounded_group_sizes = rounded_group_ends - rounded_group_starts + rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes) + + # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by group 'i' if the first row of the tile + # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' group never has a partial tile because it always starts at + # the 0-th row. + # + # If no group has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every group has a partial except the 0-th group, the total + # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # + # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + group_tiles = rounded_group_sizes // tm + + if visit_empty_groups: + # Insert one tile for empty groups. + group_tiles = jnp.where(group_sizes == 0, 1, group_tiles) + + # Create the group ids for each grid index based on the tile counts for each + # group. + # + # NOTE: This repeat(...) will pad group_ids with the final group id if + # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_tiles, + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the group that owns the tile. + # The remaining possible visits occur when a group starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each group starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the group that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. Also filter out any + # group which is empty. + # + # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. + partial_tile_mask = jnp.logical_or((group_offsets[:-1] % tm) == 0, + group_sizes == 0) + + # Explicitly enable tiles for zero sized groups, if specified. This covers + # zero sized groups that start on a tile-aligned row and those that do not. + if visit_empty_groups: + partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask) + + partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, + group_offsets[:-1] // tm) + + tile_visits = ( + jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + + 1) + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Account for sharding. + # + # Find the start of the groups owned by our shard and shift the group_ids and + # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. + # + # TODO(tgale): Move this offset into the kernel to avoid these rolls. + first_tile_in_shard = (group_ids < start_group).sum() + group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0) + m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) + + # Calculate the number of tiles we need to compute for our shard. + # + # Remove tile visits that belong to a group not in our shard. + iota = jnp.arange(num_groups, dtype=jnp.int32) + active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group) + group_tiles = jnp.where(active_group_mask, group_tiles, 0) + num_tiles = group_tiles.sum() + return (group_offsets, group_ids, m_tile_ids), num_tiles + + +def _zero_uninitialized_memory( + out: 'jnp.ndarray', + *, + start_group: 'jnp.ndarray', + num_nonzero_groups: int, + group_metadata: GroupMetadata, +) -> torch.Tensor: + """Zero out uninitialized memory from output.""" + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + + group_offsets = group_metadata[0] + group_start = group_offsets[start_group] + group_end = group_offsets[start_group + num_nonzero_groups] + valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0],), 0) + valid_mask = (valid_mask >= group_start) & (valid_mask < group_end) + return torch.from_numpy(np.array(jnp.where(valid_mask[:, None], out, + 0))).to('xla') + + +LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] + + +def _gmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + payload: str, + preferred_element_type: torch.dtype = torch.float32, + tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), + group_offset: Optional[torch.Tensor] = None, + existing_out: Optional[torch.Tensor] = None, + transpose_rhs: bool = False, + interpret: bool = False, +) -> torch.Tensor: + """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. + + Args: + lhs: A 2d, torch.Tensor with shape [m, k]. + rhs: A 3d, torch.Tensor with shape [num_groups, k, n]. + group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype. + payload: pallas payload extracted from the pallas code on JAX. + preferred_element_type: torch.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + group_offset: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + existing_out: Existing output to write to. + transpose_rhs: True if the rhs needs to be transposed. + interpret: Whether or not to run the kernel in interpret mode, helpful for + testing and debugging. + + Returns: + A 2d, torch.Tensor with shape [m, n]. + """ + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + import jax.numpy as jnp + + if existing_out is not None: + assert isinstance(existing_out, jax.Array) + expected_dtype = existing_out.dtype + if expected_dtype != preferred_element_type: + raise ValueError( + "Existing output dtype must match preferred_element_type.") + if group_offset is None: + group_offset = jnp.array([0], dtype=jnp.int32) + else: + group_offset = jnp.ndarray(group_offset.numpy()) + if group_offset.shape: + raise ValueError( + f"group_offset must be a ()-shaped array. Got: {group_offset.shape}.") + group_offset = group_offset[None] + num_current_groups = rhs.shape[0] + num_total_groups = group_sizes.shape[0] + lhs, group_sizes, input_dtype = _validate_args( + lhs=lhs, rhs=rhs, group_sizes=group_sizes) + + # Gather shape information. + m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2]) + if transpose_rhs: + n = rhs.shape[1] + + # If tiling is callable, look up the problem dimensions in the LUT. If no tuned + # tile dimensions are available throw an error. + if callable(tiling): + tiling = tiling(m, k, n) + + if tiling is None: + raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})") + + tm, tk, tn = tiling + tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk) + tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn) + del n_rem + + # Create the metadata we need for computation. + group_sizes = jnp.asarray(group_sizes.numpy()) + group_metadata, num_active_tiles = _make_group_metadata( # pylint: disable=unbalanced-tuple-unpacking + group_sizes=group_sizes, + m=m, + tm=tm, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + visit_empty_groups=False, + ) + group_metadata0 = torch.from_numpy(np.array(group_metadata[0])).to( + torch.int32).to("xla") + group_metadata1 = torch.from_numpy(np.array(group_metadata[1])).to("xla") + group_metadata2 = torch.from_numpy(np.array(group_metadata[2])).to("xla") + num_active_tiles = torch.tensor(np.array(num_active_tiles)).to("xla") + group_offset_torch = torch.from_numpy(np.array(group_offset)).to("xla") + output_shape = torch.Size([m, n]) + out = torch_xla._XLAC._xla_tpu_custom_call([ + num_active_tiles, group_metadata0, group_metadata1, group_metadata2, + group_offset_torch, lhs, rhs + ], payload, [output_shape], [preferred_element_type]) + + if existing_out is None and num_current_groups < num_total_groups: + out = jnp.asarray(out.cpu().float().numpy()) + out = _zero_uninitialized_memory( + out, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + group_metadata=group_metadata, + ) + return out + + +def gmm( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + preferred_element_type: torch.dtype = torch.float32, + tiling: Optional[Union[tuple[int, int, int], LutFn]] = (128, 128, 128), + group_offset: Optional[torch.Tensor] = None, + existing_out: Optional[torch.Tensor] = None, + transpose_rhs: bool = False, + interpret: bool = False, +): + # Import JAX within the function such that we don't need to call the jax_import_guard() + # in the global scope which could cause problems for xmp.spawn. + jax_import_guard() + import jax + from jax.experimental.pallas.ops.tpu.megablox import gmm + from torch_xla.experimental.custom_kernel import trace_pallas + + payload, _ = trace_pallas(gmm, lhs, rhs, group_sizes) + out = _gmm(lhs, rhs, group_sizes, payload, preferred_element_type, tiling, + group_offset, existing_out, transpose_rhs, interpret) + return out