Skip to content

Commit

Permalink
clean up device checks in float8 unit test files
Browse files Browse the repository at this point in the history
Summary:

While working on rowwise scaling I noticed that some of the CUDA
device capability checks we had in the test files did not make sense,
cleaning this up.

Test Plan:

tests pass on my H100

CI, it should skip less tests now since CI only has CUDA capability 8, 9

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Sep 23, 2024
1 parent 53b6b78 commit 5f1879b
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 25 deletions.
23 changes: 0 additions & 23 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,6 @@ def test_linear(
linear_dtype: torch.dtype,
linear_bias: bool,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

Expand Down Expand Up @@ -287,16 +278,6 @@ def test_autocast_outputs(
emulate: bool,
linear_dtype: torch.dtype,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
Expand Down Expand Up @@ -334,10 +315,6 @@ def test_autocast_outputs(
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
emulate = (
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
config = Float8LinearConfig(emulate=emulate)
m = Float8Linear.from_float(copy.deepcopy(m), config)
Expand Down
3 changes: 1 addition & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def _test_compile_base(
Expand Down Expand Up @@ -224,7 +223,7 @@ def forward(self, x):
return x_hp
return x_fp8

@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
@unittest.skipIf(not torch.cuda.is_available() or not is_cuda_8_9, "CUDA with float8 support not available")
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
Expand Down

0 comments on commit 5f1879b

Please sign in to comment.