Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MSCCL Multithreaded regression alternative state management #1352

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/include/msccl/msccl_status.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void mscclSetInitialized(int rank, bool initialized = true);

void mscclRemoveRank(int rank);

mscclStatus& mscclGetStatus(int rank);
mscclStatus& mscclGetStatus(int rank, int rankCount = -1);

mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank);

Expand Down
2 changes: 1 addition & 1 deletion src/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p

if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
NCCLCHECK(mscclInit(comm));
mscclStatus& status = mscclGetStatus(comm->rank);
mscclStatus& status = mscclGetStatus(comm->rank, comm->nRanks);
status.needsProxy |= mscclNeedsProxy;
}

Expand Down
56 changes: 38 additions & 18 deletions src/misc/msccl/msccl_status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
using namespace std;

struct mscclRankState {
Expand All @@ -24,25 +24,44 @@ struct mscclRankState {
};

static mutex rankStatesMutex;
static unordered_map<int, shared_ptr<mscclRankState>> rankStates;
static vector<shared_ptr<mscclRankState>> rankStates;

static inline mscclRankState& mscclGetRankState(int rank) {
// In the unlikely case of negative rank, return a per-thread state
if (rank < 0) {
static thread_local shared_ptr<mscclRankState> threadRankState(new mscclRankState());
static inline mscclRankState& mscclGetRankState(int rank, int rankCount = -1) {
static thread_local shared_ptr<mscclRankState> threadRankState;

// Calling code can allocate states for the number of ranks at an appropriate time.
// It is assumed that all threads will call this function simultaneously with the
// same rankCount, which would avoid race conditions later in the function.
if (rankCount > 0) {
lock_guard<mutex> lock(rankStatesMutex);
if (rankStates.size() < rankCount) {
rankStates.resize((size_t)rankCount);
}
}

if (rank < 0 || rank >= rankStates.size()) {
// threadRankState is used when no rank state can be returned (rank<0 or rank not in rankStates)
if (!threadRankState) {
threadRankState.reset(new mscclRankState());
}
return *threadRankState;
}

lock_guard<mutex> lock(rankStatesMutex);
if (!rankStates[rank]) {
// When no state is yet assigned to a rank, use the current thread's threadRankState.
if (!threadRankState) {
threadRankState.reset(new mscclRankState());
}
rankStates[rank] = threadRankState;
}

auto rankStateIt = rankStates.find(rank);
if (rankStateIt == rankStates.end()) {
// Create a per rank threadRankState rather than per thread
shared_ptr<mscclRankState> newthreadRankState(new mscclRankState());
newthreadRankState->rank = rank;
rankStateIt = rankStates.insert(make_pair(rank, newthreadRankState)).first;
if (!threadRankState) {
// Cache this rank's state in threadRankState in case this thread calls with rank<0 later.
// NOTE: When multiple ranks share a thread, only the first rank in will be used for rank<0.
threadRankState = rankStates[rank];
}
return *(rankStateIt->second);

return *rankStates[rank];
}

bool mscclInitialized(int rank) {
Expand All @@ -56,12 +75,13 @@ void mscclSetInitialized(int rank, bool initialized) {
}

void mscclRemoveRank(int rank) {
lock_guard<mutex> lock(rankStatesMutex);
rankStates.erase(rank);
if (rank < rankStates.size()) {
rankStates[rank].reset();
}
}

mscclStatus& mscclGetStatus(int rank) {
return mscclGetRankState(rank).status;
mscclStatus& mscclGetStatus(int rank, int rankCount) {
return mscclGetRankState(rank, rankCount).status;
}

mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
Expand Down