Skip to content

Commit

Permalink
Revert "Skip Unit Tests for ROCm CI (#1563)"
Browse files Browse the repository at this point in the history
This reverts commit a1c67b9.
  • Loading branch information
andrewor14 authored Jan 17, 2025
1 parent a1c67b9 commit 5b3fb3e
Show file tree
Hide file tree
Showing 16 changed files with 1 addition and 71 deletions.
Empty file removed test/__init__.py
Empty file.
4 changes: 0 additions & 4 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import unittest

import torch
from test_utils import skip_if_rocm
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
Expand Down Expand Up @@ -90,7 +89,6 @@ def test_tensor_core_layout_transpose(self):
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

@skip_if_rocm("ROCm development in progress")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"apply_quant", get_quantization_functions(True, True, "cuda", True)
Expand Down Expand Up @@ -170,7 +168,6 @@ def apply_uint6_weight_only_quant(linear):

deregister_aqt_quantized_linear_dispatch(dispatch_condition)

@skip_if_rocm("ROCm development in progress")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
Expand All @@ -183,7 +180,6 @@ class TestAffineQuantizedBasic(TestCase):
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
COMMON_DTYPES = [torch.bfloat16]

@skip_if_rocm("ROCm development in progress")
@common_utils.parametrize("device", COMMON_DEVICES)
@common_utils.parametrize("dtype", COMMON_DTYPES)
def test_flatten_unflatten(self, device, dtype):
Expand Down
2 changes: 0 additions & 2 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import unittest

import torch
from test_utils import skip_if_rocm
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
Expand Down Expand Up @@ -109,7 +108,6 @@ def test_to_copy_device(self, ebits, mbits):
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@skip_if_rocm("ROCm development in progress")
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
Expand Down
3 changes: 0 additions & 3 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from test_utils import skip_if_rocm

