forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ProcessGroupNCCL.cpp
1874 lines (1673 loc) · 63.1 KB
/
ProcessGroupNCCL.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <c10/util/Optional.h>
#include <c10d/ProcessGroupNCCL.hpp>
#include <exception>
#include <map>
#include <tuple>
#include <unordered_set>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Logging.h>
#include <c10d/ParamCommsUtils.hpp>
#include <torch/csrc/cuda/nccl.h>
#include <c10d/Utils.hpp>
namespace c10d {
constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM";
namespace {
constexpr int kBytes = 8;
// RAII helper class to manage NCCL group API and CUDA free mutex.
// The destructor is allowed to throw since this helper class only
// manages group and lock lifetimes.
struct AutoNcclGroup {
AutoNcclGroup() {
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
C10D_NCCL_CHECK(ncclGroupStart());
#endif
}
~AutoNcclGroup() noexcept(false) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
C10D_NCCL_CHECK(ncclGroupEnd());
#endif
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
}
};
// NCCL op mapping
const std::map<ReduceOp, ncclRedOp_t> ncclOp = {
{ReduceOp::MIN, ncclMin},
{ReduceOp::MAX, ncclMax},
{ReduceOp::SUM, ncclSum},
{ReduceOp::PRODUCT, ncclProd},
};
// NCCL type typing
std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kChar, ncclInt8},
{at::kByte, ncclUint8},
{at::kFloat, ncclFloat},
{at::kDouble, ncclDouble},
{at::kInt, ncclInt32},
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301
{at::kBFloat16, ncclBfloat16},
#endif
};
// Helper function that gets the data type and issues error if not supported
ncclDataType_t getNcclDataType(at::ScalarType type) {
auto it = ncclDataType.find(type);
TORCH_CHECK(
it != ncclDataType.end(),
"Input tensor data type is not supported for NCCL process group: ",
type);
return it->second;
}
ncclRedOp_t getNcclReduceOp(const ReduceOp reduceOp, at::Tensor& input) {
try {
if (reduceOp == ReduceOp::SUM && input.scalar_type() == at::kBool) {
// For bool tensors, map sum to max, which both represent a bitwise or.
// This is to prevent overflow issues with sum, since we use uint8 to
// represent a bool (see ncclDataType mapping).
return ncclMax;
}
return ncclOp.at(reduceOp);
} catch (const std::out_of_range& e) {
switch (reduceOp) {
case ReduceOp::BAND:
throw std::runtime_error("Cannot use ReduceOp.BAND with NCCL");
break;
case ReduceOp::BOR:
throw std::runtime_error("Cannot use ReduceOp.BOR with NCCL");
break;
case ReduceOp::BXOR:
throw std::runtime_error("Cannot use ReduceOp.BXOR with NCCL");
break;
default:
throw std::runtime_error("Unhandled ReduceOp");
break;
}
}
}
// Get the deviceList String from the list of devices
std::string getKeyFromDevices(const std::vector<at::Device>& devices) {
std::string deviceList;
for (auto& device : devices) {
if (deviceList.empty()) {
deviceList = std::to_string(device.index());
} else {
deviceList += "," + std::to_string(device.index());
}
}
return deviceList;
}
std::string getKeySendRecv(int myRank, int peer) {
int lowRank = myRank < peer ? myRank : peer;
int highRank = myRank < peer ? peer : myRank;
std::string sendRecvPair =
std::to_string(lowRank) + ":" + std::to_string(highRank);
return sendRecvPair;
}
// Get the list of devices from list of tensors
std::vector<at::Device> getDeviceList(const std::vector<at::Tensor>& tensors) {
std::vector<at::Device> res;
res.reserve(tensors.size());
for (auto& tensor : tensors) {
res.push_back(tensor.device());
}
return res;
}
// [Sync Streams] Helper that lets the input ncclStreams to wait for the current
// stream. NCCL communications run on ncclStreams, but input tensors are
// allocated on different streams (i.e., current streams). Communications on
// ncclStreams cannot start before pending input tensor ops on current streams
// finish. Otherwise, ops on two streams might read/write same tensors
// concurrently.
//
// The synchronization above alone is not enough. We also need to make sure
// input tensors are not freed before their usages on ncclStreams finish. This
// can be achieved by calling c10::cuda::CUDACachingAllocator::recordStream,
// which remembers the usage stream (ncclStream), creates an event on the usage
// stream when GC attempts to free the input tensor, and delays GC until that
// event is done.
void syncStreams(
const std::vector<at::Device>& devices,
std::vector<at::cuda::CUDAEvent>& ncclEvents,
std::vector<at::cuda::CUDAStream>& ncclStreams) {
for (size_t i = 0; i < devices.size(); ++i) {
at::cuda::CUDAStream& ncclStream = ncclStreams[i];
at::cuda::CUDAEvent& ncclEvent = ncclEvents[i];
ncclEvent.record(at::cuda::getCurrentCUDAStream(devices[i].index()));
ncclEvent.block(ncclStream);
}
}
// Given a ncclUniqueId, convert it to a string representation that can be put
// in the store.
std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) {
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&ncclID);
std::ostringstream oss;
for (size_t i = 0; i < NCCL_UNIQUE_ID_BYTES; i++) {
oss << std::hex << static_cast<int>(bytes[i]);
}
return oss.str();
}
std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) {
return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr;
}
// Returns exception's what() given an exception_ptr instance.
std::string getExceptionMsgFromExceptionPtr(
const std::exception_ptr& exceptionPtr) {
TORCH_CHECK(exceptionPtr != nullptr);
try {
std::rethrow_exception(exceptionPtr);
} catch (const std::exception& e) {
return e.what();
} catch (...) {
return "Unknown exception type";
}
}
} // namespace
const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 10000;
const int64_t ProcessGroupNCCL::kWorkCleanupThreadSleepMillis = 1000;
constexpr int64_t kWaitForAbortCommStoreKey = 1000;
constexpr int64_t kSynchronizeBusyWaitMillis = 10;
thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0;
std::ostream& operator<<(
std::ostream& output,
const ProcessGroupNCCL::WorkNCCL& workNCCL) {
std::string workInfo;
if (workNCCL.outputs_) {
workInfo = c10::str(
"WorkNCCL(",
"OpType=",
opTypeToString(workNCCL.opType_),
", TensorShape=",
(*workNCCL.outputs_)[0].sizes(),
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
} else {
workInfo = c10::str(
"WorkNCCL(",
"OpType=",
opTypeToString(workNCCL.opType_),
", Timeout(ms)=",
workNCCL.opTimeout_.count(),
")");
}
return output << workInfo;
}
ProcessGroupNCCL::WorkNCCL::WorkNCCL(
const std::vector<at::Device>& devices,
int rank,
OpType opType,
const char* profilingTitle,
const c10::optional<std::vector<at::Tensor>>& inputs)
: Work(rank, opType, profilingTitle, inputs),
devices_(devices),
workStartTime_(std::chrono::steady_clock::now()) {
// Creates the CUDA event wrappers
// Note: The actual events are lazily created when first recorded to with
// DEFAULT_FLAGS = cudaEventDisableTiming.
cudaEvents_ =
std::make_shared<std::vector<at::cuda::CUDAEvent>>(devices.size());
ncclComms_.resize(devices.size());
}
ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
: Work(w.rank_, w.opType_),
std::enable_shared_from_this<WorkNCCL>(w),
devices_(w.devices_),
cudaEvents_(w.cudaEvents_),
ncclComms_(w.ncclComms_),
blockingWait_(w.blockingWait_),
opTimeout_(w.opTimeout_),
workStartTime_(w.workStartTime_) {
completed_ = w.completed_;
exception_ = w.exception_;
}
ProcessGroupNCCL::WorkNCCL::~WorkNCCL() {}
bool ProcessGroupNCCL::WorkNCCL::isCompleted() {
checkAndSetException();
return exception() || finishedGPUExecutionInternal();
}
bool ProcessGroupNCCL::WorkNCCL::isSuccess() const {
if (exception()) {
// Already detected an exception.
return false;
}
return !checkForNCCLErrors(ncclComms_) && finishedGPUExecutionInternal();
}
void ProcessGroupNCCL::WorkNCCL::checkAndSetException() {
if (exception()) {
// We already have an exception.
return;
}
auto exception_ptr = checkForNCCLErrors(ncclComms_);
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
if (exception_) {
LOG(INFO) << "[Rank " << rank_ << "]"
<< " found async exception when checking for NCCL errors: "
<< getExceptionMsgFromExceptionPtr(exception_);
}
}
void ProcessGroupNCCL::WorkNCCL::setException(
std::exception_ptr exception_ptr) {
std::unique_lock<std::mutex> lock(mutex_);
exception_ = exception_ptr;
}
// Helper that checks if the NCCL kernels are completed on the GPUs
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecution() {
checkAndSetException();
return finishedGPUExecutionInternal();
}
bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const {
for (size_t i = 0; i < devices_.size(); ++i) {
// Checking the work's corresponding CUDA events' status
if (!(*cudaEvents_)[i].query()) {
return false;
}
}
return true;
}
void ProcessGroupNCCL::WorkNCCL::checkAndThrowException() {
// Set the appropriate exception if found.
checkAndSetException();
// Throw an exception, only if we have a valid exception.
if (exception()) {
std::rethrow_exception(exception());
}
}
void ProcessGroupNCCL::WorkNCCL::handleNCCLGuard() {
std::lock_guard<std::mutex> lock(mutex_);
completed_ = true;
if (exception_) {
auto exceptionMsg = c10::str(
"Some NCCL operations have failed or timed out. Due to the ",
"asynchronous nature of CUDA kernels, subsequent GPU operations ",
"might run on corrupted/incomplete data. To avoid this inconsistency, ",
"we are taking the entire process down.");
LOG(ERROR) << exceptionMsg;
C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleNCCLGuard");
std::rethrow_exception(exception_);
}
}
void ProcessGroupNCCL::WorkNCCL::synchronize() {
// Call Synchronize without a timeout. We use this method to avoid adding a
// timeout argument to the public synchronize API.
synchronizeInternal(kNoTimeout);
}
void ProcessGroupNCCL::WorkNCCL::synchronizeStreams() {
for (size_t i = 0; i < devices_.size(); ++i) {
auto currentStream = at::cuda::getCurrentCUDAStream(devices_[i].index());
// Block the current stream on the NCCL stream
(*cudaEvents_)[i].block(currentStream);
}
}
// Waiting on the work's corresponding CUDA events
void ProcessGroupNCCL::WorkNCCL::synchronizeInternal(
std::chrono::milliseconds timeout) {
synchronizeStreams();
// In case of blocking, wait for the operation to complete.
if (blockingWait_) {
// Use the passed in timeout if provided, otherwise use the default
// opTimeout for each WorkNCCL object.
std::chrono::milliseconds workTimeout =
timeout == kNoTimeout ? opTimeout_ : timeout;
// Wait for the operation to complete.
while (!isCompleted()) {
if (timedOut()) {
// When operation times out due to some errors that are not
// detected by nccl communicators, ncclCommWatchdog can not check this
// time out error and thus can not abort ncclComms accordingly.
// So explicitly abort ncclComms here before throwing this timed out
// exception to users, after this, ncclCommWatchdog can detect nccl
// communicators are aborted and clean up devNCCLCommMap_ accordingly.
// if throwing timed out excepiton without aborting nccl communicators
// here, it was observed that CUDA GPU will have 100% utilization and
// can not run new events successfully.
for (const auto& ncclComm : ncclComms_) {
ncclComm->ncclCommAbort();
const auto& storeKey = getNcclAbortedCommStoreKey(
buildNcclUniqueIdStr(ncclComm->getNcclId()));
auto rankStr = std::to_string(rank_);
store_->set(
storeKey,
std::vector<uint8_t>(
reinterpret_cast<const uint8_t*>(rankStr.data()),
reinterpret_cast<const uint8_t*>(rankStr.data()) +
rankStr.size()));
LOG(INFO) << "[Rank " << rank_
<< "] Wrote aborted communicator id to store: " << storeKey;
}
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed =
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - workStartTime_);
std::string exceptionMsg = c10::str(
"[Rank ",
rank_,
"] ",
"Caught collective operation timeout: ",
(*this),
" ran for ",
timeElapsed.count(),
" milliseconds before timing out.");
throw std::runtime_error(exceptionMsg);
}
// Check for errors and throw appropriate exception.
checkAndThrowException();
std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
checkAndThrowException();
}
// Device synchronize only after we've completed timeout checks.
if (!barrierTensors_.empty()) {
// If we use the work to do barrier, we should block here
for (auto& device : devices_) {
at::cuda::CUDAGuard gpuGuard(device);
AT_CUDA_CHECK(cudaDeviceSynchronize());
}
}
}
// Same as calling synchronize().
bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
RECORD_PARAM_COMMS(
rank_, // rank
"wait", // colName
0, // inSize
0, // outSize
at::kByte, // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>()); // outSplitSizes
synchronizeInternal(timeout);
// Always return true, because abort API is not implemented.
return true;
}
void ProcessGroupNCCL::WorkNCCL::abort() {
TORCH_CHECK(false, "ProcessGroupNCCL::WorkNCCL::abort not implemented.");
}
bool ProcessGroupNCCL::WorkNCCL::timedOut() {
auto currentTimepoint = std::chrono::steady_clock::now();
return (
std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - workStartTime_) >= opTimeout_);
}
ProcessGroupNCCL::ProcessGroupNCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
c10::intrusive_ptr<Options> options)
: ProcessGroup(rank, size),
store_(store),
options_(options),
ncclCommCounter_(0),
terminateProcessGroup_(false) {
TORCH_CHECK(
at::cuda::getNumGPUs() != 0,
"ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
blockingWait_ = parseEnvVarFlag(NCCL_BLOCKING_WAIT);
asyncErrorHandling_ = parseEnvVarFlag(NCCL_ASYNC_ERROR_HANDLING);
if (blockingWait_ && asyncErrorHandling_) {
LOG(INFO) << "[Rank " << rank_
<< "] NCCL_BLOCKING_WAIT and NCCL_ASYNC_ERROR_HANDLING "
<< "should not both be enabled. "
<< "Only NCCL_BLOCKING_WAIT is being used in this process.";
asyncErrorHandling_ = false;
}
#ifdef ENABLE_NCCL_ERROR_CHECKING
ncclCommWatchdogThread_ =
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
#endif
if (asyncErrorHandling_) {
workCleanupThread_ = std::thread(&ProcessGroupNCCL::workCleanupLoop, this);
}
const char* ncclDebugLevel = std::getenv("NCCL_DEBUG");
if (!ncclDebugLevel) {
ncclDebugLevel = "UNSET";
}
LOG(INFO) << "[Rank " << rank_
<< "] ProcessGroupNCCL initialized with following options:"
<< "\nNCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
<< "\nNCCL_BLOCKING_WAIT: " << blockingWait_
<< "\nTIMEOUT(ms): " << options_->timeout.count()
<< "\nUSE_HIGH_PRIORITY_STREAM: "
<< options_->is_high_priority_stream
<< "\nNCCL_DEBUG: " << ncclDebugLevel;
}
void ProcessGroupNCCL::setSequenceNumberForGroup() {
if (rank_ == 0) {
// Create and broadcast sequence number
auto seq = 1 + rand();
sequenceNum_ = c10d::SequenceNum(seq);
std::vector<uint8_t> values = c10d::toVec<uint8_t>(seq, kBytes);
store_->set(kSeqNumStoreKey, values);
} else {
// Read rank 0's sequence number from store.
sequenceNum_ = c10d::SequenceNum();
store_->wait({kSeqNumStoreKey}, options_->timeout);
std::vector<uint8_t> values = store_->get(kSeqNumStoreKey);
uint64_t num = c10d::fromVec<uint8_t>(values);
sequenceNum_->set(num);
}
}
uint64_t ProcessGroupNCCL::getSequenceNumberForGroup() {
TORCH_CHECK(
sequenceNum_ != c10::nullopt,
"Sequence number is not set for rank ",
rank_);
return sequenceNum_->get();
}
ProcessGroupNCCL::~ProcessGroupNCCL() {
terminateProcessGroup_.store(true);
watchdogCV_.notify_one();
#ifdef ENABLE_NCCL_ERROR_CHECKING
ncclCommWatchdogThread_.join();
#endif
if (asyncErrorHandling_) {
workMetaListCV_.notify_one();
workCleanupThread_.join();
}
{
// Abort all NCCL Communicators on Process Group Destruction
std::lock_guard<std::mutex> lock(mutex_);
for (auto& it : devNCCLCommMap_) {
auto& ncclComms = it.second;
for (const auto& ncclComm : ncclComms) {
ncclComm->ncclCommAbort();
}
}
}
}
void ProcessGroupNCCL::abortTimedOutCollectives(
std::unordered_set<std::string>& abortedCommIds) {
std::unique_lock<std::mutex> lock(workMetaListMutex_);
for (auto& work : workMetaList_) {
work.checkAndSetException();
// Aborting NCCL Communicators due to errors is already handled above.
if (work.exception()) {
continue;
}
// Check for Timeouts in the WorkNCCL Operations, and abort all
// communicators accordingly.
if (work.timedOut()) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - work.workStartTime_);
std::string exceptionMsg = c10::str(
"[Rank ",
rank_,
"] ",
"Watchdog caught collective operation timeout: ",
work,
" ran for ",
timeElapsed.count(),
" milliseconds before timing out.");
LOG(ERROR) << exceptionMsg;
std::exception_ptr exception_ptr =
std::make_exception_ptr(std::runtime_error(exceptionMsg));
work.setException(exception_ptr);
for (const auto& ncclComm : work.ncclComms_) {
ncclComm->ncclCommAbort();
abortedCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
}
}
}
void ProcessGroupNCCL::ncclCommWatchdog() {
try {
LOG(INFO) << "[Rank " << rank_ << "] NCCL watchdog thread started!";
ncclCommWatchdogInternal();
LOG(INFO) << "[Rank " << rank_
<< "] NCCL watchdog thread terminated normally";
} catch (std::exception& e) {
LOG(INFO) << "[Rank " << rank_
<< "] NCCL watchdog thread terminated with exception: "
<< e.what();
} catch (...) {
LOG(INFO) << "[Rank " << rank_
<< "] NCCL watchdog thread terminated with unknown exception";
}
}
void ProcessGroupNCCL::ncclCommWatchdogInternal() {
while (!terminateProcessGroup_.load()) {
std::unordered_set<std::string> abortedCommIds;
std::unordered_set<std::string> allCommIds;
{
// Loop through the cache of communicators for NCCL errors.
std::lock_guard<std::mutex> lock(mutex_);
for (auto& it : devNCCLCommMap_) {
auto& ncclComms = it.second;
for (const auto& ncclComm : ncclComms) {
allCommIds.emplace(buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
std::exception_ptr ncclErrorException = checkForNCCLErrors(ncclComms);
if (ncclErrorException) {
LOG(INFO)
<< "[Rank " << rank_
<< "] Received NCCL errors for communicators in the cache: \n"
<< "NCCL error: \n"
<< getExceptionMsgFromExceptionPtr(ncclErrorException);
if (blockingWait_ || asyncErrorHandling_) {
LOG(INFO) << "[Rank " << rank_
<< "] Aborting communicators that received errors";
// We abort NCCL communicators that have received errors from this
// thread, and exceptions are set on the corresponding work objects.
// The workCleanupThread will then loop through the unfinished
// collectives and throw exceptions if an exception has been set on
// any of the work objects from this thread.
for (const auto& ncclComm : ncclComms) {
ncclComm->ncclCommAbort();
// Note that we don't remove the aborted communicators from the
// cache. The reason is that if we do remove the communicator
// from the cache, it is possible that a new collective operation
// calls `ncclCommInitRank` to create a new communicator whereas
// other ranks might have failed/timed out and didn't enter
// `ncclCommInitRank`. As a result, when there is a failure on
// a communicator the application receives an exception and its
// their responsibility to destroy the process group and recreate
// it to recover from errors.
abortedCommIds.emplace(
buildNcclUniqueIdStr(ncclComm->getNcclId()));
}
}
}
}
}
if (asyncErrorHandling_) {
abortTimedOutCollectives(abortedCommIds);
}
if (blockingWait_) {
// When we abort a communicator on one rank, it is likely that might cause
// other ranks to hang indefinitely. As a result, whenever we abort a
// communicator, we write its ID to the store. The watchdog on other ranks
// then monitor the store, find an aborted communicator ID and abort their
// respective communicator as well.
// Record the aborted communicators locally and in the store.
for (const auto& abortedCommId : abortedCommIds) {
abortedComms_.emplace(abortedCommId);
const auto& storeKey = getNcclAbortedCommStoreKey(abortedCommId);
auto rankStr = std::to_string(rank_);
store_->set(
storeKey,
std::vector<uint8_t>(
reinterpret_cast<const uint8_t*>(rankStr.data()),
reinterpret_cast<const uint8_t*>(rankStr.data()) +
rankStr.size()));
LOG(INFO) << "[Rank " << rank_
<< "] Watchdog wrote aborted communicator id to store: "
<< storeKey;
}
// Check for any communicators in the store and abort them if needed.
for (const auto& commId : allCommIds) {
if (abortedComms_.find(commId) == abortedComms_.end()) {
// Check if we need to abort them if not already aborted (shouldn't
// wait more than the watchdog sleep time.).
const auto& storeKey = getNcclAbortedCommStoreKey(commId);
try {
store_->wait(
{storeKey},
std::chrono::milliseconds(kWaitForAbortCommStoreKey));
auto val = store_->get(storeKey);
std::string rank(reinterpret_cast<char*>(val.data()), val.size());
LOG(INFO) << "[Rank " << rank_
<< "] Found key in store: " << storeKey
<< ", from rank: " << rank
<< ", aborting appropriate communicators";
// Now abort the appropriate communicators.
std::lock_guard<std::mutex> lock(mutex_);
auto it = ncclIdToCommMap_.find(commId);
TORCH_INTERNAL_ASSERT(it != ncclIdToCommMap_.end());
for (const auto& ncclComm : it->second) {
ncclComm->ncclCommAbort();
}
abortedComms_.emplace(commId);
LOG(INFO) << "[Rank " << rank_
<< "] Aborted communicators for key in store: "
<< storeKey;
} catch (std::exception& e) {
VLOG(1) << "Did not find key in store: " << storeKey
<< ", error: " << e.what();
}
}
}
}
std::unique_lock<std::mutex> lock(watchdogCVMutex_);
watchdogCV_.wait_for(
lock,
std::chrono::milliseconds(kWatchdogThreadSleepMillis),
[&]() -> bool { return terminateProcessGroup_.load(); });
}
}
void ProcessGroupNCCL::workCleanupLoop() {
bool done = false;
while (!terminateProcessGroup_.load() || !done) {
std::list<WorkNCCL> doneWorks;
{
std::unique_lock<std::mutex> lock(workMetaListMutex_);
// We busy-poll the work vector every kWatchdogThreadSleepMillis
// milliseconds as long as the atomic is True.
workMetaListCV_.wait_for(
lock,
std::chrono::milliseconds(kWorkCleanupThreadSleepMillis),
[&]() -> bool { return terminateProcessGroup_.load(); });
for (auto it = workMetaList_.begin(); it != workMetaList_.end();
/* no increment*/) {
auto& work = *it;
if (work.isCompleted()) {
// Handle Exceptions on failed GPU operations and remove completed
// workNCCL objects from work vector.
if (!terminateProcessGroup_.load()) {
work.handleNCCLGuard();
}
doneWorks.push_back(std::move(*it));
it = workMetaList_.erase(it);
} else {
// Increment the iterator if the current WorkNCCL object is not
// completed.
++it;
}
}
done = workMetaList_.empty();
}
doneWorks.clear();
}
}
std::exception_ptr ProcessGroupNCCL::WorkNCCL::checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const {
return checkForNCCLErrorsInternal(ncclComms);
}
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrors(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
return checkForNCCLErrorsInternal(ncclComms);
}
std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal(
const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) {
for (const auto& ncclComm : ncclComms) {
ncclResult_t ncclAsyncErr = ncclComm->checkForNcclError();
if (ncclAsyncErr != ncclSuccess) {
return std::make_exception_ptr(std::runtime_error(
"NCCL error: " + ncclGetErrorWithVersion(ncclAsyncErr) + "\n" +
getNcclErrorDetailStr(ncclAsyncErr)));
}
}
return nullptr;
}
void ProcessGroupNCCL::broadcastUniqueNCCLID(
ncclUniqueId* ncclID,
OpType opType,
const std::string& p2pKey,
int p2pRank) {
// For collective operations:
// For every NCCL communicator that we create we need to broadcast
// a unique ID from rank 0 to all other ranks. This broadcast is
// done by rank 0 setting a key in the store and all other ranks
// retrieving the contents of that key. A single process group
// may create multiple NCCL communicators, so we use a sequence
// number to differentiate between them.
// For point-to-point operations:
// The sequence number will only be increased on 2 out of all the
// processes in a Process Group. So all following collective
// operations will see different sequence numbers which will cause
// runtime errors. To avoid that, use the src:target pair instead
// of sequence number for p2p communications.
std::string storeKey;
if (!isP2POp(opType)) {
storeKey = std::to_string(ncclCommCounter_++);
} else {
storeKey = p2pKey;
}
if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) {
auto vec = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(ncclID),
reinterpret_cast<uint8_t*>(ncclID) + NCCL_UNIQUE_ID_BYTES);
store_->set(storeKey, vec);
} else {
auto vec = store_->get(storeKey);
TORCH_CHECK(vec.size() == NCCL_UNIQUE_ID_BYTES);
std::memcpy(ncclID, vec.data(), vec.size());
}
}
std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
const std::string& devicesKey,
const std::vector<at::Device>& devices,
OpType opType,
int p2pRank,
bool isSendRecvSelf) {
// Sanity check
if (devicesKey.empty()) {
throw std::runtime_error(
"Not able to create/get the NCCL Communicator since "
"the GPU devices are not known");
}
for (auto& device : devices) {
usedDeviceIdxs_.insert(device.index());
}
{
std::lock_guard<std::mutex> lock(mutex_);
if (devNCCLCommMap_.find(devicesKey) != devNCCLCommMap_.end()) {
// Reuse the cached communicator if there is one.
return devNCCLCommMap_[devicesKey];
}
}
// NCCL communicator not cached, create a new entry
std::vector<std::shared_ptr<NCCLComm>> ncclComms;
ncclComms.resize(devices.size());
// Create the unique NCCL ID and broadcast it
ncclUniqueId ncclID;
// For point-to-point communication, lower rank of the two will get unique id.
if (rank_ == 0 || (isP2POp(opType) && p2pRank == 0)) {
C10D_NCCL_CHECK(ncclGetUniqueId(&ncclID));
}
// For point-to-point communication on the same process, don't need broadcast.
if (!isSendRecvSelf) {
// Broadcast so that each process can have a unique NCCL ID
broadcastUniqueNCCLID(&ncclID, opType, devicesKey, p2pRank);
}
at::cuda::OptionalCUDAGuard gpuGuard;
std::vector<at::cuda::CUDAStream> streamVal;
streamVal.reserve(devices.size());
// [Group Start/End Note] This is used to ensure that nccl communicator will
// be created before communication primitives are called. Let's look at this
// example: Using the batch_isend_irecv to send a tensor to a target process.
// On the sender side, the corresponding underlying NCCL calls will look like
// ncclGroupStart() // This is in batch_isend_irecv
// ncclGroupStart() // This is [Note 1]
// ncclCommInitRank() // Inside NCCLComm::create
// ncclSend()
// ncclGroupEnd() // This is [Note 2]
// ncclGroupEnd() // This is in batch_isend_irecv
// With this pattern, the nccl communicator will be created in the last
// ncclGroupEnd which means when ncclSend is processed, the passed
// communicator argument is NULL which will lead to runtime error. So we need
// to "close" all active nccl groups to ensure nccl communicator is actually
// created before encountering any communication calls. This is why we need
// the following for loop.
for (size_t i = 0; i < ncclActiveGroupCounter_; ++i) {
C10D_NCCL_CHECK(ncclGroupEnd());
}
// [Note 1] Create the NCCL communicators for each GPU
C10D_NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < devices.size(); ++i) {
// GPU world size and GPU rank
int numRanks, rank;
if (!isP2POp(opType)) {
numRanks = getSize() * devices.size();
rank = getRank() * devices.size() + i;
} else if (isSendRecvSelf) {
// Same process send and recv.
numRanks = 1;
rank = 0;
} else {
// For point-to-point operation, there are only 2 processes involved so
// the GPU rank is either 0 or 1.
numRanks = 2;
rank = p2pRank;
}
// Get the device index
int deviceIndex = devices[i].index();
gpuGuard.set_index(deviceIndex);
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
// Creates the NCCL streams
streamVal.push_back(
at::cuda::getStreamFromPool(options_->is_high_priority_stream));
}
// [Note 2 ]
C10D_NCCL_CHECK(ncclGroupEnd());
// See [Group Start/End Note]
for (size_t i = 0; i < ncclActiveGroupCounter_; ++i) {
C10D_NCCL_CHECK(ncclGroupStart());
}
ncclStreams_.emplace(devicesKey, std::move(streamVal));
// Note: these events are created with the (default) cudaEventDisableTiming
// flag This flag provides the best performance when used with
// cudaStreamWaitEvent() and cudaEventQuery(). Since we here don't measure the
// performance using cudaEvent, this should be set.
ncclEvents_.emplace(
std::piecewise_construct,
std::make_tuple(devicesKey),
std::make_tuple(devices.size()));
// Hold the lock before modifying the cache.
std::lock_guard<std::mutex> lock(mutex_);
// Record the communicators based on ncclUniqueId.
ncclIdToCommMap_.emplace(buildNcclUniqueIdStr(ncclID), ncclComms);
// Move the NCCL resource to cache
devNCCLCommMap_.emplace(devicesKey, std::move(ncclComms));
return devNCCLCommMap_[devicesKey];
}
namespace {
// Check validity of tensor
void check_gpu_single_tensor(const at::Tensor& tensor) {
if (!tensor.is_cuda() || tensor.is_sparse()) {
throw std::runtime_error("Tensors must be CUDA and dense");
}
if (!tensor.is_contiguous()) {
throw std::runtime_error("Tensors must be contiguous");
}
}
// Check that all `tensors' have the same type and shape and are distributed
// across distinct GPUs.
void check_gpu_tensors(const std::vector<at::Tensor>& tensors) {
if (tensors.size() == 0) {
throw std::runtime_error("Tensor list must be nonempty");
}
if (tensors.size() > static_cast<size_t>(at::cuda::getNumGPUs())) {
throw std::runtime_error(
"Tensor list mustn't be larger than the number of available GPUs");
}
const auto& first = tensors.front();
// Set for ensuring that tensors are on separate devices.
std::unordered_set<decltype(first.get_device())> usedDevices;
usedDevices.reserve(tensors.size());
for (const auto& t : tensors) {
if (!t.is_cuda() || t.is_sparse()) {
throw std::runtime_error("Tensors must be CUDA and dense");
}
if (t.scalar_type() != first.scalar_type()) {
throw std::runtime_error("Tensors must have identical type");
}
if (t.sizes() != first.sizes()) {
throw std::runtime_error("Tensors must have identical size");
}
if (t.strides() != first.strides()) {
throw std::runtime_error("Tensors must have identical strides");
}
if (!t.is_non_overlapping_and_dense()) {
throw std::runtime_error("Tensors must be non-overlapping and dense");
}
const auto inserted = usedDevices.insert(t.get_device()).second;
if (!inserted) {
throw std::runtime_error("Tensors must be on distinct GPU devices");
}
}
}
// Flatten each list in `tensor_lists' for a gather or scatter operation, and
// ensure compatibility with the corresponding tensor in `other'.
std::vector<at::Tensor> flatten_for_scatter_gather(
std::vector<std::vector<at::Tensor>>& tensor_lists,
std::vector<at::Tensor>& other,
size_t world_size) {
if (tensor_lists.size() != other.size()) {
throw std::runtime_error(
"Tensor list operands to scatter/gather must have the same length");