Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Windows inference build. #5609

Merged
merged 14 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import pkgutil
import importlib
import sys

from .abstract_accelerator import DeepSpeedAccelerator
# During setup stage torch may not be installed, pass on no torch will
Expand All @@ -24,7 +25,7 @@ class CUDA_Accelerator(DeepSpeedAccelerator):

def __init__(self):
self._name = 'cuda'
self._communication_backend_name = 'nccl'
self._communication_backend_name = 'nccl' if sys.platform != 'win32' else 'gloo'
self._compile_backend = "inductor"
if pynvml is None:
self._init_pynvml()
Expand Down
2 changes: 0 additions & 2 deletions build_win.bat
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ set DS_BUILD_AIO=0
set DS_BUILD_CUTLASS_OPS=0
set DS_BUILD_EVOFORMER_ATTN=0
set DS_BUILD_FP_QUANTIZER=0
set DS_BUILD_INFERENCE_CORE_OPS=0
set DS_BUILD_RAGGED_DEVICE_OPS=0
set DS_BUILD_SPARSE_ATTN=0
set DS_BUILD_TRANSFORMER_INFERENCE=0

python setup.py bdist_wheel

Expand Down
37 changes: 20 additions & 17 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,22 +542,23 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
1);

if (layer_id == num_layers - 1) InferenceContext::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLength(),
k * InferenceContext::Instance().GetMaxTokenLength(),
k,
1},
options);

auto prev_value =
torch::from_blob(workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * InferenceContext::Instance().GetMaxTokenLength(),
k * InferenceContext::Instance().GetMaxTokenLength(),
k,
1},
options);
auto prev_key = torch::from_blob(
workspace + offset,
{bsz, heads, all_tokens, k},
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k,
1},
options);

auto prev_value = torch::from_blob(
workspace + offset + value_offset,
{bsz, heads, all_tokens, k},
{hidden_dim * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k * static_cast<int64_t>(InferenceContext::Instance().GetMaxTokenLength()),
k,
1},
options);

return {output, prev_key, prev_value};
}
Expand Down Expand Up @@ -1592,7 +1593,9 @@ std::vector<at::Tensor> ds_rms_mlp_gemm(at::Tensor& input,
auto output = at::from_blob(output_ptr, input.sizes(), options);
auto inp_norm = at::from_blob(inp_norm_ptr, input.sizes(), options);
auto intermediate_gemm =
at::from_blob(intermediate_ptr, {input.size(0), input.size(1), mlp_1_out_neurons}, options);
at::from_blob(intermediate_ptr,
{input.size(0), input.size(1), static_cast<int64_t>(mlp_1_out_neurons)},
options);

auto act_func_type = static_cast<ActivationFuncType>(activation_type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ __global__ void fused_residual_ln(T* output,
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];

mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1,
SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 *
4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16
// Trible-Buffer for B Tile
half __restrict__(*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
half(*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
smem_array + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#ifdef PIPELINE_LEVEL_SMEM
half __restrict__(*read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
half(*__restrict__ read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
smem_array + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
#endif
half __restrict__(*write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
half(*__restrict__ write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] =
smem_array +
((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N;
//
Expand Down Expand Up @@ -265,7 +265,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1,
}

#else
#warning "The FP6 functions are only available on Ampere GPUs."
assert(("The FP6 functions are only available on Ampere GPUs.", false));
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ __device__ __forceinline__ void cp_async(half* smem_ptr,
"l"(global_ptr),
"n"(SizeInBytes));
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
assert(
("The async copy functions are only supported on Ampere and newer architectures", false));
#endif
}

Expand All @@ -40,7 +41,8 @@ __device__ __forceinline__ void cp_async_group_commit()
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;\n" ::);
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
assert(
("The async copy functions are only supported on Ampere and newer architectures", false));
#endif
}

Expand All @@ -51,7 +53,8 @@ __device__ __forceinline__ void cp_async_wait_group()
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
assert(
("The async copy functions are only supported on Ampere and newer architectures", false));
#endif
}

Expand All @@ -64,7 +67,8 @@ __device__ __forceinline__ void cp_async_wait_all()
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_all;\n" ::);
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
assert(
("The async copy functions are only supported on Ampere and newer architectures", false));
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#ifdef PIPELINE_LEVEL_SMEM
template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(
uint32_t __restrict__ Reg[][4],
half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
uint32_t (*__restrict__ Reg)[4],
half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
int slice_id)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
Expand Down Expand Up @@ -56,16 +56,17 @@ __device__ __forceinline__ void B_FromSharedToReg(
}
}
#else
#warning "The matrix load functions are only supported on Ampere and newer architectures"
assert(
("The matrix load functions are only supported on Ampere and newer architectures", false));
#endif
}
#else
// Debug: Whether ldmatrix.trans is required???
// B is in column-major
template <typename TilingConfig>
__device__ __forceinline__ void B_FromSharedToReg(
uint32_t __restrict__ Reg[][4],
half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
uint32_t (*__restrict__ Reg)[4],
half (*__restrict__ read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
int k_offset)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
Expand Down Expand Up @@ -102,14 +103,15 @@ __device__ __forceinline__ void B_FromSharedToReg(
}
}
#else
#warning "The matrix load functions are only supported on Ampere and newer architectures"
assert(
("The matrix load functions are only supported on Ampere and newer architectures", false));
#endif
}
#endif

__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[],
uint32_t __restrict__* a,
uint32_t __restrict__* b)
__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t* __restrict__ c,
uint32_t* __restrict__ a,
uint32_t* __restrict__ b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
Expand All @@ -130,7 +132,7 @@ __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[],
"r"(c[2]),
"r"(c[3]));
#else
#warning "The mma functions are only implemented for Ampere and newer architectures"
assert(("The mma functions are only implemented for Ampere and newer architectures", false));
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ __device__ __forceinline__ void initialize_mma_slice(
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales)
{
// Writing registers
Expand All @@ -54,7 +54,7 @@ __device__ __forceinline__ void core_mma_slice(
uint32_t (*b)[4],
uint32_t* __restrict__ A1_SPTR_read,
uint32_t* __restrict__ A2_SPTR_read,
half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
half (*__restrict__ B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
uint32_t* RPTR_Scales,
int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc
*/
template <int MaxNumOfLinesToCopy, int BLOCK_WARPS>
__device__ __forceinline__ void CopyFromGlobalToShared(
half __restrict__ (*SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
half (*__restrict__ SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
const half* GlobalPTR,
const int GlobalStride,
const int NumOfLinesLeft, // To support arbitrary N dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* Outputs: R1, R2
* Note: Simplified Exponent calculation is applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2)
__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t* R1, uint32_t* R2)
{
*R2 = *R1 & 0x80808080;
*R1 = *R1 >> 2;
Expand All @@ -33,7 +33,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2)
* Outputs: R1, R2
* Note: Simplified Exponent calculation is NOT applied.
*/
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_t* R2)
__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t* R1, uint32_t* R2)
{
//*R2 = *R1 & 0x80808080;
*R2 = *R1 & 0xc0c0c0c0;
Expand All @@ -56,7 +56,7 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_
//*R2 = 0x3c003c00;
}

__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale)
__device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale)
{
half* FP16_1 = reinterpret_cast<half*>(&PackedFP16Pair);
half* FP16_2 = FP16_1 + 1;
Expand All @@ -67,17 +67,17 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc
return output;
}

