From 0cf25db72544a325e54d495b8e5aecb6cba1e2c1 Mon Sep 17 00:00:00 2001 From: Abatom Date: Thu, 7 Nov 2024 23:58:10 +0800 Subject: [PATCH 1/4] FusedAddRMSNormKernel --- include/flashinfer/norm.cuh | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 0774ec62..14b6692b 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -133,6 +133,8 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res input_vec.fill(0.f); vec_t residual_vec; residual_vec.fill(0.f); + vec_t x_vec; + x_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); @@ -143,10 +145,11 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res x += float(residual_vec[j]); sum_sq += x * x; residual_vec[j] = (T)x; - smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j] = x; + x_vec[j] = x; } if ((i * num_threads + thread_id) * VEC_SIZE < d) { residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.store(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } } @@ -174,15 +177,17 @@ __global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ res for (uint32_t i = 0; i < rounds; i++) { vec_t input_vec; vec_t weight_vec; + vec_t x_vec; input_vec.fill(0.f); weight_vec.fill(0.f); + x_vec.fill(0.f); if ((i * num_threads + thread_id) * VEC_SIZE < d) { weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.load(smem + num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = smem[num_warps + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE + j]; - input_vec[j] = x * rms_rcp * float(weight_vec[j]); + input_vec[j] = x_vec[j] * rms_rcp * float(weight_vec[j]); } if ((i * num_threads + thread_id) * VEC_SIZE < d) { input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); From 17f5c14aa615472efebed4938005aca0d3f21fdb Mon Sep 17 00:00:00 2001 From: Abatom Date: Fri, 8 Nov 2024 14:43:37 +0800 Subject: [PATCH 2/4] add a benchmark --- benchmarks/bench_fused_add_rmsnorm.py | 52 +++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 benchmarks/bench_fused_add_rmsnorm.py diff --git a/benchmarks/bench_fused_add_rmsnorm.py b/benchmarks/bench_fused_add_rmsnorm.py new file mode 100644 index 00000000..e2cf7f9c --- /dev/null +++ b/benchmarks/bench_fused_add_rmsnorm.py @@ -0,0 +1,52 @@ +import argparse +from typing import cast + +import torch +from triton.testing import do_bench + +import flashinfer + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989]) + parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192]) + parser.add_argument("--dtypes", nargs='+', choices=["float16", "float32"], default=["float16"]) + args = parser.parse_args() + + eps = 1e-6 + + # Loop over each combination of batch_size, hidden_size, and dtype + for batch_size in args.batch_sizes: + for hidden_size in args.hidden_sizes: + for dtype_str in args.dtypes: + dtype = getattr(torch, dtype_str) + + # Define tensors with the correct dtype + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + @torch.cuda.nvtx.range(f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}") + def fn() -> None: + flashinfer.fused_add_rmsnorm(x, residual, weight, eps) + + # Run benchmarking + latency_ms = cast(float, do_bench(fn)) + throughput = ( + (x.numel() * x.element_size() * 2 + weight.numel() * weight.element_size()) / (latency_ms * 1e-3) + ) + print( + f"batch_size: {batch_size:3},", + f"hidden_size: {hidden_size:5},", + f"dtype: {dtype_str:2},", + f"latency: {latency_ms*1e3:2.0f}us,", + f"throughput: {throughput*1e-9:7.3f}GB/s", + ) + + print("---") + + torch.cuda.profiler.stop() + +if __name__ == "__main__": + main() From 327faf6a6e32660ee4afb0c752b3ccee5ced5499 Mon Sep 17 00:00:00 2001 From: Abatom Date: Sat, 9 Nov 2024 09:49:52 +0800 Subject: [PATCH 3/4] Revise as suggested --- benchmarks/bench_fused_add_rmsnorm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_fused_add_rmsnorm.py b/benchmarks/bench_fused_add_rmsnorm.py index e2cf7f9c..d2117d90 100644 --- a/benchmarks/bench_fused_add_rmsnorm.py +++ b/benchmarks/bench_fused_add_rmsnorm.py @@ -34,7 +34,10 @@ def fn() -> None: # Run benchmarking latency_ms = cast(float, do_bench(fn)) throughput = ( - (x.numel() * x.element_size() * 2 + weight.numel() * weight.element_size()) / (latency_ms * 1e-3) + (x.numel() * x.element_size() * 2 + + residual.numel() * residual.element_size() * 2 + + weight.numel() * weight.element_size()) + / (latency_ms * 1e-3) ) print( f"batch_size: {batch_size:3},", From b97c8b1506b15f9aa7861b29da98e0d7c3e6aa21 Mon Sep 17 00:00:00 2001 From: Abatom Date: Sat, 9 Nov 2024 10:16:28 +0800 Subject: [PATCH 4/4] dtypes --- benchmarks/bench_fused_add_rmsnorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_fused_add_rmsnorm.py b/benchmarks/bench_fused_add_rmsnorm.py index d2117d90..d56c1819 100644 --- a/benchmarks/bench_fused_add_rmsnorm.py +++ b/benchmarks/bench_fused_add_rmsnorm.py @@ -11,7 +11,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--batch-sizes", nargs='+', type=int, default=[1, 19, 99, 989]) parser.add_argument("--hidden-sizes", nargs='+', type=int, default=[111, 500, 1024, 3072, 4096, 8192]) - parser.add_argument("--dtypes", nargs='+', choices=["float16", "float32"], default=["float16"]) + parser.add_argument("--dtypes", nargs='+', choices=["float16", "bfloat16"], default=["float16"]) args = parser.parse_args() eps = 1e-6 @@ -42,7 +42,7 @@ def fn() -> None: print( f"batch_size: {batch_size:3},", f"hidden_size: {hidden_size:5},", - f"dtype: {dtype_str:2},", + f"dtype: {dtype_str:8},", f"latency: {latency_ms*1e3:2.0f}us,", f"throughput: {throughput*1e-9:7.3f}GB/s", )