Skip to content

Commit

Permalink
More elegant weight initialization for FP6 benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasvanderwerff committed Sep 24, 2024
1 parent 4211b85 commit 7fbbcca
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import torch
import pandas as pd
import torchao
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.dtypes.floatx import from_scaled_tc_floatx, to_scaled_tc_floatx
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
ebits = 3
mbits = 2
nbits = 1 + ebits + mbits

fp6_weight = torch.randint(256, (n, k // 8 * nbits), dtype=torch.uint8, device="cuda")
scale = torch.rand(n, device="cuda").half() + 0.5
fp32_weight = torch.randn(n, k, device="cuda")
fp6_weight, scale = to_scaled_tc_floatx(fp32_weight, ebits, mbits)
fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") + 0.5

fp6_output = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scale, splitK=1)
Expand Down

0 comments on commit 7fbbcca

Please sign in to comment.