__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4],
u_int32_t __restrict__* read_RPTR_Frag1,
u_int32_t __restrict__* read_RPTR_Frag2,
u_int32_t* Scales)
__device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (*__restrict__ Reg)[4],
uint32_t* __restrict__ read_RPTR_Frag1,
uint32_t* __restrict__ read_RPTR_Frag2,
uint32_t* Scales)
{
u_int32_t* OutputRegs = reinterpret_cast<u_int32_t*>(Reg);
u_int32_t* Frag1_PTR = read_RPTR_Frag1;
u_int32_t* Frag2_PTR = read_RPTR_Frag2;
uint32_t* OutputRegs = reinterpret_cast<uint32_t*>(Reg);
uint32_t* Frag1_PTR = read_RPTR_Frag1;
uint32_t* Frag2_PTR = read_RPTR_Frag2;
half* Scale_RPTR = reinterpret_cast<half*>(Scales);
u_int32_t Packed_FP6 = 0;
u_int32_t tmp = 0;
uint32_t Packed_FP6 = 0;
uint32_t tmp = 0;
// Dequantizing 32 FP6, each Loop dequantizing 4 FP6
#pragma unroll(8)
for (int i = 0; i < 8; i++) {
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BACKWARD_REDUCE_GLOBAL_TIMER = 'bwd_allreduce'
STEP_MICRO_TIMER = 'step_microstep'
STEP_GLOBAL_TIMER = 'step'
TIME_EPSILON = 1e-6

try:
import psutil
Expand Down Expand Up @@ -262,7 +263,7 @@ def stop(self, global_step=False, report_speed=True):
self.micro_step_count,
self.global_step_count,
self.avg_samples_per_sec(),
self.batch_size / self.step_elapsed_time,
self.batch_size / (self.step_elapsed_time + TIME_EPSILON),
round(get_accelerator().memory_allocated() / 1024**3, 2),
round(get_accelerator().max_memory_allocated() / 1024**3, 2),
))
Expand Down
1 change: 1 addition & 0 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def builder(self):

if not self.build_for_cpu and self.enable_bf16:
compile_args['cxx'].append("-DBF16_AVAILABLE")
compile_args['nvcc'].append("-DBF16_AVAILABLE")

if self.is_rocm_pytorch():
compile_args['cxx'].append("-D__HIP_PLATFORM_AMD__=1")
Expand Down
13 changes: 10 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
The wheel will be located at: dist/*.whl
"""

import pathlib
import os
import shutil
import sys
Expand Down Expand Up @@ -209,9 +210,15 @@ def op_enabled(op_name):
git_branch = "unknown"

if sys.platform == "win32":
shutil.copytree('.\\csrc', '.\\deepspeed\\ops')
shutil.copytree('.\\op_builder', '.\\deepspeed\\ops')
shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator')
shutil.rmtree('.\\deepspeed\\ops\\csrc', ignore_errors=True)
pathlib.Path('.\\deepspeed\\ops\\csrc').unlink(missing_ok=True)
shutil.copytree('.\\csrc', '.\\deepspeed\\ops\\csrc', dirs_exist_ok=True)
shutil.rmtree('.\\deepspeed\\ops\\op_builder', ignore_errors=True)
pathlib.Path('.\\deepspeed\\ops\\op_builder').unlink(missing_ok=True)
shutil.copytree('.\\op_builder', '.\\deepspeed\\ops\\op_builder', dirs_exist_ok=True)
shutil.rmtree('.\\deepspeed\\accelerator', ignore_errors=True)
pathlib.Path('.\\deepspeed\\accelerator').unlink(missing_ok=True)
shutil.copytree('.\\accelerator', '.\\deepspeed\\accelerator', dirs_exist_ok=True)
egg_info.manifest_maker.template = 'MANIFEST_win.in'

# Parse the DeepSpeed version string from version.txt.
Expand Down
Loading