Skip to content

Commit

Permalink
add to benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 29, 2025
1 parent 382cb0f commit 2516390
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
28 changes: 26 additions & 2 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
import csv
import logging


def ceil_div(a, b):
return (a + b - 1) // b


torch._dynamo.config.cache_size_limit = 10000
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
torch._inductor.config.max_autotune_gemm_backends = "TRITON"
Expand All @@ -41,6 +46,7 @@ class FP8Kernel(Enum):
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"
MX_FP8 = "MX-FP8"


class ScalingStrategy(Enum):
Expand Down Expand Up @@ -78,6 +84,23 @@ def get_fp8_matmul(
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
elif fp8_kernel == FP8Kernel.MX_FP8:
try:
from driss_torch import mx_fp8_bf16
except ModuleNotFoundError:
print("Driss Torch not installed")
return None
# This is crude but we just care about perf, numerics checked elsewhere
M, K = A_fp8.shape
N, _ = B_fp8.shape
n_a_rows = ceil_div(M, 128)
n_a_cols = ceil_div(K, 32)
n_b_rows = ceil_div(N, 128)
n_b_cols = ceil_div(K, 32)
a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda")
b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda")

return lambda: mx_fp8_bf16(A_fp8, B_fp8, a_scales, b_scales)
else:
raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}")

Expand Down Expand Up @@ -223,14 +246,15 @@ def get_configs_varying_k(
M: int = 8192, N: int = 8192, bf16: bool = False
) -> List[ExperimentConfig]:
shapes = [(M, K, N) for K in range(1024, 16385, 1024)]
scaling_strategies = [ScalingStrategy.PER_ROW]
compile_options = [True, False]
scaling_strategies = [ScalingStrategy.PER_TENSOR]
compile_options = [False]
configs = []
fp8_kernels = [
FP8Kernel.SCALED_MM,
# FP8Kernel.PERSISTENT,
# FP8Kernel.PERSISTENT_TMA,
# FP8Kernel.DEVICE_TMA,
FP8Kernel.MX_FP8,
]

for (M, K, N), strategy, compile, kernel in itertools.product(
Expand Down
1 change: 1 addition & 0 deletions transformer_nuggets/mx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from transformer_nuggets.mx.to_blocked import to_blocked

0 comments on commit 2516390

Please sign in to comment.