Skip to content

Commit

Permalink
Add PyTorch 2.4 tests in CI (#654)
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim authored Aug 15, 2024
1 parent 0b0192e commit ffa88a4
Show file tree
Hide file tree
Showing 12 changed files with 36 additions and 23 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ jobs:
torch-spec: 'torch==2.3.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA 2.4
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.4.0'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.2.2
runs-on: linux.4xlarge
torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu "numpy<2" '
Expand All @@ -46,6 +52,11 @@ jobs:
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: CPU 2.4
runs-on: linux.4xlarge
torch-spec: 'torch==2.4.0 --index-url https://download.pytorch.org/whl/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: CPU Nightly
runs-on: linux.4xlarge
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import torch
import torch.nn as nn

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torchao.float8 import Float8LinearConfig
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import fire

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import unittest
from typing import Any, List

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import pytest
from unittest.mock import patch
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
unwrap_tensor_subclass,
)

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
4 changes: 2 additions & 2 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_4:
if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

import torch
Expand Down
12 changes: 6 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
if dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("test requires SM capability of at least (8, 0).")
from torch._inductor import config
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm", True)
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True)

with config.patch({
"epilogue_fusion": True,
Expand Down Expand Up @@ -943,7 +943,7 @@ def test_weight_only_quant_use_mixed_mm(self, device, dtype):
self.skipTest("test requires SM capability of at least (8, 0).")
torch.manual_seed(0)
from torch._inductor import config
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_4 else ("force_mixed_mm", True)
mixed_mm_key, mixed_mm_val = ("mixed_mm_choice", "triton") if TORCH_VERSION_AT_LEAST_2_5 else ("force_mixed_mm", True)

with config.patch({
"epilogue_fusion": False,
Expand Down Expand Up @@ -1222,7 +1222,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
(1, 32, 128, 128),
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
Expand Down Expand Up @@ -1254,7 +1254,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
self.assertTrue(sqnr >= 30)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_autoquant_manual(self, device, dtype):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
Expand Down Expand Up @@ -1295,7 +1295,7 @@ def test_autoquant_manual(self, device, dtype):
(1, 32, 128, 128),
(32, 32, 128, 128),
]))
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_autoquant_kwargs(self, device, dtype, m1, m2, k, n):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def forward(self, x):

class TestUtils(unittest.TestCase):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "autoquant requires 2.4+.")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_get_model_size_autoquant(self, device, dtype):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
Expand Down
2 changes: 1 addition & 1 deletion test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class TestFSDP2(FSDPTest):
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="torch >= 2.4 required")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="OptimState8bit dispatch: attempting to run unimplemented operator/function: aten.as_strided.default")
@pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="https://github.com/pytorch/ao/issues/652")
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
Expand Down
2 changes: 2 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def test_qat_4w_primitives(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
Expand Down Expand Up @@ -453,6 +454,7 @@ def test_qat_4w_linear(self):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
Expand Down

0 comments on commit ffa88a4

Please sign in to comment.