from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
Expand Down Expand Up @@ -425,7 +423,6 @@ def test_linear_from_config_params(
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize("linear_bias", [True, False])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skip_if_rocm("ROCm development in progress")
def test_linear_from_recipe(
self,
recipe_name,
Expand Down
2 changes: 0 additions & 2 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import unittest

import torch
from test_utils import skip_if_rocm

from torchao.quantization import (
MappingType,
Expand Down Expand Up @@ -111,7 +110,6 @@ def test_hqq_plain_5bit(self):
ref_dot_product_error=0.000704,
)

@skip_if_rocm("ROCm development in progress")
def test_hqq_plain_4bit(self):
self._test_hqq(
dtype=torch.uint4,
Expand Down
7 changes: 0 additions & 7 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@
except ModuleNotFoundError:
has_gemlite = False

from test_utils import skip_if_rocm

logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -571,7 +569,6 @@ def test_per_token_linear_cpu(self):
self._test_per_token_linear_impl("cpu", dtype)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_per_token_linear_cuda(self):
for dtype in (torch.float32, torch.float16, torch.bfloat16):
self._test_per_token_linear_impl("cuda", dtype)
Expand Down Expand Up @@ -690,7 +687,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -710,7 +706,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand Down Expand Up @@ -904,7 +899,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if device == "cpu":
self.skipTest(f"Temporarily skipping for {device}")
Expand All @@ -924,7 +918,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
@skip_if_rocm("ROCm development in progress")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down
2 changes: 0 additions & 2 deletions test/kernel/test_galore_downproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
from galore_test_utils import make_data
from test_utils import skip_if_rocm

from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
Expand All @@ -30,7 +29,6 @@

@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
@skip_if_rocm("ROCm development in progress")
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32
Expand Down
3 changes: 0 additions & 3 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
if TORCH_VERSION_AT_LEAST_2_3:
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_

from test_utils import skip_if_rocm


class ToyLinearModel(torch.nn.Module):
def __init__(self, m=512, n=256, k=128):
Expand Down Expand Up @@ -115,7 +113,6 @@ def test_awq_loading(device, qdtype):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@skip_if_rocm("ROCm development in progress")
def test_save_weights_only():
dataset_size = 100
l1, l2, l3 = 512, 256, 128
Expand Down
2 changes: 0 additions & 2 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
except ImportError:
lpmm = None

from test_utils import skip_if_rocm

_DEVICES = get_available_devices()

Expand Down Expand Up @@ -113,7 +112,6 @@ class TestOptim(TestCase):
)
@parametrize("dtype", [torch.float32, torch.bfloat16])
@parametrize("device", _DEVICES)
@skip_if_rocm("ROCm development in progress")
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
Expand Down
3 changes: 0 additions & 3 deletions test/prototype/test_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
except ImportError:
triton_available = False

from test_utils import skip_if_rocm

from torchao.utils import skip_if_compute_capability_less_than


@unittest.skipIf(not triton_available, "Triton is required but not available")
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
class TestFP8Gemm(TestCase):
@skip_if_compute_capability_less_than(9.0)
@skip_if_rocm("ROCm development in progress")
def test_gemm_split_k(self):
dtype = torch.float16
qdtype = torch.float8_e4m3fn
Expand Down
2 changes: 0 additions & 2 deletions test/quantization/test_galore_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
dequantize_blockwise,
quantize_blockwise,
)
from test_utils import skip_if_rocm

from torchao.prototype.galore.kernels import (
triton_dequant_blockwise,
Expand Down Expand Up @@ -83,7 +82,6 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
"dim1,dim2,dtype,signed,blocksize",
TEST_CONFIGS,
)
@skip_if_rocm("ROCm development in progress")
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01

Expand Down
3 changes: 0 additions & 3 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pytest
import torch
from test_utils import skip_if_rocm
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

Expand Down Expand Up @@ -46,7 +45,6 @@ def setUp(self):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_marlin_qqq(self):
output_ref = self.model(self.input)
for group_size in [-1, 128]:
Expand All @@ -68,7 +66,6 @@ def test_marlin_qqq(self):

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_marlin_qqq_compile(self):
model_copy = copy.deepcopy(self.model)
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
Expand Down
4 changes: 1 addition & 3 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import torch
from test_utils import skip_if_rocm
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

Expand Down Expand Up @@ -38,7 +37,6 @@ def setUp(self):
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_quant_sparse_marlin_layout_eager(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)
Expand All @@ -50,13 +48,13 @@ def test_quant_sparse_marlin_layout_eager(self):
# Sparse + quantized
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
sparse_result = self.model(self.input)

assert torch.allclose(
dense_result, sparse_result, atol=3e-1
), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@skip_if_rocm("ROCm development in progress")
def test_quant_sparse_marlin_layout_compile(self):
apply_fake_sparsity(self.model)
model_copy = copy.deepcopy(self.model)
Expand Down
3 changes: 0 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

if is_fbcode():
pytest.skip(
"Skipping the test in fbcode since we don't have TARGET file for kernels"
Expand Down
3 changes: 0 additions & 3 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)

S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
Expand Down
29 changes: 0 additions & 29 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,11 @@
import functools
import unittest
from unittest.mock import patch

import pytest
import torch

from torchao.utils import TorchAOBaseTensor, torch_version_at_least


def skip_if_rocm(message=None):
"""Decorator to skip tests on ROCm platform with custom message.
Args:
message (str, optional): Additional information about why the test is skipped.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if torch.version.hip is not None:
skip_message = "Skipping the test in ROCm"
if message:
skip_message += f": {message}"
pytest.skip(skip_message)
return func(*args, **kwargs)

return wrapper

# Handle both @skip_if_rocm and @skip_if_rocm() syntax
if callable(message):
func = message
message = None
return decorator(func)
return decorator


class TestTorchVersionAtLeast(unittest.TestCase):
def test_torch_version_at_least(self):
test_cases = [
Expand Down

0 comments on commit 5b3fb3e

Please sign in to comment.