Skip to content

Commit

Permalink
fixing internal test failure on non sm_80 machines (#107340)
Browse files Browse the repository at this point in the history
Summary:
These tests were failing on non sm_80+ machines used for internal CI, added check to skip this.

D48295360 added new tests that work in OSS but not in phabricator CI

https://www.internalfb.com/intern/test/562950057441807?ref_report_id=0

https://www.internalfb.com/intern/test/281475080709193?ref_report_id=0

Test Plan: see phabricator result

Differential Revision: D48417499

Pull Request resolved: #107340
Approved by: https://github.com/davidberard98
  • Loading branch information
HDCharles authored and pytorchmergebot committed Aug 19, 2023
1 parent b5642f0 commit 3ddf305
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch._inductor.fx_passes import joint_graph
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import IS_LINUX
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand Down Expand Up @@ -63,6 +64,7 @@ def _test_mixed_impl(self, fn, args, mixed_mm_expected, fallback_mixed_mm_expect
self.assertEqual("mixed_mm" in code, mixed_mm_expected)
self.assertEqual("fallback_mixed_mm" in code, fallback_mixed_mm_expected)

@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_mixed_mm=True)
def test_mixed_mm(self):
def fn(a, b):
Expand Down Expand Up @@ -90,6 +92,7 @@ def fn(a, b):
for args in args_list:
self._test_mixed_impl(fn, args, True, False)

@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(force_mixed_mm=True, max_autotune_gemm=True)
def test_mixed_mm_epi_works(self):
def fn(a, b, c, d):
Expand Down Expand Up @@ -119,6 +122,7 @@ def fn(a, b, c, d):
for args in args_list:
self._test_mixed_impl(fn, args, True, False)

@unittest.skipIf(not SM80OrLater, "need sm_80")
def test_mixed_mm_gating(self):
def fn(a, b):
return torch.mm(a, b.to(a.dtype))
Expand Down Expand Up @@ -154,6 +158,7 @@ def fn(a, b):
)
self._test_mixed_impl(fn, args, False, False)

@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm(self):
def fn(a, b):
Expand Down Expand Up @@ -188,6 +193,7 @@ def fn(a, b):
torch.testing.assert_close(ref, test)
self.assertTrue("uint4x2_mixed_mm" in code)

@unittest.skipIf(not SM80OrLater, "need sm_80")
@inductor_config.patch(use_mixed_mm=True)
def test_uint4x2_mixed_mm_epi(self):
def fn(a, b, c, d):
Expand Down

0 comments on commit 3ddf305

Please sign in to comment.