From 25150c1e6e757e5017b147969fe3dd15e10847e9 Mon Sep 17 00:00:00 2001 From: Costin Eseanu Date: Mon, 3 Jun 2024 23:00:44 +0000 Subject: [PATCH 1/5] Fixed the Windows build for inference ops. --- build_win.bat | 2 -- .../transformer/inference/csrc/pt_binding.cpp | 10 ++++---- .../cuda_layer_norm/layer_norm_cuda.cu | 1 - .../cuda_linear/include/kernel_matmul.cuh | 8 +++---- .../cuda_linear/include/ptx_cp.async.cuh | 12 ++++++---- .../core_ops/cuda_linear/include/ptx_mma.cuh | 20 ++++++++-------- .../cuda_linear/include/utils_core.cuh | 4 ++-- .../cuda_linear/include/utils_gmem.cuh | 2 +- .../include/utils_paralleldequant.cuh | 24 +++++++++---------- 9 files changed, 42 insertions(+), 41 deletions(-) diff --git a/build_win.bat b/build_win.bat index 18ff17347fc9..6e67cb8c2aa2 100644 --- a/build_win.bat +++ b/build_win.bat @@ -6,10 +6,8 @@ set DS_BUILD_AIO=0 set DS_BUILD_CUTLASS_OPS=0 set DS_BUILD_EVOFORMER_ATTN=0 set DS_BUILD_FP_QUANTIZER=0 -set DS_BUILD_INFERENCE_CORE_OPS=0 set DS_BUILD_RAGGED_DEVICE_OPS=0 set DS_BUILD_SPARSE_ATTN=0 -set DS_BUILD_TRANSFORMER_INFERENCE=0 python setup.py bdist_wheel diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 1b9f91cd9c88..175c8f0654c8 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -544,8 +544,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), - k * InferenceContext::Instance().GetMaxTokenLength(), + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), k, 1}, options); @@ -553,8 +553,8 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, auto prev_value = torch::from_blob(workspace + offset + value_offset, {bsz, heads, all_tokens, k}, - {hidden_dim * InferenceContext::Instance().GetMaxTokenLength(), - k * InferenceContext::Instance().GetMaxTokenLength(), + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), k, 1}, options); @@ -1592,7 +1592,7 @@ std::vector ds_rms_mlp_gemm(at::Tensor& input, auto output = at::from_blob(output_ptr, input.sizes(), options); auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options); auto intermediate_gemm = - at::from_blob(intermediate_ptr, {input.size(0), input.size(1), mlp_1_out_neurons}, options); + at::from_blob(intermediate_ptr, {input.size(0), input.size(1), static_cast(mlp_1_out_neurons)}, options); auto act_func_type = static_cast(activation_type); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu index 15f52c46622b..fb6dd0578f1d 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_layer_norm/layer_norm_cuda.cu @@ -252,7 +252,6 @@ __global__ void fused_residual_ln(T* output, for (int i = 0; i < unRoll; i++) { T* iteration_buffer = local_buffer + i * T_per_load; T residual_buffer[T_per_load]; - T bias_buffer[T_per_load]; mem_access::load_global( iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh index 0262baef4614..860f70b226cb 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh @@ -179,13 +179,13 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile - half __restrict__(*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; #ifdef PIPELINE_LEVEL_SMEM - half __restrict__(*read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; #endif - half __restrict__(*write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // @@ -265,7 +265,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, } #else -#warning "The FP6 functions are only available on Ampere GPUs." + assert(("The FP6 functions are only available on Ampere GPUs.", false)); #endif } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh index 39874e023539..982d5a80010c 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh @@ -30,7 +30,8 @@ __device__ __forceinline__ void cp_async(half* smem_ptr, "l"(global_ptr), "n"(SizeInBytes)); #else -#warning "The async copy functions are only supported on Ampere and newer architectures" + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); #endif } @@ -40,7 +41,8 @@ __device__ __forceinline__ void cp_async_group_commit() #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cp.async.commit_group;\n" ::); #else -#warning "The async copy functions are only supported on Ampere and newer architectures" + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); #endif } @@ -51,7 +53,8 @@ __device__ __forceinline__ void cp_async_wait_group() #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); #else -#warning "The async copy functions are only supported on Ampere and newer architectures" + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); #endif } @@ -64,7 +67,8 @@ __device__ __forceinline__ void cp_async_wait_all() #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cp.async.wait_all;\n" ::); #else -#warning "The async copy functions are only supported on Ampere and newer architectures" + assert( + ("The async copy functions are only supported on Ampere and newer architectures", false)); #endif } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh index 8023629caac9..34647e860ce6 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh @@ -18,8 +18,8 @@ #ifdef PIPELINE_LEVEL_SMEM template __device__ __forceinline__ void B_FromSharedToReg( - uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t (*__restrict__ Reg)[4], + half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], int slice_id) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -56,7 +56,7 @@ __device__ __forceinline__ void B_FromSharedToReg( } } #else -#warning "The matrix load functions are only supported on Ampere and newer architectures" + assert(("The matrix load functions are only supported on Ampere and newer architectures", false)); #endif } #else @@ -64,8 +64,8 @@ __device__ __forceinline__ void B_FromSharedToReg( // B is in column-major template __device__ __forceinline__ void B_FromSharedToReg( - uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t (*__restrict__ Reg)[4], + half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], int k_offset) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -102,14 +102,14 @@ __device__ __forceinline__ void B_FromSharedToReg( } } #else -#warning "The matrix load functions are only supported on Ampere and newer architectures" + assert(("The matrix load functions are only supported on Ampere and newer architectures", false)); #endif } #endif -__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[], - uint32_t __restrict__* a, - uint32_t __restrict__* b) +__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c, + uint32_t* __restrict__ a, + uint32_t* __restrict__ b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile( @@ -130,7 +130,7 @@ __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[], "r"(c[2]), "r"(c[3])); #else -#warning "The mma functions are only implemented for Ampere and newer architectures" + assert(("The mma functions are only implemented for Ampere and newer architectures", false)); #endif } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh index a65575a1ba5a..bd8a009a02c6 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh @@ -32,7 +32,7 @@ __device__ __forceinline__ void initialize_mma_slice( uint32_t (*b)[4], uint32_t* __restrict__ A1_SPTR_read, uint32_t* __restrict__ A2_SPTR_read, - half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { // Writing registers @@ -54,7 +54,7 @@ __device__ __forceinline__ void core_mma_slice( uint32_t (*b)[4], uint32_t* __restrict__ A1_SPTR_read, uint32_t* __restrict__ A2_SPTR_read, - half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh index d0c58352cd56..3dd7e9e0104e 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh @@ -57,7 +57,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc */ template __device__ __forceinline__ void CopyFromGlobalToShared( - half __restrict__ (*SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], const half* GlobalPTR, const int GlobalStride, const int NumOfLinesLeft, // To support arbitrary N dimensions. diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh index 9f035f12cfcd..2282f6cc52d1 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -17,7 +17,7 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) +__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t* R1, uint32_t* R2) { *R2 = *R1 & 0x80808080; *R1 = *R1 >> 2; @@ -33,7 +33,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) * Outputs: R1, R2 * Note: Simplified Exponent calculation is NOT applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_t* R2) +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t* R1, uint32_t* R2) { //*R2 = *R1 & 0x80808080; *R2 = *R1 & 0xc0c0c0c0; @@ -56,7 +56,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_ //*R2 = 0x3c003c00; } -__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) +__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; @@ -67,17 +67,17 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc return output; } -__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], - u_int32_t __restrict__* read_RPTR_Frag1, - u_int32_t __restrict__* read_RPTR_Frag2, - u_int32_t* Scales) +__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (*__restrict__ Reg)[4], + uint32_t *__restrict__ read_RPTR_Frag1, + uint32_t *__restrict__ read_RPTR_Frag2, + uint32_t* Scales) { - u_int32_t* OutputRegs = reinterpret_cast(Reg); - u_int32_t* Frag1_PTR = read_RPTR_Frag1; - u_int32_t* Frag2_PTR = read_RPTR_Frag2; + uint32_t* OutputRegs = reinterpret_cast(Reg); + uint32_t* Frag1_PTR = read_RPTR_Frag1; + uint32_t* Frag2_PTR = read_RPTR_Frag2; half* Scale_RPTR = reinterpret_cast(Scales); - u_int32_t Packed_FP6 = 0; - u_int32_t tmp = 0; + uint32_t Packed_FP6 = 0; + uint32_t tmp = 0; // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for (int i = 0; i < 8; i++) { From c2496f9d75899c108f923c386fe9e73c633b945f Mon Sep 17 00:00:00 2001 From: Costin Eseanu Date: Mon, 3 Jun 2024 23:02:16 +0000 Subject: [PATCH 2/5] Fixed whitespace. --- .../transformer/inference/csrc/pt_binding.cpp | 37 ++++++++++--------- .../core_ops/cuda_linear/include/ptx_mma.cuh | 6 ++- .../include/utils_paralleldequant.cuh | 4 +- 3 files changed, 26 insertions(+), 21 deletions(-) diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 175c8f0654c8..2d5332578edc 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -542,22 +542,23 @@ std::vector ds_softmax_context(at::Tensor& query_key_value, 1); if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens(); - auto prev_key = torch::from_blob(workspace + offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), - k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), - k, - 1}, - options); - - auto prev_value = - torch::from_blob(workspace + offset + value_offset, - {bsz, heads, all_tokens, k}, - {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), - k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), - k, - 1}, - options); + auto prev_key = torch::from_blob( + workspace + offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k, + 1}, + options); + + auto prev_value = torch::from_blob( + workspace + offset + value_offset, + {bsz, heads, all_tokens, k}, + {hidden_dim * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k * static_cast(InferenceContext::Instance().GetMaxTokenLength()), + k, + 1}, + options); return {output, prev_key, prev_value}; } @@ -1592,7 +1593,9 @@ std::vector ds_rms_mlp_gemm(at::Tensor& input, auto output = at::from_blob(output_ptr, input.sizes(), options); auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options); auto intermediate_gemm = - at::from_blob(intermediate_ptr, {input.size(0), input.size(1), static_cast(mlp_1_out_neurons)}, options); + at::from_blob(intermediate_ptr, + {input.size(0), input.size(1), static_cast(mlp_1_out_neurons)}, + options); auto act_func_type = static_cast(activation_type); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh index 34647e860ce6..56f86a46f6b5 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh @@ -56,7 +56,8 @@ __device__ __forceinline__ void B_FromSharedToReg( } } #else - assert(("The matrix load functions are only supported on Ampere and newer architectures", false)); + assert( + ("The matrix load functions are only supported on Ampere and newer architectures", false)); #endif } #else @@ -102,7 +103,8 @@ __device__ __forceinline__ void B_FromSharedToReg( } } #else - assert(("The matrix load functions are only supported on Ampere and newer architectures", false)); + assert( + ("The matrix load functions are only supported on Ampere and newer architectures", false)); #endif } #endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh index 2282f6cc52d1..11603fcc576c 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -68,8 +68,8 @@ __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scal } __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (*__restrict__ Reg)[4], - uint32_t *__restrict__ read_RPTR_Frag1, - uint32_t *__restrict__ read_RPTR_Frag2, + uint32_t* __restrict__ read_RPTR_Frag1, + uint32_t* __restrict__ read_RPTR_Frag2, uint32_t* Scales) { uint32_t* OutputRegs = reinterpret_cast(Reg); From a6120c855865dc43f729cbd89e59fc56f8848a22 Mon Sep 17 00:00:00 2001 From: Costin Eseanu Date: Wed, 5 Jun 2024 18:20:49 +0000 Subject: [PATCH 3/5] Fixed Windows setup.py not copying files in the right location. --- setup.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 408b300a78b0..06934bd9021f 100755 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ The wheel will be located at: dist/*.whl """ +import pathlib import os import shutil import sys @@ -209,9 +210,12 @@ def op_enabled(op_name): git_branch = "unknown" if sys.platform == "win32": - shutil.copytree('.\\csrc', '.\\deepspeed\\ops') - shutil.copytree('.\\op_builder', '.\\deepspeed\\ops') - shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator') + pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True) + shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True) + pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True) + shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True) + pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True) + shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True) egg_info.manifest_maker.template = 'MANIFEST_win.in' # Parse the DeepSpeed version string from version.txt. From 07aa5bfc5834b75ab6ac2eb8bb3f368cd368fe8b Mon Sep 17 00:00:00 2001 From: Costin Eseanu Date: Wed, 12 Jun 2024 23:11:30 +0000 Subject: [PATCH 4/5] Fix BF16 not being built on Win. Fixed consecutive builds on Windows failing. --- op_builder/builder.py | 1 + setup.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/op_builder/builder.py b/op_builder/builder.py index df54415c3b84..03611bf56284 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -678,6 +678,7 @@ def builder(self): if not self.build_for_cpu and self.enable_bf16: compile_args['cxx'].append("-DBF16_AVAILABLE") + compile_args['nvcc'].append("-DBF16_AVAILABLE") if self.is_rocm_pytorch(): compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1") diff --git a/setup.py b/setup.py index 06934bd9021f..f39eab41938a 100755 --- a/setup.py +++ b/setup.py @@ -210,10 +210,13 @@ def op_enabled(op_name): git_branch = "unknown" if sys.platform == "win32": + shutil.rmtree('.\\deepspeed\\ops\\csrc', ignore_errors=True) pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True) shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True) + shutil.rmtree('.\\deepspeed\\ops\\op_builder', ignore_errors=True) pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True) shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True) + shutil.rmtree('.\\deepspeed\\accelerator', ignore_errors=True) pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True) shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True) egg_info.manifest_maker.template = 'MANIFEST_win.in' From f95d5f32a1a990319d1c8023247a8b08e8034060 Mon Sep 17 00:00:00 2001 From: Costin Eseanu Date: Mon, 17 Jun 2024 20:11:18 +0000 Subject: [PATCH 5/5] Fixed divide by zero caused by coarse time resolution on Windows. Replaced NCCL with GLOO on Windows. --- accelerator/cuda_accelerator.py | 3 ++- deepspeed/utils/timer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 74b004205f2a..06fd443f9829 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -7,6 +7,7 @@ import os import pkgutil import importlib +import sys from .abstract_accelerator import DeepSpeedAccelerator # During setup stage torch may not be installed, pass on no torch will @@ -24,7 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator): def __init__(self): self._name = 'cuda' - self._communication_backend_name = 'nccl' + self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo' self._compile_backend = "inductor" if pynvml is None: self._init_pynvml() diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index dd78b207cc37..00f17dea709c 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -18,6 +18,7 @@ BACKWARD_REDUCE_GLOBAL_TIMER = 'bwd_allreduce' STEP_MICRO_TIMER = 'step_microstep' STEP_GLOBAL_TIMER = 'step' +TIME_EPSILON = 1e-6 try: import psutil @@ -262,7 +263,7 @@ def stop(self, global_step=False, report_speed=True): self.micro_step_count, self.global_step_count, self.avg_samples_per_sec(), - self.batch_size / self.step_elapsed_time, + self.batch_size / (self.step_elapsed_time + TIME_EPSILON), round(get_accelerator().memory_allocated() / 1024**3, 2), round(get_accelerator().max_memory_allocated() / 1024**3, 2), ))