Skip to content

Commit

Permalink
Adding megablox gmm standalone (pytorch#6940)
Browse files Browse the repository at this point in the history
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
  • Loading branch information
miladm and wonjoolee95 committed May 10, 2024
1 parent c1b745e commit 40f7e1f
Show file tree
Hide file tree
Showing 5 changed files with 580 additions and 0 deletions.
161 changes: 161 additions & 0 deletions test/test_megablox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Grouped matrix multiplication kernels for TPU written in Pallas."""

import logging
import unittest

from typing import Optional, Union, Callable

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.experimental.megablox as megablox
from torch_xla import runtime as xr
from torch_xla._internal import tpu

import numpy as np

if xr.device_type() == 'TPU':
from torch_xla.experimental.custom_kernel import jax_import_guard
jax_import_guard()
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl


class MegabloxTest(unittest.TestCase):

def _reference_gmm(
self,
lhs: np.array,
rhs: np.array,
group_sizes: np.array,
preferred_element_type: np.dtype = np.float32,
) -> np.array:

start = 0
out = []
for i, size in enumerate(group_sizes):
result = np.dot(lhs[start:start + size, :], rhs[i, :, :])

result = result.astype(preferred_element_type)
out.append(result)
start += group_sizes[i]
return np.array(np.concatenate(out, axis=0))

def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor:
# Randomly sample the ends of the groups in the m-dimension. Let the fuzzer
# sample with replacement so that it's possible to get zero-sized groups. Get
# 'num_groups - 1' run ends. The final group will end at 'm'.
ends_no_final = np.sort(
np.array(
[np.random.randint(low=0, high=m) for _ in range(num_groups - 1)],
dtype=np.int32,
),)
ends = np.concatenate([ends_no_final, np.array([m], dtype=np.int32)])

# Calculate the run starts by shifting ends 1 to the right. The first run
# starts at zero.
starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final])
return torch.from_numpy(ends - starts).to(torch.int32)

def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype,
out_dtype: torch.dtype) -> tuple[float, float]:
if (lhs_dtype == torch.bfloat16 or rhs_dtype == torch.bfloat16 or
out_dtype == torch.bfloat16):
return 1e-3, 1e-2 # atol, rtol
return 1e-4, 1e-2 # atol, rtol

LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]]

def _init_test_cases(self):
self.tests_cases = []
self.tests_cases.append({
'dtype': torch.float32,
'm': 128,
'k': 128,
'n': 128,
'num_groups': 1
})
self.tests_cases.append({
'dtype': torch.float32,
'm': 256,
'k': 128,
'n': 128,
'num_groups': 1
})
self.tests_cases.append({
'dtype': torch.float32,
'm': 128,
'k': 256,
'n': 128,
'num_groups': 8
})
self.tests_cases.append({
'dtype': torch.float32,
'm': 512,
'k': 128,
'n': 256,
'num_groups': 2
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 128,
'k': 128,
'n': 128,
'num_groups': 1
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 256,
'k': 128,
'n': 128,
'num_groups': 1
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 128,
'k': 256,
'n': 128,
'num_groups': 8
})
self.tests_cases.append({
'dtype': torch.bfloat16,
'm': 512,
'k': 128,
'n': 256,
'num_groups': 2
})

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_gmm(self):
self._init_test_cases()
for test_case in self.tests_cases:
num_groups = test_case['num_groups']
k = test_case['k']
m = test_case['m']
n = test_case['n']
lhs_dtype = rhs_dtype = test_case['dtype']
out_dtype = torch.float32

lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla')
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla')
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
out = megablox.gmm(lhs, rhs, group_sizes)

ref_out = self._reference_gmm(lhs.cpu().float().numpy(),
rhs.cpu().float().numpy(),
group_sizes.numpy())

atol, rtol = self._tolerances(lhs_dtype, rhs_dtype, out_dtype)
np.testing.assert_allclose(
ref_out, np.array(out[0].cpu()), rtol=rtol, atol=atol)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
use_full_mat_mul_precision=True)
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ python3 test/pjrt/test_dynamic_plugin_tpu.py
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
python3 test/test_pallas.py
python3 test/test_input_output_aliases.py
python3 test/test_megablox.py
python3 test/torch_distributed/test_torch_distributed_all_gather_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_all_reduce_xla_backend.py
python3 test/torch_distributed/test_torch_distributed_multi_all_reduce_xla_backend.py
Expand Down
1 change: 1 addition & 0 deletions torch_xla/experimental/megablox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gmm import gmm
22 changes: 22 additions & 0 deletions torch_xla/experimental/megablox/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Common utilities for Pallas kernels."""

from typing import Union
import torch
from torch_xla._internal import tpu


def assert_is_supported_dtype(dtype: torch.dtype) -> None:
if dtype != torch.bfloat16 and dtype != torch.float32:
raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")


def select_input_dtype(lhs: torch.Tensor, rhs: torch.Tensor) -> torch.dtype:
"""A type to which both input should be adapted to before dot product."""
# bf16xbf16 matmul is only supported since TPU v4 generation. In
# case of mixed input precision, we need to convert bf16 argument to fp32
# beforehand.
if (tpu.version() >= 4 and lhs.dtype == torch.bfloat16 and
rhs.dtype == torch.bfloat16):
return torch.bfloat16
else:
return torch.float32
Loading

0 comments on commit 40f7e1f

Please sign in to comment.