Skip to content

Commit

Permalink
fix precision error due to dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed May 26, 2024
1 parent d6c6b6a commit d798eaf
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/quantization/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_fp6_llm_linear_forward(self, bias, leading_dims):
fp6_linear = Fp6LlmLinear.from_float(linear)
assert (fp6_linear.bias is not None) == bias

x = torch.randn(*leading_dims, IC, device=device)
x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
fp6_linear(x)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -72,7 +72,7 @@ def test_fp6_llm_linear_compile(self, bias):
linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fp6_linear = Fp6LlmLinear.from_float(linear)

x = torch.randn(N, IC, device=device)
x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fp6_linear(x)
actual = torch.compile(fp6_linear)(x)
torch.testing.assert_close(actual, expected)
Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None
self.in_features = weight.shape[1] * 16 // 3

def forward(self, x: Tensor) -> Tensor:
# TODO: splitK map
out = fp16act_fp6weight_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=1)
if self.bias is not None:
out = out + self.bias
Expand Down

0 comments on commit d798eaf

Please sign in to comment.