Skip to content

Commit

Permalink
[Inductor] make decompose_mm_pass support cpu case (pytorch#139696)
Browse files Browse the repository at this point in the history
Summary: Previously, decompose_mm_pass only works for gpu case. This diff make it support some cpu case as well for the performance optimization

Differential Revision: D65226131

Pull Request resolved: pytorch#139696
Approved by: https://github.com/eellison
  • Loading branch information
hl475 authored and pytorchmergebot committed Nov 12, 2024
1 parent 965555d commit 330c957
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 10 deletions.
69 changes: 69 additions & 0 deletions test/inductor/test_decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch._inductor
from torch._dynamo.utils import counters
from torch._inductor.fx_passes.decompose_mem_bound_mm import check_device
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
Expand Down Expand Up @@ -117,6 +118,29 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose):
)
counters.clear()

@parametrize(
"b,m,k,n,should_decompose",
[(1, 2, 2, 2, True), (2, 2, 2, 2, False)],
)
def test_decompose_bmm_cpu(self, b, m, n, k, should_decompose):
torch._logging.set_logs(inductor=logging.DEBUG)
mat1 = torch.randn(b, m, k)
mat2 = torch.randn(b, k, n)

counters.clear()

module = MyModule2()
traced = torch.compile(module)
input = [mat1, mat2]
self.compare_pred(module, traced, input)

expected_val = 1 if should_decompose else 0
self.assertEqual(
counters["inductor"]["decompose_bmm"],
expected_val,
)
counters.clear()

@parametrize(
"m,k,n, should_decompose",
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
Expand Down Expand Up @@ -247,6 +271,28 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose):
)
counters.clear()

@parametrize(
"m,k,n, should_decompose",
[(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, False)],
)
def test_decompose_mm_cpu(self, m, n, k, should_decompose):
torch._logging.set_logs(inductor=logging.DEBUG)
mat1 = torch.randn(m, k)
mat2 = torch.randn(k, n)
counters.clear()

module = MyModule3()
traced = torch.compile(module)
input = [mat1, mat2]
self.compare_pred(module, traced, input)

expected_val = 1 if should_decompose else 0
self.assertEqual(
counters["inductor"]["decompose_mm"],
expected_val,
)
counters.clear()

@parametrize(
"m,k,n, should_decompose",
[(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)],
Expand Down Expand Up @@ -347,6 +393,29 @@ def foo(x, y):
# two kernels generated
FileCheck().check_count(".run(", 2, exactly=True).run(code[0])

def test_check_device(self):
m = 5
k = 5
n = 2
torch._logging.set_logs(inductor=logging.DEBUG)

input1 = torch.randn(m, k, device=GPU_TYPE)
input2 = torch.randn(k, n, device=GPU_TYPE)
self.assertTrue(check_device(input1, input2))
self.assertFalse(check_device(input1, input2, device="cpu"))

input1 = torch.randn(m, k)
input2 = torch.randn(k, n)
self.assertTrue(check_device(input1, input2, device="cpu"))
self.assertFalse(check_device(input1, input2))

input1 = torch.randn(m, k, device=GPU_TYPE)
input2 = torch.randn(k, n)
self.assertFalse(check_device(input1, input2, device="gpu"))
self.assertFalse(check_device(input1, input2, device="cpu"))

self.assertFalse(check_device(input1, input2, device="mtia"))


if __name__ == "__main__":
run_tests()
27 changes: 17 additions & 10 deletions torch/_inductor/fx_passes/decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION)


def check_device(a: Tensor, b: Tensor) -> bool:
return a.is_cuda and b.is_cuda
def check_device(a: Tensor, b: Tensor, device="cuda") -> bool:
return (a.device.type == b.device.type) and (b.device.type == device)


def realize_inputs(inputs: List[torch.fx.Node]):
Expand All @@ -45,19 +45,21 @@ def should_decompose_bmm(mat1, mat2) -> bool:
mat2 = mat2.meta["val"]
else:
return False
if not check_device(mat1, mat2):
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
return False
else:
if len(mat1.shape) != 3 or len(mat2.shape) != 3:
return False
if check_device(mat1, mat2, device="cuda"):
if mat1.shape[0] < min_first_dimension_decomposition:
return False
# 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION
if (mat1.shape[1] < max_other_dimention_decomposition) + (
mat1.shape[2] < max_other_dimention_decomposition
) + (mat2.shape[2] < max_other_dimention_decomposition) < 2:
return False
return True
return True
elif check_device(mat1, mat2, device="cpu"):
if mat1.shape[0] == 1 and mat2.shape[0] == 1:
return True
return False


def should_decompose_mm(mat1, mat2) -> bool:
Expand All @@ -66,13 +68,18 @@ def should_decompose_mm(mat1, mat2) -> bool:
mat2 = mat2.meta["val"]
else:
return False
if len(mat1.shape) != 2 or len(mat2.shape) != 2:
return False
return (
check_device(mat1, mat2)
and len(mat1.shape) == 2
and len(mat2.shape) == 2
check_device(mat1, mat2, device="cuda")
and mat1.shape[0] >= min_first_dimension_decomposition
and mat2.shape[0] < max_other_dimention_decomposition
and mat2.shape[1] < max_other_dimention_decomposition
) or (
check_device(mat1, mat2, device="cpu")
and mat1.shape[0] == 1
and mat2.shape[0] <= 64
and mat2.shape[1] <= 16
)


Expand Down

0 comments on commit 330c957

Please sign in to comment.