From e7c3a5c6fef60fc0c919198942c428a698a568cf Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 15 Aug 2024 11:38:19 -0400 Subject: [PATCH] Car revert (#140) * Per @iotamudelta suggestion until the deadlocks issue is better understood Revert "Make CAR ROCm 6.1 compatible. (#137)" This reverts commit 4d2dda61c18bf93fa591cd84a5481ee9dd8ee428. * Per @iotamudelta suggestion until the deadlocks issue is better understood Revert "Optimize custom all reduce (#130)" This reverts commit 636ff019a1c9164321ae4414b1b933cddf853b7e. --- csrc/custom_all_reduce.cuh | 170 +++++++++++++---------------- csrc/custom_all_reduce_test.cu | 7 -- vllm/distributed/parallel_state.py | 16 +-- 3 files changed, 77 insertions(+), 116 deletions(-) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 27e5c271fd8d2..c640b15a2346a 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -43,12 +43,7 @@ struct __align__(16) RankData { const void* ptrs[8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; #endif -struct __align__(16) RankSignals { -#ifndef USE_ROCM - volatile -#endif - Signal* signals[8]; -}; +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; // like std::array, but aligned template @@ -141,28 +136,25 @@ DINLINE O downcast(array_t val) { // This function is meant to be used as the first synchronization in the all // reduce kernel. Thus, it doesn't need to make any visibility guarantees for // prior memory accesses. Note: volatile writes will not be reordered against -// other volatile writes (CUDA-only). +// other volatile writes. template +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { #ifdef USE_ROCM -DINLINE void start_sync(const RankSignals& sg, Signal* self_sg, int rank) { + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { - __atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0, - __ATOMIC_RELAXED); // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1, + __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED); - __atomic_thread_fence(__ATOMIC_ACQ_REL); // wait until we got true from all ranks - while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], - __ATOMIC_RELAXED); + while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED) < flag); } __syncthreads(); -} + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { // reset flag for next time self_sg->end[blockIdx.x][threadIdx.x] = 0; @@ -173,38 +165,36 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); -} #endif +} // This function is meant to be used as the second or the final synchronization // barrier in the all reduce kernel. If it's the final synchronization barrier, // we don't need to make any visibility guarantees for prior memory accesses. template +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { #ifdef USE_ROCM -DINLINE void end_sync(const RankSignals& sg, Signal* self_sg, int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of // testing. Might be the case that hardware provides stronger guarantee than // the memory model. + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { - // reset flag for next time - __atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0, - __ATOMIC_RELAXED); // simultaneously write to the corresponding flag of all ranks. // Latency = 1 p2p write - __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1, - __ATOMIC_RELAXED); - __atomic_thread_fence(__ATOMIC_ACQ_REL); + __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE); // wait until we got true from all ranks - while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], - __ATOMIC_RELAXED)); + while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) < + flag); } - if constexpr (!final_sync) __syncthreads(); -} + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of @@ -221,8 +211,8 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); -} #endif +} template DINLINE P packed_reduce(const P* ptrs[], int idx) { @@ -237,11 +227,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) cross_device_reduce_1stage(RankData* _dp, RankSignals sg, -#ifndef USE_ROCM - volatile -#endif - Signal* self_sg, - T* __restrict__ result, int rank, int size) { + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same @@ -257,22 +244,15 @@ __global__ void __launch_bounds__(512, 1) } template -DINLINE P* get_tmp_buf( -#ifndef USE_ROCM - volatile -#endif - Signal* sg) { +DINLINE P* get_tmp_buf(volatile Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) cross_device_reduce_2stage(RankData* _dp, RankSignals sg, -#ifndef USE_ROCM - volatile -#endif - Signal* self_sg, - T* __restrict__ result, int rank, int size) { + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -475,41 +455,37 @@ class CustomAllreduce { */ template void allreduce(cudaStream_t stream, T* input, T* output, int size, -#ifdef USE_ROCM - int threads = 512, int block_limit = 18){ -#else int threads = 512, int block_limit = 36) { -#endif - auto d = packed_t::P::size; - if (size % d != 0) - throw std::runtime_error( - "custom allreduce currently requires input length to be multiple " - "of " + - std::to_string(d)); - if (block_limit > kMaxBlocks) - throw std::runtime_error("max supported block limit is " + - std::to_string(kMaxBlocks) + ". Got " + - std::to_string(block_limit)); - - RankData* ptrs; - cudaStreamCaptureStatus status; - CUDACHECK(cudaStreamIsCapturing(stream, &status)); - if (status == cudaStreamCaptureStatusActive) { - ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); - graph_unreg_buffers_.push_back(input); - } else { - auto it = buffers_.find(input); - if (it == buffers_.end()) + auto d = packed_t::P::size; + if (size % d != 0) throw std::runtime_error( - "buffer address " + - std::to_string(reinterpret_cast(input)) + - " is not registered!"); - ptrs = it->second; - } + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } - size /= d; - auto bytes = size * sizeof(typename packed_t::P); - int blocks = std::min(block_limit, (size + threads - 1) / threads); + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); @@ -528,27 +504,27 @@ class CustomAllreduce { break; \ } - switch (world_size_) { - REDUCE_CASE(2) - REDUCE_CASE(4) - REDUCE_CASE(6) - REDUCE_CASE(8) - default: - throw std::runtime_error( - "custom allreduce only supports num gpus in (2,4,6,8). Actual num " - "gpus = " + - std::to_string(world_size_)); - } + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } #undef REDUCE_CASE #undef KL -} + } -~CustomAllreduce() { - for (auto [_, ptr] : ipc_handles_) { - CUDACHECK(cudaIpcCloseMemHandle(ptr)); + ~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } } -} -}; // namespace vllm +}; /** * To inspect PTX/SASS, copy paste this header file to compiler explorer and add a template instantiation: diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index c0652e875aeff..9b809caa6e045 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -330,17 +330,10 @@ int main(int argc, char** argv) { // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); // } // } -#ifdef USE_ROCM - for (int sz = 512; sz <= (8 << 22); sz *= 2) { - run(myRank, nRanks, comm, 512, 18, sz + 8 * 47, performance_test); - } -#else for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } -#endif cudaProfilerStop(); - MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; } diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ca2f9d4bb8698..6e9b017ea93b3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -199,18 +199,10 @@ def initialize_model_parallel( if _ENABLE_CUSTOM_ALL_REDUCE: from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) - - # max size defaults to 8 MiB, increase to 16 MiB on ROCm - # due to later crossover - if is_hip(): - _TP_CA_COMMUNICATOR = CustomAllreduce(group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - max_size=2 * 8192 * 1024) - else: - _TP_CA_COMMUNICATOR = CustomAllreduce( - group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - ) + _TP_CA_COMMUNICATOR = CustomAllreduce( + group=_TP_CPU_GROUP, + device=_LOCAL_RANK, + ) # Build the pipeline model-parallel groups. global _PP_DEVICE_GROUP, _PP_CPU_GROUP