Skip to content

Commit

Permalink
clean up device checks in float8 unit test files (pytorch#923)
Browse files Browse the repository at this point in the history
Summary:

While working on rowwise scaling I noticed that some of the CUDA
device capability checks we had in the test files did not make sense,
cleaning this up.

Test Plan:

tests pass on my H100

CI, it should skip less tests now since CI only has CUDA capability 8, 9

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored and weifengpy committed Sep 26, 2024
1 parent ebdeed0 commit 09ffa22
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 24 deletions.
23 changes: 0 additions & 23 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,6 @@ def test_linear(
linear_dtype: torch.dtype,
linear_bias: bool,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)

Expand Down Expand Up @@ -290,16 +281,6 @@ def test_autocast_outputs(
emulate: bool,
linear_dtype: torch.dtype,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
Expand Down Expand Up @@ -337,10 +318,6 @@ def test_autocast_outputs(
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
emulate = (
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0)
)

m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
config = Float8LinearConfig(emulate=emulate)
m = Float8Linear.from_float(copy.deepcopy(m), config)
Expand Down
3 changes: 2 additions & 1 deletion test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def forward(self, x):
return x_hp
return x_fp8

@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with float8 support not available")
# TODO(future): figure out why the test below fails on CUDA capability 8.9
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA with capability 9.0 or greater not available")
def test_float8_with_graph_break_in_the_middle(self):
"""Test that having Float8Tensor object at the boundary of a subgraph"""
cnts = CompileCounterWithBackend("inductor")
Expand Down

0 comments on commit 09ffa22

Please sign in to comment.