From d8f768c00bac40c62dc711e8bc8050f477e49ad7 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Fri, 20 Sep 2024 07:42:21 +0000 Subject: [PATCH 1/9] Cleanup act_order code from the MarlinMoE kernels --- csrc/moe/marlin_moe_ops.cu | 524 +++++++++---------------------------- 1 file changed, 127 insertions(+), 397 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 49cc03f827f68..dd8c63042ec97 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -210,21 +210,6 @@ __device__ inline void scale_float(float* c, FragS& s) { c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); } -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { @@ -344,7 +329,6 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -386,7 +370,7 @@ __device__ inline void MarlinMoESingle( int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (!has_act_order && group_blocks != -1) { + if constexpr (group_blocks != -1) { if (group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts @@ -482,20 +466,12 @@ __device__ inline void MarlinMoESingle( int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; constexpr int sorted_sh_stride = threads; constexpr int sorted_gl_stride = threads; @@ -519,23 +495,17 @@ __device__ inline void MarlinMoESingle( int b_sh_wr = threadIdx.x * b_thread_vecs; int b_sh_rd = threadIdx.x * b_thread_vecs; - // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; // No act_order int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; } + int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -555,17 +525,14 @@ __device__ inline void MarlinMoESingle( constexpr int sh_max_num_groups = 32; int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_g_idx; int* sh_sorted = (int*)(sh_s + shs_size); // Precompute which thread should not read memory in which iterations; this is @@ -622,8 +589,7 @@ __device__ inline void MarlinMoESingle( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order + FragS frag_s[2][4]; // No act-order // Zero accumulators. auto zero_accums = [&]() { @@ -632,39 +598,6 @@ __device__ inline void MarlinMoESingle( reinterpret_cast(frag_c)[i] = 0; }; - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { @@ -693,41 +626,24 @@ __device__ inline void MarlinMoESingle( B_ptr[i] += b_gl_rd_delta_o; } - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta; } } } @@ -775,125 +691,35 @@ __device__ inline void MarlinMoESingle( } }; - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - cur_k += warp_row * 16; + int warp_row = warp_id / n_warps; - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } + int4* sh_s_stage = sh_s + s_sh_stage * pipe; - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } - return; } - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } + return; }; // Execute the actual tensor core matmul of a sub-tile. @@ -916,25 +742,10 @@ __device__ inline void MarlinMoESingle( FragB frag_b0 = dequant(b_quant_0); FragB frag_b1 = dequant(b_quant_1); - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } + // Apply scale to frag_b0 and frag_b1 + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + scale(frag_b1, frag_s[k % 2][j], 1); } #pragma unroll @@ -1106,8 +917,7 @@ __device__ inline void MarlinMoESingle( // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { + if constexpr (group_blocks == -1 && w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -1165,23 +975,14 @@ __device__ inline void MarlinMoESingle( #pragma unroll for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } fetch_to_shared(i, i, i < slice_iters); } zero_accums(); wait_for_stage(); - init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); }; if (slice_iters) { start_pipes(); @@ -1204,7 +1005,6 @@ __device__ inline void MarlinMoESingle( slice_iters >= stages); pipe++; wait_for_stage(); - init_same_group(pipe % stages); } matmul(k); } @@ -1215,21 +1015,6 @@ __device__ inline void MarlinMoESingle( } a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing @@ -1237,7 +1022,7 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1256,7 +1041,7 @@ __device__ inline void MarlinMoESingle( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); @@ -1280,8 +1065,7 @@ __device__ inline void MarlinMoESingle( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { + if constexpr (group_blocks == -1 && w_type.size_bits() == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1324,16 +1108,7 @@ __device__ inline void MarlinMoESingle( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); } } @@ -1346,7 +1121,6 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1406,28 +1180,28 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( + stages, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( + stages, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( + stages, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( + stages, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1460,7 +1234,6 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1509,24 +1282,23 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, \ + NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1561,10 +1333,8 @@ thread_config_t large_batch_thread_configs[] = { }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - + int prob_n, int prob_k, int num_bits, + int group_size) { int tb_n = th_config.thread_n; int tb_k = th_config.thread_k; @@ -1578,17 +1348,9 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, tb_groups = ceildiv(tb_k, group_size); } - if (cache_scales_chunk) { - int load_groups = - tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; - - } else { - int tb_scales = tb_groups * tb_n * 2; + int tb_scales = tb_groups * tb_n * 2; - return tb_scales * STAGES; - } + return tb_scales * STAGES; } bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, @@ -1629,8 +1391,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { + int group_size, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1659,9 +1420,8 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, } // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); + int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, + prob_k, num_bits, group_size); // Check that pipeline fits into cache if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, @@ -1674,23 +1434,20 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { + num_bits, group_size, max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } } else { for (auto th_config : large_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { + num_bits, group_size, max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } @@ -1703,28 +1460,22 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + int num_groups, int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1747,23 +1498,20 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); + exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, + group_size, max_shared_mem); } - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); + TORCH_CHECK( + exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); int num_threads = exec_cfg.tb_cfg.num_threads; thread_k = exec_cfg.tb_cfg.thread_k; @@ -1781,16 +1529,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int group_blocks = 0; if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); } else { if (group_size == -1) { group_blocks = -1; @@ -1808,15 +1550,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, compute_expert_offsets<<<1, num_experts, 0, stream>>>( topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); - bool do_permute_a = has_act_order; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - int pack_factor = 32 / q_type.size_bits(); for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { @@ -1836,7 +1569,9 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; - if (do_permute_a) { + // Since k is always full, we only use has_act_order here and treat it as + // false in the kernel, + if (has_act_order) { // Permute A columns int topk_rows = replicate_input ? tot_m : tot_m * topk; int block_rows = ceildiv(topk_rows, blocks); @@ -1864,7 +1599,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + ", thread_m_blocks = " + str(thread_m_blocks) + @@ -1888,7 +1622,8 @@ torch::Tensor marlin_gemm_moe( bool replicate_input, bool apply_weights) { TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - + TORCH_CHECK(is_k_full, + "The MoE kernel currently supports only is_k_full == true"); int pack_factor = 32 / b_q_type->size_bits(); int max_par = 4; @@ -1926,15 +1661,10 @@ torch::Tensor marlin_gemm_moe( num_groups = b_scales.size(1); if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; } else { if (num_groups > 1) { TORCH_CHECK( @@ -1946,13 +1676,13 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, - topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, num_groups, group_size, num_experts, topk, + moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } From bf97f62c3375e5bb071f0e57fa43fd0332ff34b7 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 23 Sep 2024 10:59:37 +0000 Subject: [PATCH 2/9] bring back act_order template --- csrc/moe/marlin_moe_ops.cu | 526 ++++++++++++++++++++++++++++--------- 1 file changed, 398 insertions(+), 128 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index dd8c63042ec97..22399eb0f4f72 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -210,6 +210,21 @@ __device__ inline void scale_float(float* c, FragS& s) { c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); } +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int* lock, int count) { if (threadIdx.x == 0) { @@ -329,6 +344,7 @@ template shared // fetch pipeline + const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -370,7 +386,7 @@ __device__ inline void MarlinMoESingle( int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (group_blocks != -1) { + if constexpr (!has_act_order && group_blocks != -1) { if (group_blocks >= thread_k_blocks) { // Ensure that the number of tiles in each stripe is a multiple of the // groupsize; this avoids an annoying special case where a stripe starts @@ -466,12 +482,20 @@ __device__ inline void MarlinMoESingle( int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = - group_blocks != -1 && group_blocks < thread_k_blocks + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; constexpr int sorted_sh_stride = threads; constexpr int sorted_gl_stride = threads; @@ -495,17 +519,23 @@ __device__ inline void MarlinMoESingle( int b_sh_wr = threadIdx.x * b_thread_vecs; int b_sh_rd = threadIdx.x * b_thread_vecs; + // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; // No act_order int s_gl_rd; - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } } - int s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -525,14 +555,17 @@ __device__ inline void MarlinMoESingle( constexpr int sh_max_num_groups = 32; int shs_size; - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. int4* sh_a = sh; int4* sh_b = sh_a + (stages * a_sh_stage); int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx; + int4* sh_s = sh_g_idx + (stages * g_idx_stage); int* sh_sorted = (int*)(sh_s + shs_size); // Precompute which thread should not read memory in which iterations; this is @@ -589,7 +622,8 @@ __device__ inline void MarlinMoESingle( FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2][b_thread_vecs]; FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order // Zero accumulators. auto zero_accums = [&]() { @@ -598,6 +632,39 @@ __device__ inline void MarlinMoESingle( reinterpret_cast(frag_c)[i] = 0; }; + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { @@ -626,24 +693,41 @@ __device__ inline void MarlinMoESingle( B_ptr[i] += b_gl_rd_delta_o; } - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; } - s_gl_rd += s_gl_rd_delta; } } } @@ -691,35 +775,125 @@ __device__ inline void MarlinMoESingle( } }; + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + auto fetch_scales_to_registers = [&](int k, int full_pipe) { int pipe = full_pipe % stages; - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; - int4* sh_s_stage = sh_s + s_sh_stage * pipe; + int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } } + + return; } - return; + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } }; // Execute the actual tensor core matmul of a sub-tile. @@ -742,10 +916,25 @@ __device__ inline void MarlinMoESingle( FragB frag_b0 = dequant(b_quant_0); FragB frag_b1 = dequant(b_quant_1); - // Apply scale to frag_b0 and frag_b1 - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - scale(frag_b1, frag_s[k % 2][j], 1); + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } } #pragma unroll @@ -917,7 +1106,8 @@ __device__ inline void MarlinMoESingle( // For per-column quantization we finally apply the scale here (only for // 4-bit) - if constexpr (group_blocks == -1 && w_type.size_bits() == 4) { + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { res = __hmul2(res, s[0]); } @@ -975,14 +1165,23 @@ __device__ inline void MarlinMoESingle( #pragma unroll for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } fetch_to_shared(i, i, i < slice_iters); } zero_accums(); wait_for_stage(); + init_same_group(0); fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); }; if (slice_iters) { start_pipes(); @@ -1005,6 +1204,7 @@ __device__ inline void MarlinMoESingle( slice_iters >= stages); pipe++; wait_for_stage(); + init_same_group(pipe % stages); } matmul(k); } @@ -1015,6 +1215,21 @@ __device__ inline void MarlinMoESingle( } a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing @@ -1022,7 +1237,7 @@ __device__ inline void MarlinMoESingle( if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - if constexpr (group_blocks == -1) { + if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1041,7 +1256,7 @@ __device__ inline void MarlinMoESingle( } thread_block_reduce(); - if constexpr (group_blocks == -1) { + if constexpr (!has_act_order && group_blocks == -1) { if constexpr (w_type.size_bits() == 8) { cp_async_wait<0>(); __syncthreads(); @@ -1065,7 +1280,8 @@ __device__ inline void MarlinMoESingle( // For 8-bit channelwise, we apply the scale before the global reduction // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) - if constexpr (group_blocks == -1 && w_type.size_bits() == 8) { + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1108,7 +1324,16 @@ __device__ inline void MarlinMoESingle( for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } start_pipes(); } } @@ -1121,6 +1346,7 @@ template shared // fetch pipeline + const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1180,28 +1406,28 @@ __global__ void MarlinMoE( if (max_block == 1) { MarlinMoESingle( + stages, has_act_order, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 2) { MarlinMoESingle( + stages, has_act_order, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else if (max_block == 3) { MarlinMoESingle( + stages, has_act_order, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, current_m_block); } else { MarlinMoESingle( + stages, has_act_order, group_blocks>( A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, @@ -1234,6 +1460,7 @@ template shared // fetch pipeline + const bool has_act_order, // whether act_order is enabled const int group_blocks = -1 // number of consecutive 16x16 blocks // with a separate quantization scale > @@ -1282,23 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory static constexpr int min_thread_n = 64; static constexpr int min_thread_k = 64; -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS, \ - NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks); \ } typedef struct { @@ -1333,8 +1561,10 @@ thread_config_t large_batch_thread_configs[] = { }; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, - int group_size) { + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + int tb_n = th_config.thread_n; int tb_k = th_config.thread_k; @@ -1348,9 +1578,17 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, tb_groups = ceildiv(tb_k, group_size); } - int tb_scales = tb_groups * tb_n * 2; + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; - return tb_scales * STAGES; + return tb_scales * STAGES; + } } bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, @@ -1391,7 +1629,8 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, int max_shared_mem) { + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1420,8 +1659,9 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, } // Determine cache for scales - int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, - prob_k, num_bits, group_size); + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); // Check that pipeline fits into cache if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, @@ -1434,20 +1674,23 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full, int max_shared_mem) { int max_m_blocks = 4; while (max_m_blocks > 0) { if (prob_m <= 16) { for (auto th_config : small_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } } else { for (auto th_config : large_batch_thread_configs) { if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { return exec_config_t{max_m_blocks, th_config}; } } @@ -1460,22 +1703,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) - -void marlin_mm_moe(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - int num_groups, int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { +#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool replicate_input, + bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1498,20 +1747,23 @@ void marlin_mm_moe(const void* A, const void* B, void* C, exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; } else { // Auto config - exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, - group_size, max_shared_mem); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK( - exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, - ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); int num_threads = exec_cfg.tb_cfg.num_threads; thread_k = exec_cfg.tb_cfg.thread_k; @@ -1529,10 +1781,16 @@ void marlin_mm_moe(const void* A, const void* B, void* C, int group_blocks = 0; if (has_act_order) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + } else { if (group_size == -1) { group_blocks = -1; @@ -1550,6 +1808,15 @@ void marlin_mm_moe(const void* A, const void* B, void* C, compute_expert_offsets<<<1, num_experts, 0, stream>>>( topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); + bool do_permute_a = has_act_order; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + int pack_factor = 32 / q_type.size_bits(); for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { @@ -1565,13 +1832,11 @@ void marlin_mm_moe(const void* A, const void* B, void* C, (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * prob_n / 8) * expert_idx; - const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_m * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; - // Since k is always full, we only use has_act_order here and treat it as - // false in the kernel, - if (has_act_order) { + if (do_permute_a) { // Permute A columns int topk_rows = replicate_input ? tot_m : tot_m * topk; int block_rows = ceildiv(topk_rows, blocks); @@ -1599,6 +1864,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + ", thread_m_blocks = " + str(thread_m_blocks) + @@ -1622,8 +1888,7 @@ torch::Tensor marlin_gemm_moe( bool replicate_input, bool apply_weights) { TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - TORCH_CHECK(is_k_full, - "The MoE kernel currently supports only is_k_full == true"); + int pack_factor = 32 / b_q_type->size_bits(); int max_par = 4; @@ -1661,10 +1926,15 @@ torch::Tensor marlin_gemm_moe( num_groups = b_scales.size(1); if (has_act_order) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + } else { if (num_groups > 1) { TORCH_CHECK( @@ -1676,13 +1946,13 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe( + marlin_moe::marlin_mm_moe_f16i4( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, num_groups, group_size, num_experts, topk, - moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts, + topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; } From 49d86479eec3c7b7e5e5079a7127a450aedd0500 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Mon, 23 Sep 2024 17:35:07 +0000 Subject: [PATCH 3/9] Fix group ptr increment, add tests for is_k_full==True --- csrc/moe/marlin_moe_ops.cu | 9 ++++-- tests/kernels/test_moe.py | 32 +++++++++++++------ .../layers/fused_moe/fused_marlin_moe.py | 8 +++-- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index a947baa5176c7..2952a0767e22a 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1749,6 +1749,9 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order, is_k_full, max_shared_mem); } + int group_tensor_size = + (!is_k_full && has_act_order) ? prob_k / num_groups : group_size; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, @@ -1826,10 +1829,10 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, const int* sorted_ids_ptr = (const int*)sorted_ids; const int4* s_ptr = (const int4*)s + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * + ((group_tensor_size == -1 ? 1 : prob_k / group_tensor_size) * prob_n / + 8) * expert_idx; - const int* g_idx_ptr = (const int*)g_idx + prob_m * expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b1f0516dfa0b3..5c3c5b8abee86 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -142,6 +142,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("is_k_full", [True, False]) def test_fused_marlin_moe( m: int, n: int, @@ -151,6 +152,7 @@ def test_fused_marlin_moe( group_size: int, act_order: bool, num_bits: int, + is_k_full: bool, ): seed_everything(7) @@ -163,6 +165,9 @@ def test_fused_marlin_moe( return if group_size in (k, n): return + else: + if not is_k_full: + return quant_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -243,6 +248,7 @@ def test_fused_marlin_moe( w1_scale=scales1, w2_scale=scales2, num_bits=num_bits, + is_k_full=is_k_full, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -258,6 +264,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("is_k_full", [True, False]) def test_single_marlin_moe_multiply( m: int, n: int, @@ -267,6 +274,7 @@ def test_single_marlin_moe_multiply( group_size: int, act_order: bool, num_bits: int, + is_k_full: bool, ): if topk > e: return @@ -277,6 +285,9 @@ def test_single_marlin_moe_multiply( return if group_size == k: return + else: + if not is_k_full: + return quant_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -307,15 +318,18 @@ def test_single_marlin_moe_multiply( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe(a, - qweight, - scales, - score, - g_idx, - sort_indices, - topk, - renormalize=False, - num_bits=num_bits) + marlin_output = single_marlin_moe( + a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False, + num_bits=num_bits, + is_k_full=is_k_full, + ) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 866b18d725a8c..8177e846127ee 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -21,6 +21,7 @@ def single_marlin_moe( renormalize: bool, override_config: Optional[Dict[str, Any]] = None, num_bits: int = 8, + is_k_full: bool = True, ) -> torch.Tensor: """ This function computes the multiplication of hidden_states with expert @@ -86,7 +87,7 @@ def single_marlin_moe( intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, - g_idx, perm, workspace, scalar_type, M, N, K, True, E, topk, + g_idx, perm, workspace, scalar_type, M, N, K, is_k_full, E, topk, block_size_m, True, False) return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) @@ -107,6 +108,7 @@ def fused_marlin_moe( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, num_bits: int = 8, + is_k_full: bool = True, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -199,7 +201,7 @@ def fused_marlin_moe( M, 2 * N, K, - True, + is_k_full, E, topk, block_size_m, @@ -223,7 +225,7 @@ def fused_marlin_moe( M, K, N, - True, + is_k_full, E, topk, block_size_m, From 42e4fd6b70b17a13f5d02d98fcb24ac302388acd Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 07:21:37 +0000 Subject: [PATCH 4/9] Assertion, remove superfluous variable --- csrc/moe/marlin_moe_ops.cu | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 2952a0767e22a..6ed39a819a4f1 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1749,9 +1749,6 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, has_act_order, is_k_full, max_shared_mem); } - int group_tensor_size = - (!is_k_full && has_act_order) ? prob_k / num_groups : group_size; - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, @@ -1827,11 +1824,7 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = - (const int4*)s + - ((group_tensor_size == -1 ? 1 : prob_k / group_tensor_size) * prob_n / - 8) * - expert_idx; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -1925,6 +1918,11 @@ torch::Tensor marlin_gemm_moe( " is not size_n = ", size_n); num_groups = b_scales.size(1); + if (!is_k_full) { + TORCH_CHECK(has_act_order, + "if is_k_full is false, has_act_order must be true"); + } + if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); From 65122f8340765f959ca125c341572330e8288b92 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 15:06:01 +0000 Subject: [PATCH 5/9] vllm_implies macro --- csrc/core/exception.hpp | 3 +++ csrc/moe/marlin_moe_ops.cu | 1 + 2 files changed, 4 insertions(+) create mode 100644 csrc/core/exception.hpp diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp new file mode 100644 index 0000000000000..f3b2ffaef6cce --- /dev/null +++ b/csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 6ed39a819a4f1..0bd4d9bc6476b 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,7 @@ #include +#include "core/exception.hpp" #include "core/scalar_type.hpp" template From 567da228ebe5b6c605560fdd3e6e6814710269d4 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 17:07:22 +0200 Subject: [PATCH 6/9] Update csrc/moe/marlin_moe_ops.cu Co-authored-by: Tyler Michael Smith --- csrc/moe/marlin_moe_ops.cu | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 6ed39a819a4f1..132fbbe079700 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1918,10 +1918,7 @@ torch::Tensor marlin_gemm_moe( " is not size_n = ", size_n); num_groups = b_scales.size(1); - if (!is_k_full) { - TORCH_CHECK(has_act_order, - "if is_k_full is false, has_act_order must be true"); - } +TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), "if is_k_full is false, has_act_order must be true"); if (has_act_order) { if (is_k_full) { From 839e3a04701553a8999136f94251906a120f9ae5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Tue, 24 Sep 2024 15:21:32 +0000 Subject: [PATCH 7/9] format --- csrc/moe/marlin_moe_ops.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 97d5b1a640c7a..8476ac0771e9c 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1919,7 +1919,8 @@ torch::Tensor marlin_gemm_moe( " is not size_n = ", size_n); num_groups = b_scales.size(1); -TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), "if is_k_full is false, has_act_order must be true"); + TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), + "if is_k_full is false, has_act_order must be true"); if (has_act_order) { if (is_k_full) { From 585d96acace0cef0f9c43443561d06b9d2d3e75b Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 25 Sep 2024 02:59:25 -0400 Subject: [PATCH 8/9] reserve more cache space for shared scales --- csrc/moe/marlin_moe_ops.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index c852dfad2debd..9d6c0568fe342 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -190,12 +190,12 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int load_groups = tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; + return load_groups * tb_n * 4; } else { int tb_scales = tb_groups * tb_n * 2; - return tb_scales * STAGES; + return tb_scales * STAGES * 2; } } From 22094482e4ad354ee8ab541042ce69824c7cdbd5 Mon Sep 17 00:00:00 2001 From: ElizaWszola Date: Wed, 25 Sep 2024 06:48:43 -0400 Subject: [PATCH 9/9] check if we can get away with not reserving too much cache space in is_k_full==True case --- csrc/moe/marlin_moe_ops.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 9d6c0568fe342..c97b5dbd2a54e 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -195,7 +195,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, } else { int tb_scales = tb_groups * tb_n * 2; - return tb_scales * STAGES * 2; + return tb_scales * STAGES; } }