From 606625329648e6eff1883e23040adfad82f219cf Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 23 May 2024 02:39:27 -0400 Subject: [PATCH] Marlin 24 prefill performance improvement (about 25% better on average) (#4983) --- benchmarks/kernels/benchmark_marlin.py | 74 ++++++++++++++++--- .../marlin/sparse/marlin_24_cuda_kernel.cu | 55 ++++++++++---- tests/kernels/test_marlin_gemm.py | 2 +- .../layers/quantization/gptq_marlin_24.py | 8 +- 4 files changed, 107 insertions(+), 32 deletions(-) diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 5dcffc284f3d4..b771911781574 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -6,9 +6,13 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - MarlinWorkspace, marlin_quantize) + MarlinWorkspace, marlin_24_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( gptq_pack, quantize_weights, sort_weights) @@ -44,6 +48,10 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, marlin_rand_perm, ) = marlin_quantize(b, num_bits, group_size, act_order) + # Marlin_24 quant + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s) = marlin_24_quantize(b, num_bits, group_size) + # GPTQ quant (w_ref, q_w, s, g_idx, rand_perm) = quantize_weights(b, num_bits, group_size, act_order) @@ -56,28 +64,43 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare - marlin_workspace = MarlinWorkspace(size_n) + marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_MAX_PARALLEL) globals = { + # Gen params + "num_bits": num_bits, + "group_size": group_size, + "size_m": size_m, + "size_n": size_n, + "size_k": size_k, + "a": a, + "a_tmp": a_tmp, + # Marlin params "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, "marlin_rand_perm": marlin_rand_perm, + "marlin_workspace": marlin_workspace, + "is_k_full": is_k_full, + # Marlin_24 params + "marlin_24_w_ref": marlin_24_w_ref, + "marlin_24_q_w_comp": marlin_24_q_w_comp, + "marlin_24_meta": marlin_24_meta, + "marlin_24_s": marlin_24_s, + "marlin_24_workspace": marlin_24_workspace, + # GPTQ params "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, - "num_bits": num_bits, - "group_size": group_size, - "size_m": size_m, - "size_n": size_n, - "size_k": size_k, - "is_k_full": is_k_full, - "a": a, - "a_tmp": a_tmp, + # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, + "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_repack": ops.gptq_marlin_repack, - "marlin_workspace": marlin_workspace, } min_run_time = 1 @@ -105,6 +128,18 @@ def bench_run(results, model, act_order, is_k_full, num_bits, group_size, description="gptq_marlin_gemm", ).blocked_autorange(min_run_time=min_run_time)) + if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + results.append( + benchmark.Timer( + stmt= + "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_24_gemm", + ).blocked_autorange(min_run_time=min_run_time)) + results.append( benchmark.Timer( stmt= @@ -135,8 +170,20 @@ def main(args): continue for act_order in ACT_ORDER_OPTS: + if len(args.limit_act_order + ) > 0 and act_order not in args.limit_act_order: + continue + for is_k_full in K_FULL_OPTS: + if len(args.limit_k_full + ) > 0 and is_k_full not in args.limit_k_full: + continue + for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + if len(args.limit_num_bits + ) > 0 and num_bits not in args.limit_num_bits: + continue + for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: if len( args.limit_group_size @@ -159,7 +206,7 @@ def main(args): # For quick benchmarking use: -# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501 +# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501 # if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -178,6 +225,9 @@ def main(args): parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[]) + parser.add_argument("--limit-act-order", nargs="+", type=int, default=[]) + parser.add_argument("--limit-k-full", nargs="+", type=int, default=[]) args = parser.parse_args() main(args) diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 54ad27676e207..686dd7851e6af 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -48,12 +48,12 @@ namespace marlin_24 { // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. static constexpr int THREADS = 256; -static constexpr int STAGES = 4; // 4 pipeline stages fit into shared memory +static constexpr int STAGES = 4; static constexpr int min_thread_n = 128; static constexpr int tile_size = 16; -static constexpr int max_par = 16; +static constexpr int max_par = 64; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -736,10 +736,10 @@ __global__ void Marlin_24( for (int pipe = 0; pipe < stages;) { fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + matmul(pipe); wait_for_stage(); fetch_to_registers(pipe + 1, (pipe + 1) % stages); - matmul(pipe); pipe++; slice_iters--; @@ -899,9 +899,12 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, // than better compute utilization thread_k = 128; thread_m = 128; - } else { + } else if (prob_n <= 256) { thread_k = 64; thread_m = 256; + } else { + thread_k = 32; + thread_m = 512; } } @@ -928,19 +931,21 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, int4* C_ptr = (int4*)C; const int4* s_ptr = (const int4*)s; + constexpr int max_m_blocks = 4; + int* locks = (int*)workspace; - for (int i = 0; i < tot_n_blocks; i += 4) { + for (int i = 0; i < tot_n_blocks; i += max_m_blocks) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; int par = 1; - if (thread_n_blocks > 4) { + if (thread_n_blocks > max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * thread_n_blocks - pad) / 64; + par = (16 * thread_n_blocks - pad) / (max_m_blocks * 16); if (par > max_par) par = max_par; - prob_n = 64 * par; - i += 4 * (par - 1); - thread_n_blocks = 4; + prob_n = (max_m_blocks * 16) * par; + i += max_m_blocks * (par - 1); + thread_n_blocks = max_m_blocks; } // For compilation speed, we only define the kernel configurations that have @@ -951,8 +956,9 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, if (false) { } // BMxBNxBK, group // 4-bit - CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(4, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(4, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(4, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(4, 16, 2, 2, -1) // e.g.. 32x256x64 @@ -962,9 +968,19 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(4, 16, 4, 2, -1) CALL_IF_2_4(4, 16, 4, 2, 4) + CALL_IF_2_4(4, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(4, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(4, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(4, 32, 2, 1, 4) + CALL_IF_2_4(4, 32, 3, 1, -1) + CALL_IF_2_4(4, 32, 3, 1, 4) + CALL_IF_2_4(4, 32, 4, 1, -1) + CALL_IF_2_4(4, 32, 4, 1, 4) + // 8-bit - CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 - CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_2_4(8, 8, 1, 4, 4) // e.g., 16x128x128, 64 + CALL_IF_2_4(8, 16, 1, 2, -1) // e.g., 16x256x64 CALL_IF_2_4(8, 16, 1, 2, 4) // e.g., 16x256x64, 64 CALL_IF_2_4(8, 16, 2, 2, -1) // e.g.. 32x256x64 @@ -973,6 +989,15 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, CALL_IF_2_4(8, 16, 3, 2, 4) CALL_IF_2_4(8, 16, 4, 2, -1) CALL_IF_2_4(8, 16, 4, 2, 4) + + CALL_IF_2_4(8, 32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_2_4(8, 32, 1, 1, 4) // e.g., 16x256x64, 64 + CALL_IF_2_4(8, 32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_2_4(8, 32, 2, 1, 4) + CALL_IF_2_4(8, 32, 3, 1, -1) + CALL_IF_2_4(8, 32, 3, 1, 4) + CALL_IF_2_4(8, 32, 4, 1, -1) + CALL_IF_2_4(8, 32, 4, 1, 4) else { throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]" + @@ -1062,7 +1087,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, int thread_k = -1; int thread_m = -1; int sms = -1; - int max_par = 16; + int max_par = marlin_24::max_par; int groupsize = -1; if (b_scales.size(0) > 1) { diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 587fc3901eb7c..1f8d94bad26d9 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -27,7 +27,7 @@ MARLIN_N_CHUNKS = [64, 128, 256] MARLIN_24_K_CHUNKS = [128] -MARLIN_24_N_CHUNKS = [256] +MARLIN_24_N_CHUNKS = [512] MNK_FACTORS = [ (1, 1, 1), diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index f5345c0443029..6bcfc405afe71 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -15,7 +15,7 @@ GPTQ_MARLIN_24_TILE = 16 GPTQ_MARLIN_24_MIN_THREAD_N = 128 GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 16 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] @@ -53,14 +53,14 @@ def __init__( self.tile_size = 16 # Min out_features dim - self.min_n_threads = 128 + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N # Min in_features dim - self.min_k_threads = 128 + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K # Max parallel problems to solve at once (improves large # batch performance) - self.max_parallel = 16 + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL # Permutation length used by the marlin kernels. self.perm_len = 1024