Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support gmm as a custom op for dynamo #7672

Merged
merged 6 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 40 additions & 30 deletions test/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,48 +93,58 @@ def _init_test_cases(self):
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', 'highest')
jax.config.update('jax_default_matmul_precision', "highest")
gmm_funcs = [
gmm, torch.ops.xla.gmm,
torch.compile(torch.ops.xla.gmm, backend="openxla")
]

self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
jax.config.update('jax_default_matmul_precision', 'default')
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm_bf16(self):
met.clear_all()

gmm_funcs = [
gmm, torch.ops.xla.gmm,
torch.compile(torch.ops.xla.gmm, backend="openxla")
]
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16
for gmm_func in gmm_funcs:
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = torch.bfloat16

lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
lhs = torch.rand(m, k, dtype=lhs_dtype)
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype)
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
ref_out = self._reference_gmm(lhs, rhs, group_sizes)

out = gmm(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))
out = gmm_func(lhs.to("xla"), rhs.to("xla"), group_sizes.to("xla"))

self.assertTrue(torch.allclose(ref_out, out.cpu()))
self.assertTrue(torch.allclose(ref_out, out.cpu()))

# Make sure gmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
Expand Down Expand Up @@ -300,7 +310,7 @@ def test_sorting_input(self):
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm(self):
met.clear_all()
jax.config.update('jax_default_matmul_precision', 'highest')
jax.config.update('jax_default_matmul_precision', "highest")

self._init_test_cases()
for test_case in self.tests_cases:
Expand All @@ -320,7 +330,7 @@ def test_tgmm(self):

# Make sure tgmm doesn't fallback.
self.assertNotIn("aten::", met.short_metrics_report())
jax.config.update('jax_default_matmul_precision', 'default')
jax.config.update('jax_default_matmul_precision', "default")

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tgmm_bf16(self):
Expand Down
51 changes: 47 additions & 4 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
import torch_xla.debug.metrics as met
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you add this in accident?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol forgot to delete


from typing import Any, List, Callable, Optional
from torch.library import impl
Expand Down Expand Up @@ -617,6 +618,7 @@ def _make_group_metadata(
# group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized
# such that we only execute the necessary number of tiles.
tiles_m = _calculate_num_tiles(m, tm)

group_ids = repeat_with_fixed_output_size(
torch.arange(num_groups, dtype=torch.int32).to(device), group_tiles,
tiles_m + num_groups - 1)
Expand Down Expand Up @@ -683,7 +685,8 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
# shift the repeats by one
# tensor([0, 0, 1, 2, 0, 4, 0, 6, 7, 8])
exclusive_repeats = torch.roll(repeats, shifts=1)
exclusive_repeats[0] = 0
exclusive_repeats = exclusive_repeats.index_copy(
0, torch.tensor([0], device=device), torch.tensor([0], device=device))

# tensor([ 0, 0, 1, 3, 3, 7, 7, 13, 20, 28])
scatter_indices = torch.cumsum(exclusive_repeats, dim=0)
Expand All @@ -698,10 +701,12 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
# tensor([2, 1, 0, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
block_split_indicators = torch.zeros(
total_repeat_length, dtype=torch.int32, device=device)
block_split_indicators.scatter_add_(0, valid_indices.to(torch.int64),
torch.ones_like(block_split_indicators))
block_split_indicators = block_split_indicators.scatter_add(
0, valid_indices.to(torch.int64), torch.ones_like(block_split_indicators))
# out_of_bound indices also scatter to index 0, need to offset them
block_split_indicators[0] -= out_of_bound_count
block_split_indicators = block_split_indicators.index_copy(
0, torch.tensor([0], device=device),
(block_split_indicators[0] - out_of_bound_count).unsqueeze(0))

# value in gather_indices represents the index in the input.
# tensor([1, 2, 2, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7])
Expand Down Expand Up @@ -894,3 +899,41 @@ def paged_attention_non_xla(q: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: str = None):
return non_xla_attetion(q, k_pages, v_pages, "paged")


XLA_LIB.define(
"gmm(Tensor lhs, Tensor rhs, Tensor group_sizes, int[]? tiling=None) -> Tensor",
)


@impl(XLA_LIB, "gmm", "XLA")
def gmm_xla(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
# pytorch custom op does not allow tuple type, use list instead
tiling: Optional[list[int]] = [512, 512, 512]):
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
tiling = tuple(tiling)
return gmm(lhs, rhs, group_sizes, tiling)


@impl(XLA_LIB, "gmm", "CompositeExplicitAutograd")
def gmm_non_xla(lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: Optional[list[int]] = [512, 512, 512]):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if lhs.device != torch.device("meta"):
warnings.warn(f'XLA gmm should only be applied to tensors on XLA device')
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"

# we only need to return the tensor with correct shape for meta tensor.
return torch.empty(lhs.size()[0], rhs.size()[2], device=lhs.device)
Loading