diff --git a/benchmarks/bench_fused_add_rmsnorm.py b/benchmarks/bench_fused_add_rmsnorm.py new file mode 100644 index 00000000..d56c1819 --- /dev/null +++ b/benchmarks/bench_fused_add_rmsnorm.py @@ -0,0 +1,55 @@ +import argparse +from typing import cast + +import torch +from triton.testing import do_bench + +import flashinfer + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989]) + parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192]) + parser.add_argument("--dtypes", nargs='+', choices=["float16", "bfloat16"], default=["float16"]) + args = parser.parse_args() + + eps = 1e-6 + + # Loop over each combination of batch_size, hidden_size, and dtype + for batch_size in args.batch_sizes: + for hidden_size in args.hidden_sizes: + for dtype_str in args.dtypes: + dtype = getattr(torch, dtype_str) + + # Define tensors with the correct dtype + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + @torch.cuda.nvtx.range(f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}") + def fn() -> None: + flashinfer.fused_add_rmsnorm(x, residual, weight, eps) + + # Run benchmarking + latency_ms = cast(float, do_bench(fn)) + throughput = ( + (x.numel() * x.element_size() * 2 + + residual.numel() * residual.element_size() * 2 + + weight.numel() * weight.element_size()) + / (latency_ms * 1e-3) + ) + print( + f"batch_size: {batch_size:3},", + f"hidden_size: {hidden_size:5},", + f"dtype: {dtype_str:8},", + f"latency: {latency_ms*1e3:2.0f}us,", + f"throughput: {throughput*1e-9:7.3f}GB/s", + ) + + print("---") + + torch.cuda.profiler.stop() + +if __name__ == "__main__": + main() diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 0774ec62..14b6692b 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -133,6 +133,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res input_vec.fill(0.f); vec_t residual_vec; residual_vec.fill(0.f); + vec_t x_vec; + x_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); @@ -143,10 +145,11 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res x += float(residual_vec[j]); sum_sq += x * x; residual_vec[j] = (T)x; - smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x; + x_vec[j] = x; } if ((i * num_threads + thread_id) * VEC_SIZE < d) { residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } } @@ -174,15 +177,17 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res for (uint32_t i = 0; i < rounds; i++) { vec_t input_vec; vec_t weight_vec; + vec_t x_vec; input_vec.fill(0.f); weight_vec.fill(0.f); + x_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j]; - input_vec[j] = x * rms_rcp * float(weight_vec[j]); + input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);