Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix synchronization in allreduce8Read kernel #1457

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 21 additions & 33 deletions ext-src/read-allred.patch
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp
index 1b85136..ee90c2f 100644
index 1b85136..36d85e4 100644
--- a/apps/nccl/src/allreduce.hpp
+++ b/apps/nccl/src/allreduce.hpp
@@ -386,24 +386,353 @@ __global__ void __launch_bounds__(512, 1)
@@ -386,12 +386,323 @@ __global__ void __launch_bounds__(512, 1)
}
}

+template <typename T>
+__global__ void __launch_bounds__(512, 1)
+ allreduce8Read(T* buff, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
+ int rank, int nRanksPerNode, int worldSize, size_t nelems) {
+ mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset, int rank,
+ int nRanksPerNode, int worldSize, size_t nelems) {
+ const int nPeer = nRanksPerNode - 1;
+ const size_t chanOffset = nPeer * blockIdx.x;
+ // assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
Expand All @@ -22,7 +22,7 @@ index 1b85136..ee90c2f 100644
+ int4* buff4 = reinterpret_cast<int4*>(buff);
+ int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);
+
+ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
+ // Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
+ constexpr size_t unitNInt4 = 512;
+ const size_t maxNInt4PerBlock =
+ (((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
Expand All @@ -48,18 +48,18 @@ index 1b85136..ee90c2f 100644
+ }
+ __syncwarp();
+
+ // we can use double buffering to hide synchronization overhead
+ for (size_t itr = 0; itr < nItrs; itr++) {
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ channels[threadIdx.x].signal();
+ channels[threadIdx.x].wait();
+ }
+ __syncthreads();
+ // Wait for other GPUs before reading input from channels
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ channels[threadIdx.x].signal();
+ channels[threadIdx.x].wait();
+ }
+ __syncthreads();
+
+ for (size_t itr = 0; itr < nItrs; itr++) {
+ for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);
+ data = add_vectors<T>(val, data);
+ }
+ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
Expand All @@ -69,27 +69,14 @@ index 1b85136..ee90c2f 100644
+ data);
+ }
+ }
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+
+ offsetOfThisBlock += nInt4PerChunk;
+ }
+
+ if (restNInt4 > 0) {
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ channels[threadIdx.x].signal();
+ channels[threadIdx.x].wait();
+
+ }
+ __syncthreads();
+
+ for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
+ int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
+ for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
+ int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);
+ data = add_vectors<T>(val, data);
+ }
+ resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
Expand All @@ -98,14 +85,15 @@ index 1b85136..ee90c2f 100644
+ data);
+ }
+ }
+
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+ }
+
+ // Synchronize threads before signaling that all results have been written to outChannels
+ __syncthreads();
+ if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
+ outChannels[threadIdx.x].signal();
+ outChannels[threadIdx.x].wait();
+ }
+ __syncthreads();
+}
+
+template <typename T>
Expand Down