diff --git a/test/test_gmm.py b/test/test_gmm.py index 7db9adcbf93..39679946074 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -93,48 +93,58 @@ def _init_test_cases(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_gmm(self): met.clear_all() - jax.config.update('jax_default_matmul_precision', 'highest') + jax.config.update('jax_default_matmul_precision', "highest") + gmm_funcs = [ + gmm, torch.ops.xla.gmm, + torch.compile(torch.ops.xla.gmm, backend="openxla") + ] self._init_test_cases() - for test_case in self.tests_cases: - num_groups = test_case['num_groups'] - k = test_case['k'] - m = test_case['m'] - n = test_case['n'] - lhs_dtype = rhs_dtype = test_case['dtype'] - - lhs = torch.rand(m, k, dtype=lhs_dtype) - rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) - group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) - ref_out = self._reference_gmm(lhs, rhs, group_sizes) - - out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) - self.assertTrue(torch.allclose(ref_out, out.cpu())) + for gmm_func in gmm_funcs: + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = test_case['dtype'] + + lhs = torch.rand(m, k, dtype=lhs_dtype) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + ref_out = self._reference_gmm(lhs, rhs, group_sizes) + + out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + self.assertTrue(torch.allclose(ref_out, out.cpu())) # Make sure gmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) - jax.config.update('jax_default_matmul_precision', 'default') + jax.config.update('jax_default_matmul_precision', "default") @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_gmm_bf16(self): met.clear_all() + gmm_funcs = [ + gmm, torch.ops.xla.gmm, + torch.compile(torch.ops.xla.gmm, backend="openxla") + ] self._init_test_cases() - for test_case in self.tests_cases: - num_groups = test_case['num_groups'] - k = test_case['k'] - m = test_case['m'] - n = test_case['n'] - lhs_dtype = rhs_dtype = torch.bfloat16 + for gmm_func in gmm_funcs: + for test_case in self.tests_cases: + num_groups = test_case['num_groups'] + k = test_case['k'] + m = test_case['m'] + n = test_case['n'] + lhs_dtype = rhs_dtype = torch.bfloat16 - lhs = torch.rand(m, k, dtype=lhs_dtype) - rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) - group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) - ref_out = self._reference_gmm(lhs, rhs, group_sizes) + lhs = torch.rand(m, k, dtype=lhs_dtype) + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype) + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) + ref_out = self._reference_gmm(lhs, rhs, group_sizes) - out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) + out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla")) - self.assertTrue(torch.allclose(ref_out, out.cpu())) + self.assertTrue(torch.allclose(ref_out, out.cpu())) # Make sure gmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) @@ -300,7 +310,7 @@ def test_sorting_input(self): @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tgmm(self): met.clear_all() - jax.config.update('jax_default_matmul_precision', 'highest') + jax.config.update('jax_default_matmul_precision', "highest") self._init_test_cases() for test_case in self.tests_cases: @@ -320,7 +330,7 @@ def test_tgmm(self): # Make sure tgmm doesn't fallback. self.assertNotIn("aten::", met.short_metrics_report()) - jax.config.update('jax_default_matmul_precision', 'default') + jax.config.update('jax_default_matmul_precision', "default") @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") def test_tgmm_bf16(self): diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 95d5e9f0df7..1e967a33c5b 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -7,6 +7,7 @@ import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.spmd as xs +import torch_xla.debug.metrics as met from typing import Any, List, Callable, Optional from torch.library import impl @@ -617,6 +618,7 @@ def _make_group_metadata( # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized # such that we only execute the necessary number of tiles. tiles_m = _calculate_num_tiles(m, tm) + group_ids = repeat_with_fixed_output_size( torch.arange(num_groups, dtype=torch.int32).to(device), group_tiles, tiles_m + num_groups - 1) @@ -683,7 +685,8 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor, # shift the repeats by one # tensor([0, 0, 1, 2, 0, 4, 0, 6, 7, 8]) exclusive_repeats = torch.roll(repeats, shifts=1) - exclusive_repeats[0] = 0 + exclusive_repeats = exclusive_repeats.index_copy( + 0, torch.tensor([0], device=device), torch.tensor([0], device=device)) # tensor([ 0, 0, 1, 3, 3, 7, 7, 13, 20, 28]) scatter_indices = torch.cumsum(exclusive_repeats, dim=0) @@ -698,10 +701,12 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor, # tensor([2, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) block_split_indicators = torch.zeros( total_repeat_length, dtype=torch.int32, device=device) - block_split_indicators.scatter_add_(0, valid_indices.to(torch.int64), - torch.ones_like(block_split_indicators)) + block_split_indicators = block_split_indicators.scatter_add( + 0, valid_indices.to(torch.int64), torch.ones_like(block_split_indicators)) # out_of_bound indices also scatter to index 0, need to offset them - block_split_indicators[0] -= out_of_bound_count + block_split_indicators = block_split_indicators.index_copy( + 0, torch.tensor([0], device=device), + (block_split_indicators[0] - out_of_bound_count).unsqueeze(0)) # value in gather_indices represents the index in the input. # tensor([1, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7]) @@ -894,3 +899,41 @@ def paged_attention_non_xla(q: torch.Tensor, pages_per_compute_block: int, megacore_mode: str = None): return non_xla_attetion(q, k_pages, v_pages, "paged") + + +XLA_LIB.define( + "gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor", +) + + +@impl(XLA_LIB, "gmm", "XLA") +def gmm_xla( + lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + # pytorch custom op does not allow tuple type, use list instead + tiling: Optional[list[int]] = [512, 512, 512]): + assert len(tiling) == 3, "tiling must be a list with 3 integers" + assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]" + assert rhs.dim( + ) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]" + tiling = tuple(tiling) + return gmm(lhs, rhs, group_sizes, tiling) + + +@impl(XLA_LIB, "gmm", "CompositeExplicitAutograd") +def gmm_non_xla(lhs: torch.Tensor, + rhs: torch.Tensor, + group_sizes: torch.Tensor, + tiling: Optional[list[int]] = [512, 512, 512]): + # This will be called when dynamo use fake tensor to construct the fake output. + # We need to make sure output tensor's shape is correct. + if lhs.device != torch.device("meta"): + warnings.warn(f'XLA gmm should only be applied to tensors on XLA device') + assert len(tiling) == 3, "tiling must be a list with 3 integers" + assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]" + assert rhs.dim( + ) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]" + + # we only need to return the tensor with correct shape for meta tensor. + return torch.empty(lhs.size()[0], rhs.size()[2], device=lhs.device)