-
Notifications
You must be signed in to change notification settings - Fork 2
/
ncclEnhance.h
113 lines (108 loc) · 3.47 KB
/
ncclEnhance.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#ifndef _NCCLENHANCE_H
#define _NCCLENHANCE_H
#include "nccl.h"
#include<cstddef>
// This function is copied from nccl
static __inline__ int ncclTypeSize(ncclDataType_t type) {
switch (type) {
case ncclInt8:
case ncclUint8:
return 1;
case ncclFloat16:
#if defined(__CUDA_BF16_TYPES_EXIST__)
case ncclBfloat16:
#endif
return 2;
case ncclInt32:
case ncclUint32:
case ncclFloat32:
return 4;
case ncclInt64:
case ncclUint64:
case ncclFloat64:
return 8;
default:
return -1;
}
}
ncclResult_t NCCLSendrecv(void *sendbuff, size_t sendcount, ncclDataType_t datatype, int peer,
void *recvbuff,size_t recvcount,ncclComm_t comm, cudaStream_t stream)
{
ncclGroupStart();
auto a = ncclSend(sendbuff, sendcount, datatype, peer, comm, stream);
auto b = ncclRecv(recvbuff, recvcount, datatype, peer, comm, stream);
ncclGroupEnd();
if (a||b)
{
if(a)
return a;
return b;
}
return ncclSuccess;
}
// Please be aware that sendcount,recvcount is the count for single rank
ncclResult_t NCCLAlltoall(void *sendbuff, size_t sendcount, ncclDataType_t senddatatype, void *recvbuff,
size_t recvcount, ncclDataType_t recvdatatype, ncclComm_t comm, cudaStream_t stream)
{
ncclGroupStart();
int nRanks;
ncclCommCount(comm, &nRanks);
for (int i = 0; i < nRanks; ++i)
{
auto a = NCCLSendrecv(static_cast<std::byte*>(sendbuff) + i * ncclTypeSize(senddatatype) * sendcount, sendcount, senddatatype, i,
static_cast<std::byte*>(recvbuff) + i * ncclTypeSize(recvdatatype) * recvcount, recvcount, comm, stream);
if (a)
return a;
}
ncclGroupEnd();
return ncclSuccess;
}
// Please be aware that sendcount,recvcount is the count for single rank
ncclResult_t NCCLGather(void *sendbuff, size_t sendcount, ncclDataType_t senddatatype, void *recvbuff,
size_t recvcount, ncclDataType_t recvdatatype, int root,ncclComm_t comm,
cudaStream_t stream)
{
ncclGroupStart();
int myRank,nRanks;
ncclCommUserRank(comm,&myRank);
ncclCommCount(comm, &nRanks);
auto a = ncclSend(sendbuff, sendcount, senddatatype, root, comm, stream);
if(a){
return a;
}
if(myRank==root){
for(int i=0;i<nRanks;++i){
auto b=ncclRecv( static_cast<std::byte*>(recvbuff)+i*ncclTypeSize(recvdatatype)*recvcount,recvcount,recvdatatype,i,comm,stream);
if(b){
return b;
}
}
}
ncclGroupEnd();
return ncclSuccess;
}
// Please be aware that sendcount,recvcount is the count for single rank
ncclResult_t NCCLScatter(void *sendbuff, size_t sendcount, ncclDataType_t senddatatype, void *recvbuff,
size_t recvcount, ncclDataType_t recvdatatype, int root,
ncclComm_t comm, cudaStream_t stream)
{
ncclGroupStart();
int myRank, nRanks;
ncclCommUserRank(comm, &myRank);
ncclCommCount(comm, &nRanks);
if (myRank == root)
{
for (int i = 0; i < nRanks; ++i)
{
auto a = ncclSend(static_cast<std::byte*>(sendbuff) + i * ncclTypeSize(senddatatype) * sendcount, sendcount, recvdatatype, i, comm, stream);
if (a)
return a;
}
}
auto b = ncclRecv(recvbuff, recvcount, recvdatatype, root, comm, stream);
if (b)
return b;
ncclGroupEnd();
return ncclSuccess;
}
#endif