Skip to content

Commit

Permalink
add mx-to the bunch
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 29, 2025
1 parent 2516390 commit 0d1cbe6
Showing 1 changed file with 41 additions and 15 deletions.
56 changes: 41 additions & 15 deletions benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,61 +46,87 @@ class FP8Kernel(Enum):
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"
MX_FP8 = "MX-FP8"
CUTLASS_MX = "Cutlass-MX-FP8"


class ScalingStrategy(Enum):
PER_TENSOR = "PerTensor"
PER_ROW = "PerRow"
E8M0 = "E8M0"


def is_col_major(stride):
assert len(stride) == 2, "is_col_major only supports 2D tensors"
return stride[1] > stride[0] and stride[0] == 1


def get_e8_scales(A: torch.Tensor, B: torch.Tensor):
M, K = A.shape
N, _ = B.shape
n_a_rows = ceil_div(M, 128)
n_a_cols = ceil_div(K, 32)
n_b_rows = ceil_div(N, 128)
n_b_cols = ceil_div(K, 32)
a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda")
b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda")
return a_scales, b_scales


def get_fp8_matmul(
A: torch.Tensor, B: torch.Tensor, scaling_strategy: ScalingStrategy, fp8_kernel: FP8Kernel
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))

# Handle E8M0 format for supported kernels
if scaling_strategy == ScalingStrategy.E8M0:
if fp8_kernel not in [FP8Kernel.CUTLASS_MX, FP8Kernel.SCALED_MM]:
raise ValueError(
"E8M0 scaling strategy is only supported by MX_FP8 and SCALED_MM kernels"
)

if scaling_strategy == ScalingStrategy.PER_TENSOR:
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
elif scaling_strategy == ScalingStrategy.PER_ROW:
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
elif scaling_strategy == ScalingStrategy.E8M0:
a_scale, b_scale = get_e8_scales(A_fp8, B_fp8)
else:
raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")

if fp8_kernel == FP8Kernel.PERSISTENT:
return lambda: matmul_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.PERSISTENT_TMA:
return lambda: matmul_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.DEVICE_TMA:
return lambda: matmul_device_tma_persistent(A_fp8, a_scale, B_fp8, b_scale, torch.bfloat16)
elif fp8_kernel == FP8Kernel.SCALED_MM:
if scaling_strategy == ScalingStrategy.E8M0:
# Use the scales we computed earlier for E8M0
return lambda: torch._scaled_mm(
A_fp8,
B_fp8,
b_scale, # swap since we haven't figured this out yet
a_scale,
out_dtype=torch.bfloat16,
scale_dtype=1,
)
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
elif fp8_kernel == FP8Kernel.MX_FP8:
elif fp8_kernel == FP8Kernel.CUTLASS_MX:
assert (
scaling_strategy == ScalingStrategy.E8M0
), "E8M0 scaling strategy is required for MX_FP8"
try:
from driss_torch import mx_fp8_bf16
except ModuleNotFoundError:
print("Driss Torch not installed")
return None
# This is crude but we just care about perf, numerics checked elsewhere
M, K = A_fp8.shape
N, _ = B_fp8.shape
n_a_rows = ceil_div(M, 128)
n_a_cols = ceil_div(K, 32)
n_b_rows = ceil_div(N, 128)
n_b_cols = ceil_div(K, 32)
a_scales = torch.randint(256, (n_a_rows, n_a_cols), dtype=torch.uint8, device="cuda")
b_scales = torch.randint(256, (n_b_rows, n_b_cols), dtype=torch.uint8, device="cuda")

return lambda: mx_fp8_bf16(A_fp8, B_fp8, a_scales, b_scales)
return lambda: mx_fp8_bf16(A_fp8, B_fp8, a_scale, b_scale)
else:
raise ValueError(f"Invalid FP8 kernel: {fp8_kernel}")

Expand Down Expand Up @@ -246,15 +272,15 @@ def get_configs_varying_k(
M: int = 8192, N: int = 8192, bf16: bool = False
) -> List[ExperimentConfig]:
shapes = [(M, K, N) for K in range(1024, 16385, 1024)]
scaling_strategies = [ScalingStrategy.PER_TENSOR]
scaling_strategies = [ScalingStrategy.E8M0]
compile_options = [False]
configs = []
fp8_kernels = [
FP8Kernel.SCALED_MM,
# FP8Kernel.PERSISTENT,
# FP8Kernel.PERSISTENT_TMA,
# FP8Kernel.DEVICE_TMA,
FP8Kernel.MX_FP8,
FP8Kernel.CUTLASS_MX,
]

for (M, K, N), strategy, compile, kernel in itertools.product(
Expand Down

0 comments on commit 0d1cbe6

Please sign in to comment.