From f371a360c747df4e96ea80582333b2ae86931e6a Mon Sep 17 00:00:00 2001 From: Corey Derochie Date: Thu, 26 Sep 2024 17:26:54 -0500 Subject: [PATCH 1/2] Reimplemented this fix using a backing vector instead of a map, and changing the semantics of hand-off from thread-local state to rank state. --- src/include/msccl/msccl_status.h | 2 +- src/init.cc | 2 +- src/misc/msccl/msccl_status.cc | 49 ++++++++++++++++++++------------ 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/include/msccl/msccl_status.h b/src/include/msccl/msccl_status.h index 077709ddc..1cab29d9e 100644 --- a/src/include/msccl/msccl_status.h +++ b/src/include/msccl/msccl_status.h @@ -14,7 +14,7 @@ void mscclSetInitialized(int rank, bool initialized = true); void mscclRemoveRank(int rank); -mscclStatus& mscclGetStatus(int rank); +mscclStatus& mscclGetStatus(int rank, int rankCount = -1); mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank); diff --git a/src/init.cc b/src/init.cc index 9c93d8ac5..9dc8474ba 100644 --- a/src/init.cc +++ b/src/init.cc @@ -1844,7 +1844,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) { NCCLCHECK(mscclInit(comm)); - mscclStatus& status = mscclGetStatus(comm->rank); + mscclStatus& status = mscclGetStatus(comm->rank, comm->nRanks); status.needsProxy |= mscclNeedsProxy; } diff --git a/src/misc/msccl/msccl_status.cc b/src/misc/msccl/msccl_status.cc index f2b26663b..c01563b16 100644 --- a/src/misc/msccl/msccl_status.cc +++ b/src/misc/msccl/msccl_status.cc @@ -10,7 +10,7 @@ #include #include -#include +#include using namespace std; struct mscclRankState { @@ -24,25 +24,37 @@ struct mscclRankState { }; static mutex rankStatesMutex; -static unordered_map> rankStates; +static vector> rankStates; -static inline mscclRankState& mscclGetRankState(int rank) { - // In the unlikely case of negative rank, return a per-thread state - if (rank < 0) { - static thread_local shared_ptr threadRankState(new mscclRankState()); +static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { + static thread_local shared_ptr threadRankState; + + if (rankCount > 0) { + lock_guard lock(rankStatesMutex); + if (rankStates.size() < rankCount) { + rankStates.resize((size_t)rankCount); + } + } + + if (rank < 0 || rank >= rankStates.size()) { + if (!threadRankState) { + threadRankState.reset(new mscclRankState()); + } return *threadRankState; } - lock_guard lock(rankStatesMutex); + if (!rankStates[rank]) { + if (!threadRankState) { + threadRankState.reset(new mscclRankState()); + } + rankStates[rank] = threadRankState; + } - auto rankStateIt = rankStates.find(rank); - if (rankStateIt == rankStates.end()) { - // Create a per rank threadRankState rather than per thread - shared_ptr newthreadRankState(new mscclRankState()); - newthreadRankState->rank = rank; - rankStateIt = rankStates.insert(make_pair(rank, newthreadRankState)).first; + if (!threadRankState) { + threadRankState = rankStates[rank]; } - return *(rankStateIt->second); + + return *rankStates[rank]; } bool mscclInitialized(int rank) { @@ -56,12 +68,13 @@ void mscclSetInitialized(int rank, bool initialized) { } void mscclRemoveRank(int rank) { - lock_guard lock(rankStatesMutex); - rankStates.erase(rank); + if (rank < rankStates.size()) { + rankStates[rank].reset(); + } } -mscclStatus& mscclGetStatus(int rank) { - return mscclGetRankState(rank).status; +mscclStatus& mscclGetStatus(int rank, int rankCount) { + return mscclGetRankState(rank, rankCount).status; } mscclThreadLocalStatus& mscclGetThreadLocalStatus() { From e88e0ecb6e3aaf3020fc9bb99ba48d84c8e7fc74 Mon Sep 17 00:00:00 2001 From: Corey Derochie Date: Thu, 26 Sep 2024 17:59:41 -0500 Subject: [PATCH 2/2] Added documentation to explain the flow better. --- src/misc/msccl/msccl_status.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/misc/msccl/msccl_status.cc b/src/misc/msccl/msccl_status.cc index c01563b16..d8ed2d698 100644 --- a/src/misc/msccl/msccl_status.cc +++ b/src/misc/msccl/msccl_status.cc @@ -29,6 +29,9 @@ static vector> rankStates; static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { static thread_local shared_ptr threadRankState; + // Calling code can allocate states for the number of ranks at an appropriate time. + // It is assumed that all threads will call this function simultaneously with the + // same rankCount, which would avoid race conditions later in the function. if (rankCount > 0) { lock_guard lock(rankStatesMutex); if (rankStates.size() < rankCount) { @@ -37,6 +40,7 @@ static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { } if (rank < 0 || rank >= rankStates.size()) { + // threadRankState is used when no rank state can be returned (rank<0 or rank not in rankStates) if (!threadRankState) { threadRankState.reset(new mscclRankState()); } @@ -44,6 +48,7 @@ static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { } if (!rankStates[rank]) { + // When no state is yet assigned to a rank, use the current thread's threadRankState. if (!threadRankState) { threadRankState.reset(new mscclRankState()); } @@ -51,6 +56,8 @@ static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) { } if (!threadRankState) { + // Cache this rank's state in threadRankState in case this thread calls with rank<0 later. + // NOTE: When multiple ranks share a thread, only the first rank in will be used for rank<0. threadRankState = rankStates[rank]; }