diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h index 55423457214396..b4f10395798612 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm_split_k.h @@ -231,6 +231,7 @@ struct GemmFpAIntBSplitK { public: CUTLASS_HOST_DEVICE Params() = default; + CUTLASS_HOST_DEVICE Params(Arguments const &args, int device_sms, int sm_occupancy) : params_A(args.ref_A.layout()), @@ -842,10 +843,10 @@ struct GemmFpAIntBSplitK { init_iterator_scale(tile_work, params.mode); // Initialize accumulators AccumulatorTile accumulator_tile; - Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); - Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); - Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); - Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); + accumulator_tile.clear(); + // static_assert(print_type()); + + // Perform this tile's range of multiply-accumulate (MAC) iterations Mma mma(shared_storage.main_loop, -1, thread_idx, warp_idx, lane_idx); mma(tile_work.k_iters_remaining,