diff --git a/benchmarks/benchmark_fp6_llm.py b/benchmarks/benchmark_fp6_llm.py index b6fdca643..ae17764e6 100644 --- a/benchmarks/benchmark_fp6_llm.py +++ b/benchmarks/benchmark_fp6_llm.py @@ -7,12 +7,12 @@ def benchmark(m: int, k: int, n: int): - fp6_weight = torch.randint(256, size=(n, k // 4 * 3), dtype=torch.uint8, device="cuda") + fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda") scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5 - fp6_linear = Fp6LlmLinear(fp6_weight.view(torch.int32), scales) + fp6_linear = Fp6LlmLinear(fp6_weight, scales) fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda") - fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight.view(-1), n, k, dtype=torch.half) * scales[:, None] + fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None] fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") fp6_output = fp6_linear(fp16_act) diff --git a/setup.py b/setup.py index 8268ac61e..a6a2da8a7 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ def read_version(file_path="version.txt"): CUDAExtension, BuildExtension, CUDA_HOME, + IS_WINDOWS ) @@ -52,20 +53,41 @@ def get_extensions(): use_cuda = torch.cuda.is_available() and CUDA_HOME is not None extension = CUDAExtension if use_cuda else CppExtension - extra_link_args = [] - extra_compile_args = { - "cxx": [ - "-O3" if not debug_mode else "-O0", - "-fdiagnostics-color=always", - ], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - ] - } - if debug_mode: - extra_compile_args["cxx"].append("-g") - extra_compile_args["nvcc"].append("-g") - extra_link_args.extend(["-O0", "-g"]) + if not IS_WINDOWS: + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "-O3" if not debug_mode else "-O0", + "-fdiagnostics-color=always", + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + "-t=0", + ] + } + + if debug_mode: + extra_compile_args["cxx"].append("-g") + extra_compile_args["nvcc"].append("-g") + extra_link_args.extend(["-O0", "-g"]) + + else: + extra_link_args = [] + extra_compile_args = { + "cxx": [ + "/O2" if not debug_mode else "/Od", + "/permissive-" + ], + "nvcc": [ + "-O3" if not debug_mode else "-O0", + "-t=0", + ] + } + + if debug_mode: + extra_compile_args["cxx"].append("/ZI") + extra_compile_args["nvcc"].append("-g") + extra_link_args.append("/DEBUG") this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index de7775ddc..ed11fc851 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh #include "configs.h" #include "utils_gmem.cuh" @@ -133,11 +133,12 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile - half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR. + half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; #ifdef PIPELINE_LEVEL_SMEM - half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; #endif - half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; // Copying A tile from Global to Register, Bypassing L1, using double-buffer diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index d0985bd63..bafdd0b4e 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. @@ -36,11 +36,14 @@ #include #include "configs.h" +// MODIFICATION NOTE: to support MSVC +// - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4] +// - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR) #ifdef PIPELINE_LEVEL_SMEM template -__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - int slice_id) { +__device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4], + half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int slice_id) { #ifdef DEBUG_MODE static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); #endif @@ -112,8 +115,10 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[ } #endif +// MODIFICATION NOTE: to support MSVC, the function signature is changed from +// MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b). __device__ __forceinline__ void -MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) +MMA_FP16_M16N8K16(uint32_t * __restrict__ c, uint32_t * __restrict__ a, uint32_t * __restrict__ b) { asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" "{ %0, %1, %2, %3}," diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index 5bfc043ef..07e37d85b 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH @@ -35,12 +35,13 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u } } +// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. template __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A1_SPTR_read, uint32_t* __restrict__ A2_SPTR_read, - half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { // Writing registers @@ -53,13 +54,14 @@ __device__ __forceinline__ void initialize_mma_slice(uint32_t ( B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } +// MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], uint32_t* __restrict__ A1_SPTR_read, uint32_t* __restrict__ A2_SPTR_read, - half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { diff --git a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh index 5c37452e1..a74930ba4 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_gmem.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_gmem.cuh #ifndef UTILS_GMEM_CUH #define UTILS_GMEM_CUH @@ -57,17 +57,18 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; } +// MODIFICATION NOTE: to support MSVC, half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. /* * (1) Copying X rows * 64 columns of FP16 values, originally in row major * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads */ template -__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - const half* GlobalPTR, - const int GlobalStride, - const int NumOfLinesLeft, // To support arbitrary N dimensions. - bool Pred = true) { +__device__ __forceinline__ void CopyFromGlobalToShared(half (* __restrict__ SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) { // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; const int NumOfGroups = NumOfThreads / 8; diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index f6ce4cc04..48b0f968b 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh +// To support MSVC, all instances of u_int32_t are changed to uint32_t. #ifndef UTILS_PARALLELDEQUANT_CUH #define UTILS_PARALLELDEQUANT_CUH @@ -26,7 +27,7 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) { +__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) { *R2 = *R1 & 0x80808080; *R1 = *R1 >> 2; *R1 = *R1 & 0x1f1f1f1f; @@ -41,7 +42,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) * Outputs: R1, R2 * Note: Simplified Exponent calculation is NOT applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) { +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) { //*R2 = *R1 & 0x80808080; *R2 = *R1 & 0xc0c0c0c0; *R1 = *R1 >> 2; @@ -63,7 +64,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_ //*R2 = 0x3c003c00; } -__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { +__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; @@ -73,16 +74,19 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc return output; } -__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], - u_int32_t __restrict__ *read_RPTR_Frag1, - u_int32_t __restrict__ *read_RPTR_Frag2, - u_int32_t *Scales) { - u_int32_t *OutputRegs = reinterpret_cast (Reg); - u_int32_t *Frag1_PTR = read_RPTR_Frag1; - u_int32_t *Frag2_PTR = read_RPTR_Frag2; +// MODIFICATION NOTE: to support MSVC +// - u_int32_t __restrict__ Reg[][4] is changed to below. +// - u_int32_t __restrict__ *read_RPTR_Frag1 is changed to below. similarly for read_RPTR_Frag2 +__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4], + uint32_t * __restrict__ read_RPTR_Frag1, + uint32_t * __restrict__ read_RPTR_Frag2, + uint32_t * Scales) { + uint32_t *OutputRegs = reinterpret_cast (Reg); + uint32_t *Frag1_PTR = read_RPTR_Frag1; + uint32_t *Frag2_PTR = read_RPTR_Frag2; half *Scale_RPTR = reinterpret_cast(Scales); - u_int32_t Packed_FP6 = 0; - u_int32_t tmp = 0; + uint32_t Packed_FP6 = 0; + uint32_t tmp = 0; // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for(int i=0; i<8; i++) {