-
Notifications
You must be signed in to change notification settings - Fork 22.7k
/
ProcessGroupNCCL.hpp
1280 lines (1027 loc) · 47 KB
/
ProcessGroupNCCL.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#pragma once
#ifdef USE_C10D_NCCL
#if defined(__linux__)
#include <fcntl.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#endif
#include <atomic>
#include <chrono>
#include <deque>
#include <future>
#include <iostream>
#include <list>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <ATen/DynamicLibrary.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/core/Stream.h>
#include <c10/core/StreamGuard.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/custom_class.h>
namespace c10d {
// Control broadcasting of NCCL uniqueId
static std::vector<std::string> TORCH_NCCL_BCAST_UNIQUEID = {
"TORCH_NCCL_BCAST_UNIQUEID"};
// Control whether to always use high priority streams
static std::vector<std::string> TORCH_NCCL_HIGH_PRIORITY = {
"TORCH_NCCL_HIGH_PRIORITY"};
// Control whether or not wait() is blocking or non-blocking.
static std::vector<std::string> TORCH_NCCL_BLOCKING_WAIT = {
"TORCH_NCCL_BLOCKING_WAIT",
"NCCL_BLOCKING_WAIT"};
// TODO: We want to eventually remove this variable and make users to use
// the default value (3 - SkipCleanUp).
// Control whether or not we perform Async Error Handling with NCCL.
static std::vector<std::string> TORCH_NCCL_ASYNC_ERROR_HANDLING = {
"TORCH_NCCL_ASYNC_ERROR_HANDLING",
"NCCL_ASYNC_ERROR_HANDLING"};
// Control whether dumping debug info on watchdog
// timeout is enabled. This variable must be set together with
// TORCH_NCCL_ENABLE_MONITORING=1 and TORCH_NCCL_TRACE_BUFFER_SIZE > 0.
static std::vector<std::string> TORCH_NCCL_DUMP_ON_TIMEOUT = {
"TORCH_NCCL_DUMP_ON_TIMEOUT"};
// Control whether Desync Debug is enabled. This variable must be set
// together with TORCH_NCCL_ASYNC_ERROR_HANDLING.
static std::vector<std::string> TORCH_NCCL_DESYNC_DEBUG = {
"TORCH_NCCL_DESYNC_DEBUG",
"NCCL_DESYNC_DEBUG"};
// Enable recording start-events for all ProcessGroupNCCL collectives, and
// compute accurate collective timing per-collective. (Note: end-events are
// recorded by default. Turn on this flag can increase chances of a watchdog
// hang due to performing a CUDA event query which eventually calls
// cudaEventElapsedTime() API.
static std::vector<std::string> TORCH_NCCL_ENABLE_TIMING = {
"TORCH_NCCL_ENABLE_TIMING",
"NCCL_ENABLE_TIMING"};
// Enable monitoring thread which aborts the process when the ProcessGroupNCCL
// Watchdog thread gets stuck and no heartbeat is detected after
// TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC. This can happen due to calling CUDA/NCCL
// APIs that may hang. It is Useful to prevent jobs being stuck for a prolonged
// time than necessary tying up cluster resources.
static std::vector<std::string> TORCH_NCCL_ENABLE_MONITORING = {
"TORCH_NCCL_ENABLE_MONITORING"};
// Control the watchdog heartbeat timeout period after which the monitoring
// thread will abort the process.
static std::vector<std::string> TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC = {
"TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"};
// Whether to rethrow CUDA Errors in the watchdog (default true)
static std::vector<std::string> TORCH_NCCL_RETHROW_CUDA_ERRORS = {
"TORCH_NCCL_RETHROW_CUDA_ERRORS"};
// The maximum number of events we store in the flight recorder's ring buffer.
// (One event could be the start or end of a collective, for example).
static std::vector<std::string> TORCH_NCCL_TRACE_BUFFER_SIZE = {
"TORCH_NCCL_TRACE_BUFFER_SIZE"};
// Control how much extra time we will wait for dumping the debugging info
// before we exit and throws timeout exception.
static std::vector<std::string> TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = {
"TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"};
// Control the interval inside the monitoring thread to check the coordinated
// signal from other ranks, e.g. to dump the debugging information.
static std::vector<std::string> TORCH_NCCL_COORD_CHECK_MILSEC = {
"TORCH_NCCL_COORD_CHECK_MILSEC"};
// Whether to log C++ stack traces on unclean shutdown (default true)
static std::vector<std::string> TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN = {
"TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN"};
// Control whether to use CudaEventCache for the collective in watchdog thread.
// We noticed in the past when cuda global lock is held, destroying CudaEvent
// can cause a hang.
static std::vector<std::string> TORCH_NCCL_CUDA_EVENT_CACHE = {
"TORCH_NCCL_CUDA_EVENT_CACHE"};
static std::vector<std::string> TORCH_NCCL_NAN_CHECK = {"TORCH_NCCL_NAN_CHECK"};
constexpr const char* NCCL_BACKEND_NAME = "nccl";
constexpr const char* EXCEPTION_DUMP = "exception_dump";
constexpr const int kWorkStatusUpdatePeriodMs = 30 * 1000; // 30 seconds
constexpr auto kProcessGroupNCCLDefaultTimeout =
std::chrono::milliseconds(10 * 60 * 1000);
// NoHandling: do not handle asynchronous NCCL errors
// TearDown: tear down process upon error, see `WorkNCCL::handleException`
// CleanUpOnly: just clean up collectives and abort communicators without
// tearing down process SkipCleanUp: (this is a temporary option and can be
// removed in future) tear down process without cleaning up NCCL communicators.
// This should be used as a last resort in case `ncclCommAbort` itself is
// hanging
enum ErrorHandlingMode {
NoHandling = 0,
TearDown = 1,
CleanUpOnly = 2,
SkipCleanUp = 3
};
#define SHOULD_CLEAN_UP(a) (a != NoHandling && a != SkipCleanUp)
#define SHOULD_TEAR_DOWN(a) (a != NoHandling && a != CleanUpOnly)
#define PRINT_COLLECTIVE_HASH_SIGNATURE(phase, opType, numel, hashValue) \
LOG(WARNING) << logPrefix() << "Hash of " << phase << " to NCCL " << opType \
<< " with size " << numel << " is " << hashValue;
// If set, ProcessGroupNCCL doesn't use recordStream calls to ensure
// caching allocator safety for tensors used on both user-facing and
// internal comm streams.
// Instead, it stashes live references to those tensors until after
// user-facing streams are synced with comm streams.
// See stashed_for_allocator_safety_ below.
static std::vector<std::string> TORCH_NCCL_AVOID_RECORD_STREAMS = {
"TORCH_NCCL_AVOID_RECORD_STREAMS"};
// If set, ProcessGroupNCCL registers postAlloc and preFree hooks to cuda cache
// allocator so that whenever a tensor is allocated or freed, ProcessGroupNCCL
// can register/deregister the tensor on all available NCCL communicators.
static std::vector<std::string> TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK =
{"TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK",
"NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK"};
#if defined(__linux__)
struct DumpPipe {
DumpPipe(int rank) {
std::string fileStem =
getCvarString({"TORCH_NCCL_DEBUG_INFO_PIPE_FILE"}, "");
if (fileStem.empty() ||
getCvarInt({"TORCH_NCCL_TRACE_BUFFER_SIZE"}, 0) <= 0) {
return;
}
TORCH_CHECK(!fileStem.empty(), "TORCH_NCCL_DEBUG_INFO_PIPE_FILE is empty");
std::string filename = c10::str(fileStem, rank, ".pipe");
TORCH_CHECK(
unlink(filename.c_str()) != -1 || errno == ENOENT,
"Error removing existing named pipe ",
filename);
TORCH_CHECK(
mkfifo(filename.c_str(), 0666) != -1,
"Error creating named pipe ",
filename);
fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK);
LOG(INFO) << "Pipe file " << filename
<< " has been opened, write to it to trigger NCCL Debug Dump.";
TORCH_CHECK(fd_ != -1, "Error opening named pipe ", filename);
}
bool shouldDump() {
if (fd_ == -1) {
return false;
}
// NOLINTNEXTLINE(*array*)
char buf[128]{};
// non-blocking from O_NONBLOCK above.
// Ignore EINTR because we already will poll this
// again later.
ssize_t bytesRead = read(fd_, &buf, 128);
return bytesRead > 0;
}
~DumpPipe() {
if (fd_ != -1) {
close(fd_);
}
}
private:
int fd_ = -1;
};
#else
struct DumpPipe {
DumpPipe(int rank) {}
bool shouldDump() {
return false;
}
};
#endif
// ProcessGroupNCCL implements NCCL bindings for c10d.
//
// All functions of the class are expected to be called in the same order
// across all processes in the process group. This is the only way that we
// can guarantee to match up the same calls among all processes.
//
// All NCCL functions provided by this class are asynchronous functions. More
// specifically, each NCCL call is scheduled on a separate CUDA stream that is
// different from the current CUDA stream. This is for the purpose of
// achieving potentially concurrency and better performance. As a result,
// it is the callers' responsibility to make sure that the CUDA stream their
// code works on needs to wait for the NCCL operation from
// this class.
//
// This can be done by calling:
//
// either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
// functionality and are synonyms.
//
// Also note that WorkNCCL::finishedGPUExecution() is a helper function only
// provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
// finished execution on the GPU (not just scheduled).
//
// Example on using the NCCL process group
//
// ProcessGroupNCCL pg(store, rank, size);
// std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
//
// // At this point, NCCL kernel has already by queued successfully
// // Now, let current stream wait for the NCCL to finish, this function is
// // async operation as well
//
// work->wait()
//
// // Now continue on other work in the current stream.
class TORCH_API ProcessGroupNCCL : public Backend {
public:
class WorkNCCL : public Work, public std::enable_shared_from_this<WorkNCCL> {
public:
friend struct WorkInfo;
// Constructor takes a list of CUDA devices
WorkNCCL(
std::string pgUID,
std::string pgDesc,
at::Device& device,
int rank,
OpType opType,
uint64_t seq,
bool isP2P = false,
const char* profilingTitle = nullptr,
const std::optional<std::vector<at::Tensor>>& inputs = std::nullopt,
bool desyncDebug = false,
bool enableTiming = false,
bool cudaEventCacheEnabled = false,
DebugLevel distDebugLevel = DebugLevel::Off);
// Copy constructor doing partial copy without outputs_. Cleanup thread
// monitors and removes finished works. However it will deadlock when
// destructs outputs_ tensors who are view tensors in autograd graph.
WorkNCCL(const WorkNCCL& w);
~WorkNCCL() override;
// Checks if the NCCL kernel has started to execute.
bool isStarted();
// Checks if request has completed. In this specific case of NCCL, it checks
// if the NCCL operation has completed on the GPU in its own NCCL stream.
// Non-blocking operation.
bool isCompleted() override;
bool isSuccess() const override;
// Same as calling synchronize() for NCCL work if timeout is not set.
// Otherwise, it will block the CPU thread until the NCCL work is completed
// or timed out. If timeout, exception will be thrown.
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
void abort() override;
// Let current stream wait on the completion of the NCCL work
// Throws on exceptions.
void synchronize() override;
// Synchronize streams by blocking each on the NCCL stream
void synchronizeStream();
// Helper function to handle exception (throw if needed).
void handleException(ErrorHandlingMode asyncErrorHandling);
// Helper function that checks if the NCCL kernels have finished
// execution on the GPUs
bool finishedGPUExecution();
// Get a Future object that will be marked as completed internally.
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
// Get a Future result of each work (e.g. success, different error types).
// instead of the tensor output.
c10::intrusive_ptr<c10::ivalue::Future> getFutureResult() override;
float getDuration() const override;
uint64_t getSequencenumber() const override;
const std::string& logPrefix() const;
// Helper function that sets an exception_ptr on the WorkNCCL object.
void setException(std::exception_ptr exception_ptr);
// Helper function that returns True if the WorkNCCL object has timed out
// and False otherwise.
// In case of timeout, set exception on the WorkNCCL object.
bool checkTimeout(
std::optional<std::chrono::milliseconds> timeout = std::nullopt);
// Print the traceback of the collective at call time
void printTraceback() const;
std::vector<at::Tensor> result() override;
protected:
// The process group unique id
std::string pgUID_;
// The process group description
std::string pgDesc_;
// The cached list of CUDA devices to operate on
at::Device device_;
// The start CUDA event of NCCL operator tracking this work item. These
// start CUDA events are needed by desync debugging if enabled.
std::shared_ptr<at::cuda::CUDAEvent> ncclStartEvent_;
// The end CUDA event of NCCL operator tracking this work item.
std::shared_ptr<at::cuda::CUDAEvent> ncclEndEvent_;
// The NCCL communicator used for this work item.
std::shared_ptr<NCCLComm> ncclComm_;
// whether this work is a barrier op
bool isBarrierOp_{false};
// Clone of blockingWait_ from ProcessGroupNCCL.
bool blockingWait_{false};
// Clone of avoidRecordStreams_ from ProcessGroupNCCL.
bool avoidRecordStreams_{false};
// Clone of opTimeout_ from ProcessGroupNCCL.
std::chrono::milliseconds opTimeout_{};
// Ephemeral timeouts are owned by exactly one work,
// and reset after that work completes.
// There may be more than one ephemeral timeout active at the same time,
// and this variable is used to track the ownership of ephemeral timeout.
std::chrono::milliseconds ownedEphermeralTimeout_ =
std::chrono::milliseconds(0);
// Time point representing when the work started.
std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
// Record the sequential number of collective or p2p.
uint64_t seq_;
bool isP2P_;
// Indicates if the nccl start event has been updated to the store trace.
// This will be used by desync debug.
bool startTraceUpdated_{false};
// Record collective sizes for debug. We only record the size on the first
// device as multi-device per process is deprecated
size_t numelIn_ = -1;
size_t numelOut_ = -1;
// Wrapper method for the static checkForNCCLErrors which can be overridden
// for tests.
virtual std::exception_ptr checkForNCCLErrors();
friend std::ostream& operator<<(
std::ostream& output,
const WorkNCCL& workNCCL);
private:
// Checks for NCCL errors and sets an appropriate exception_ptr.
void checkAndSetException();
// Just checks whether GPU execution has started, without modifying
// exception_ptr.
bool startedGPUExecutionInternal() const;
// Just checks whether GPU execution has completed, without modifying
// exception_ptr.
bool finishedGPUExecutionInternal() const;
// Reference to the store so that we can write aborted communicators
// to the store.
c10::intrusive_ptr<Store> store_;
// Store a reference to NCCL collective's outputs, used by result and to
// give a more descriptive message when representing the Work as a string.
std::shared_ptr<std::vector<at::Tensor>> outputs_;
// TORCH_NCCL_AVOID_RECORD_STREAMS implementation helper.
// Stores references to participating non-output tensors (ie inputs,
// flattened intermediates).
// We'll clear this list in synchronizeStream, just after user-facing
// stream(s) are synced with the nccl work stream(s).
// By keeping these refs (as well as outputs_) alive until after the
// collective's work rejoins the user-facing streams, we achieve
// caching allocator safety without any recordStream calls.
// For in-place collectives, some refs stashed here may alias outputs_,
// but that doesn't do any harm.
std::shared_ptr<std::vector<at::Tensor>> stashed_for_allocator_safety_;
// The future returned by getFuture.
c10::intrusive_ptr<at::ivalue::Future> future_;
// the future result (e.g., success or failure) of the work
c10::intrusive_ptr<at::ivalue::Future> futureWorkResult_;
bool timingEnabled_;
// unique id used to tell the trace buffer that this
// work has completed
std::optional<uint64_t> trace_id_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};
class CUDAEventCache {
public:
CUDAEventCache();
std::shared_ptr<at::cuda::CUDAEvent> create(bool timing);
static CUDAEventCache& get();
private:
std::mutex cacheMutex_;
// NOTE: We intentionally store raw pointers so that
// we do not attempt to destroy the event objects on process exit,
// because cuda may be gone.
std::array<std::deque<at::cuda::CUDAEvent*>, 2>
eventsArray_; // 0 for timing=false, 1 for timing=true
};
struct Options : Backend::Options {
// NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
// operations. This is only used when blockingWait_ is enabled.
explicit Options(bool is_high_priority_stream = false);
// return intrusive_ptr of the object
static c10::intrusive_ptr<Options> create(
bool is_high_priority_stream = false) {
return c10::make_intrusive<Options>(is_high_priority_stream);
}
// Schedule NCCL operations on high priority CUDA streams
bool is_high_priority_stream;
#ifdef NCCL_HAS_COMM_NONBLOCKING
// Configure ranks
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
#endif
// Optional "parent" backend and color to create communicators from
// via `ncclCommSplit`
std::shared_ptr<ProcessGroupNCCL> split_from;
// Color to use for `ncclCommSplit`, values:
// * Non-negative value: in group;
// * NCCL_SPLIT_NOCOLOR (-1): not in group;
// * NCCL_SPLIT_NOCOLOR - 1: uninitialized.
// [Note 1]: the type must be `int` instead of `int64_t` because NCCL API
// accepts int. Otherwise, an implicit conversion may happen at the API call
// and the value may become negative.
// [Note 2]: this member is pybinded to Python, the value passed from Python
// must be within the numerical range of C++ int. Otherwise, Python will
// raise a RuntimeError saying type is incompatible. See also
// `_process_group_color` in `distributed_c10d.py`.
#ifdef NCCL_HAS_COMM_SPLIT
int split_color{NCCL_SPLIT_NOCOLOR - 1};
#else
// [Note 3]: for older NCCL versions, NCCL_SPLIT_NOCOLOR is not defined. But
// `split_color` is pybinded to Python, so we need to define it. So we use
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
int split_color{-2};
#endif
std::vector<uint64_t> global_ranks_in_group;
std::string group_name;
};
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
class DesyncDebugger {
public:
// Initialize and enable DesyncDebugger
void init(int rank, int size, c10::intrusive_ptr<Store> store);
// Run desync debug. This function is called by watchdog at time of timeout.
void run();
// Log work start to store.
void logWorkStart(WorkNCCL& work);
// Log work end to store.
void logWorkEnd(WorkNCCL& work);
private:
// Whether desync debug is enabled.
// If false, all functions are no-op.
bool enabled_{false};
// From ProcessGroupNCCL
int rank_;
int size_;
// Reference to the store so that we can log start/end event.
c10::intrusive_ptr<Store> store_;
// The store keys to trace the last NCCL collective kernel CUDA events -
// start event and end event respectively. These are used to do desync root
// cause analysis.
std::string traceKeyStart_;
std::string traceKeyEnd_;
};
// If you wish to create multiple process groups, each with a potentially
// different rank and size, you can do so by passing a new store instance
// to each one. If you have only a single store object, you can
// use the `c10d::PrefixStore` to derive scoped instances.
// This is also what the Python API in torch.distributed does.
//
// The process group instance keeps a reference to the store because
// it may be used long after the constructor runs. In fact, the constructor
// doesn't create any NCCL communicators. A single NCCL communicator can
// only be used on a specific set of devices, and are therefore created
// on-demand when a collective runs. If another collective is executed later,
// against a different set of devices, the process group creates another NCCL
// communicator. These NCCL communicators are cached and reused if possible.
//
ProcessGroupNCCL(
c10::intrusive_ptr<Store> store,
int rank,
int size,
c10::intrusive_ptr<Options> options = Options::create());
// This constructor includes the deprecated `groupName` argument.
// If you have existing code that uses the `groupName`, you can replace
// it by specifying a `c10d::PrefixStore(groupName, store)` for store.
C10_DEPRECATED ProcessGroupNCCL(
const c10::intrusive_ptr<Store>& store,
int rank,
int size,
const std::string& groupName,
c10::intrusive_ptr<Options> options = Options::create())
: ProcessGroupNCCL(store, rank, size, std::move(options)) {}
~ProcessGroupNCCL() override;
// This function returns a local uid for ProcessGroupNCCL.
uint64_t getUid() {
return static_cast<uint64_t>(local_id_);
}
c10::intrusive_ptr<Options> getOptions() {
return options_;
}
const std::string getBackendName() const override {
return std::string(NCCL_BACKEND_NAME);
}
bool supportsSplitting() const override {
return true;
}
void startCoalescing() override;
c10::intrusive_ptr<Work> endCoalescing() override;
// For specifying a composite optype, such as ALLGATHER and REDUCE_SCATTER
c10::intrusive_ptr<Work> endCoalescing(OpType optype);
c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<Work> _broadcast_oop(
at::Tensor& outputTensors,
at::Tensor& inputTensors,
const BroadcastOptions& opts = BroadcastOptions());
c10::intrusive_ptr<Work> allreduce_sparse(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts =
AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<Work> _reduce_oop(
at::Tensor& outputTensors,
at::Tensor& inputTensors,
const ReduceOptions& opts = ReduceOptions());
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& outputbuffer,
at::Tensor& inputbuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& outputTensorLists,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& outputTensor,
at::Tensor& inputTensor,
std::vector<int64_t>& outputSplitSizes,
std::vector<int64_t>& inputSplitSizes,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> alltoall(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllToAllOptions& opts = AllToAllOptions()) override;
c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
void groupStart();
void groupEnd();
void groupEndNonblocking(const std::shared_ptr<NCCLComm>& comm);
c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
// Unsupported Ops
c10::intrusive_ptr<Work> recvAnysource(
std::vector<at::Tensor>& tensors,
int tag) override;
// Agrees on an initial sequence number for the whole group by having rank 0
// create it and broadcast it to other ranks using the store.
void setSequenceNumberForGroup() override;
// Retrieves the current sequence number for the whole group, which should be
// in sync. If the returned number is not consistent across the group, it
// may indicate that there is some sort of collective desynchronization.
uint64_t getSequenceNumberForGroup() override;
// Return the total number of splits the communicators held by this process
// group have performed. Counts ncclCommCreateFromRanks() for ncclx v2.21.5+
uint64_t getCommSplitCounter() const;
void registerOnCompletionHook(
std::function<void(std::shared_ptr<WorkInfo>)>&& hook) override;
void waitForPendingWorks() override;
void enableCollectivesTiming() override;
// Helper function for iteratively aborting communicators in the provided map
void abortCommsFromMap(
std::unordered_map<std::string, std::shared_ptr<NCCLComm>>& ncclCommsMap,
const std::optional<std::string>& abortReason);
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> initIntraNodeComm();
// Destroy (shutdown) this backend -- normal exit.
void shutdown();
// Provides an API to abort the ProcessGroup (similar to ncclCommAbort)
// instead of relying on ProcessGroupNCCL destructor.
void abort();
void eagerConnectSingleDevice(at::Device device) override;
void performNocolorSplit(at::Device device);
// If all comms on this PG are fully initialized, return true.
bool isInitialized();
// Performs NCCL user buffer registration for all buffers in
// the given MemPool
void registerMemPool(c10::cuda::MemPool* pool);
// Performs NCCL user buffer de-registration for all buffers in
// the given MemPool
void deregisterMemPool(c10::cuda::MemPool* pool);
// This method adds a temporary extension for the timeout period,
// applying to all collectives between the calling of this API and
// the completion of the first collective on the GPU. While this feature
// provides flexibility in specific scenarios, it introduces statefulness
// to timeout setting. Therefore, it is advisable to use this API sparingly
// and consider alternative approaches, such as directly setting the timeout
// or utilizing a barrier collective (one can set any timeout to the barrier),
// whenever feasible.
void addEphemeralTimeout(const std::chrono::milliseconds& timeout);
// This function is only intended for testing purposes because we don't
// want to expose the `WorkNCCL` via pybind. It verifies whether the
// `opTimeout_` of the provided WorkNCCL instance is the same as the specified
// timeout.
bool verifyWorkTimeoutForTest(
const c10::intrusive_ptr<Work>& work,
const std::chrono::milliseconds& timeout);
protected:
// Helper that broadcasts nccl unique ID to all ranks through the store
void broadcastUniqueNCCLID(
ncclUniqueId* ncclID,
bool isSingleP2POp,
const std::string& devicesKey,
int p2pRank);
// Helper that looks up the cached NCCL communicators only
std::shared_ptr<NCCLComm> getNCCLComm(const std::string& deviceKey);
std::shared_ptr<NCCLComm> initNCCLComm(
const std::string& deviceKey,
at::Device& device,
OpType opType,
int p2pRank = 0,
bool isSendRecvSelf = false);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
std::shared_ptr<NCCLComm>& ncclComm);
// Ensure thaht if record is True, the work obj will be enqueued via
// workEnqueue
virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
at::Device& device,
int rank,
OpType opType,
bool isP2P,
const char* profilingTitle = nullptr,
const std::vector<at::Tensor>& inputs = {},
const std::vector<at::Tensor>& outputs = {},
bool record = false);
// In the timeout case and we will dump debug info such as the NCCL flight
// recorder to storage. Down the road, if we have more complicated or blocking
// operations, we might need to use a side thread to do it.
bool dumpDebuggingInfo();
// Abort all communicators on this rank.
bool abortComms(const std::optional<std::string>& abortReason = std::nullopt);
// A helper function to check if nonblocking API mode should be used.
// Use this helper instead of directly checking `useNonblocking_` variable.
bool useNonblocking();
private:
int globalRankStart;
int globalRankStride;
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
//
// ncclResult_t fn(at::Tensor& input, at::Tensor& output,
// ncclComm_t, at::cuda::CUDAStream&);
// void {pre,post}(std::vector<at::cuda::CUDAStream&>);
template <typename Fn>
c10::intrusive_ptr<Work> collective(
at::Tensor& input,
at::Tensor& output,
Fn fn,
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
at::Tensor& input,
at::Tensor& output,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> collective(
std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
Fn fn,
PreProcess pre,
PostProcess post,
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false,
bool nanCheck = true);
template <typename Fn>
c10::intrusive_ptr<Work> collectiveCoalesced(
std::vector<at::Tensor>& input,
std::vector<at::Tensor>& output,
Fn fn,
OpType opType,
const char* profilingTitle = nullptr,
bool avoidRecordStreams = false);
// Helper that encapsulates work shared across point-to-point communication
// primitives. It is the same structure as the helper used for collective
// communication primitives.
template <typename Fn>
c10::intrusive_ptr<Work> pointToPoint(
at::Tensor& tensor,
Fn fn,
int peer,
OpType opType,
const char* profilingTitle = nullptr);
template <typename Fn, typename PreProcess, typename PostProcess>
c10::intrusive_ptr<Work> pointToPoint(
at::Tensor& tensor,
Fn fn,
int peer,
OpType opType,
PreProcess pre,
PostProcess post,
const char* profilingTitle);
c10::intrusive_ptr<Work> allreduce_impl(
at::Tensor& tensor,
const AllreduceOptions& opts = AllreduceOptions());
// Checks for NCCL errors on each of the communicators and returns an
// appropriate exception_ptr (nullptr if no errors).
static std::exception_ptr checkForNCCLErrorsInternal(
std::shared_ptr<NCCLComm>& ncclComm);
// Function that runs as part of a separate thread and checks for errors on
// NCCL communicators. We need a separate thread to check for NCCL errors
// since we can't rely on the user calling certain methods like wait(),
// isCompleted() etc. to detect and remediate errors. In addition to this, we
// need a mechanism to safely abort and remove NCCL communicators from our
// cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
// class. Attempting to modify the communicator cache from the WorkNCCL class
// might run into issues with object lifetime since the ProcessGroupNCCL
// object might get destroyed before the WorkNCCL object.
void ncclCommWatchdog();
// Return the CUDA device most likely associated with this backend.
// If we aren't bound to a specific device, there is no strict
// guarantee that this heuristic is the correct assignment of ranks
// to GPUs that Python layers use, but in practice it tends to be.
// Fortunately we don't rely on this for correctness of any tensor
// operations, just for ancillary uses like barriers.
at::Device guessDeviceForRank() const;
// Destroys initialized NCCL communicators in devNCCLComMap_ given by input
// key. Throws if there are no communicators to destroy. Also removes
// communicators from the cache and clears used device indices.
void destroyNCCLComms(const std::string& devNCCLCommMapKey);
// Watchdog's inside loop.
// Takes care of cleaning up completed work, and aborting upon failure or
// timeout.
void watchdogHandler();
void runHookLoop();
// Generates a prefix that is unique to this process group and rank, for
// disambiguating logs
std::string createLogPrefix() const;
// Returns the unique prefix created in createLogPrefix
const std::string& logPrefix() const;
// Returns the global rank of the device. This function assumes that users
// always create a default global process group(PG) which includes all
// devices. It is called in the constructor of ProcessGroupNCCL, so it always
// return the rank_ of the the very first PG created, aka, default global PG.
const int& globalRank() const;
// Returns the global ranks of a PG.
const std::vector<uint64_t>& groupRanks() const;
// Util function to assign timeout to each work.
void assignTimeoutToWork(
const c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work,
const c10::intrusive_ptr<Options>& option);
// Broadcast flight-recorder dump signal
void broadcastDumpSignal();
protected:
// Function that runs as part of a separate thread aside from watchdog
// thread because we need to check the heartbeat from watchdog thread
// so that when we get stuck in some NCCL/CUDA calls,
// we can dump the debugging information and abort the process.
virtual void heartbeatMonitor();
// Function that directly trigger std::abort so that the whole process
// gets terminated.
virtual void terminateProcess(const std::string& errMsg);
// A helper function to wait for a future to complete or timeout.
void waitForFutureOrTimeout(
std::future<bool>& fut,
const std::chrono::milliseconds& timeOutMilSec,
const std::string& futDescription,
bool throwException = false,
bool log = false);
std::string getNCCLWatchdogTimeoutErrorMsg(const std::string& extraMsg);
std::string getNCCLWatchdogTimeoutExitMsg(const std::string& exitReason);
static const int64_t kWatchdogThreadSleepMillis;
// The store is used to broadcast the NCCL unique ID of rank 0. This store
// comes with prefix and it is different across ProcessGroup NCCL instances
// (aka, different ProcessGroups).
c10::intrusive_ptr<Store> store_;
// Reference to the store without prefix so that keys are same across all