Skip to content

Commit

Permalink
support gmm as a custom op for dynamo (#7672)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Jul 12, 2024
1 parent 8043050 commit 1651e76
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 34 deletions.
70 changes: 40 additions & 30 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,48 +93,58 @@ def _init_test_cases(self):
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', 'highest')
jax.config.update('jax_default_matmul_precision', "highest")
gmm_funcs = [
gmm, torch.ops.xla.gmm,
torch.compile(torch.ops.xla.gmm, backend="openxla")
]

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
jax.config.update('jax_default_matmul_precision', 'default')
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_bf16(self):
met.clear_all()

gmm_funcs = [
gmm, torch.ops.xla.gmm,
torch.compile(torch.ops.xla.gmm, backend="openxla")
]
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))

self.assertTrue(torch.allclose(ref_out, out.cpu()))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
Expand Down Expand Up @@ -300,7 +310,7 @@ def test_sorting_input(self):
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', 'highest')
jax.config.update('jax_default_matmul_precision', "highest")

self._init_test_cases()
for test_case in self.tests_cases:
Expand All @@ -320,7 +330,7 @@ def test_tgmm(self):

# Make sure tgmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
jax.config.update('jax_default_matmul_precision', 'default')
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm_bf16(self):
Expand Down
51 changes: 47 additions & 4 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met

from typing import Any, List, Callable, Optional
from torch.library import impl
Expand Down Expand Up @@ -617,6 +618,7 @@ def _make_group_metadata(
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
# such that we only execute the necessary number of tiles.
tiles_m = _calculate_num_tiles(m, tm)

group_ids = repeat_with_fixed_output_size(
torch.arange(num_groups, dtype=torch.int32).to(device), group_tiles,
tiles_m + num_groups - 1)
Expand Down Expand Up @@ -683,7 +685,8 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
# shift the repeats by one
# tensor([0, 0, 1, 2, 0, 4, 0, 6, 7, 8])
exclusive_repeats = torch.roll(repeats, shifts=1)
exclusive_repeats[0] = 0
exclusive_repeats = exclusive_repeats.index_copy(
0, torch.tensor([0], device=device), torch.tensor([0], device=device))

# tensor([ 0, 0, 1, 3, 3, 7, 7, 13, 20, 28])
scatter_indices = torch.cumsum(exclusive_repeats, dim=0)
Expand All @@ -698,10 +701,12 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
# tensor([2, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
block_split_indicators = torch.zeros(
total_repeat_length, dtype=torch.int32, device=device)
block_split_indicators.scatter_add_(0, valid_indices.to(torch.int64),
torch.ones_like(block_split_indicators))
block_split_indicators = block_split_indicators.scatter_add(
0, valid_indices.to(torch.int64), torch.ones_like(block_split_indicators))
# out_of_bound indices also scatter to index 0, need to offset them
block_split_indicators[0] -= out_of_bound_count
block_split_indicators = block_split_indicators.index_copy(
0, torch.tensor([0], device=device),
(block_split_indicators[0] - out_of_bound_count).unsqueeze(0))

# value in gather_indices represents the index in the input.
# tensor([1, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7])
Expand Down Expand Up @@ -894,3 +899,41 @@ def paged_attention_non_xla(q: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: str = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


XLA_LIB.define(
"gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor",
)


@impl(XLA_LIB, "gmm", "XLA")
def gmm_xla(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
# pytorch custom op does not allow tuple type, use list instead
tiling: Optional[list[int]] = [512, 512, 512]):
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
tiling = tuple(tiling)
return gmm(lhs, rhs, group_sizes, tiling)


@impl(XLA_LIB, "gmm", "CompositeExplicitAutograd")
def gmm_non_xla(lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: Optional[list[int]] = [512, 512, 512]):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if lhs.device != torch.device("meta"):
warnings.warn(f'XLA gmm should only be applied to tensors on XLA device')
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"

# we only need to return the tensor with correct shape for meta tensor.
return torch.empty(lhs.size()[0], rhs.size()[2], device=lhs.device)

0 comments on commit 1651e76

Please sign in to comment.