Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init NCCL communicator in graph mode unifiedly #8263

Merged
merged 9 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
#include "oneflow/core/job/eager_nccl_comm_manager.h"
#include "oneflow/core/device/nccl_util.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/vm/vm_util.h"

#ifdef WITH_CUDA

Expand Down Expand Up @@ -73,8 +75,14 @@ void CreateNcclComm(ncclComm_t* comm, const int dev, const std::string& key,
OF_NCCL_CHECK(ncclCommInitRank(comm, device_vec.size(), nccl_unique_id, rank));
}

bool IsUnifiedNcclCommInitUseKernel(const std::string& op_type_name) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个名字的意思有点奇怪。init use两个动词放一起非常chinglish,一时间没看懂。你是想表达inited in kernel还是inited by kernel呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以改成 IsUserKenrelNeedingUnifiedNcclCommInit 或者 NeedUnifiedNcclCommInit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以。

return UserKernelUnifiedNcclCommInitRegistry::Instance().IsRegistered(op_type_name);
}

} // namespace

const std::string EagerNcclCommMgr::kDefaultStreamName = "DEFAULT";

EagerNcclCommMgr::~EagerNcclCommMgr() {
for (auto& device_set7device_id2comm : device_set2device_id2comm_) {
for (auto& device_id7comm : device_set7device_id2comm.second) {
Expand Down Expand Up @@ -136,6 +144,69 @@ ncclComm_t EagerNcclCommMgr::GetCommForDeviceAndStreamName(
return comm;
}

void EagerNcclCommMgr::CreateCommFromPlan(const Plan& plan) {
const int64_t rank = GlobalProcessCtx::Rank();
const int64_t dev = GlobalProcessCtx::LocalRank();
std::map<std::string, std::vector<std::pair<int64_t, int64_t>>> nccl_comm_key2devices;

for (const auto& task_proto : plan.task()) {
if (task_proto.machine_id() != rank) { continue; }
if (task_proto.exec_sequence().exec_node_size() != 1) { continue; }
const auto& kernel_conf = task_proto.exec_sequence().exec_node(0).kernel_conf();
const OpAttribute* op_attr = nullptr;
if (kernel_conf.has_op_attribute()) {
op_attr = &kernel_conf.op_attribute();
} else if (kernel_conf.has_op_attribute_ref()) {
const auto& ref_name = kernel_conf.op_attribute_ref();
op_attr = &plan.job_id2op_attribute_ref_table()
.at(task_proto.job_id())
.op_name2op_attribute()
.at(ref_name);
} else {
continue;
}
const auto& op_conf = op_attr->op_conf();
if (!op_conf.has_user_conf()) { continue; }
if (!IsUnifiedNcclCommInitUseKernel(op_conf.user_conf().op_type_name())) { continue; }

if (!op_attr->has_parallel_conf_signature()) { continue; }
if (!op_attr->parallel_conf_signature().has_op_parallel_conf()) { continue; }

std::vector<std::pair<int64_t, int64_t>> device_vec;
ParallelDesc parallel_desc(op_attr->parallel_conf_signature().op_parallel_conf());
for (int64_t parallel_id = 0; parallel_id < parallel_desc.parallel_num(); ++parallel_id) {
int64_t machine_id = CHECK_JUST(parallel_desc.MachineId4ParallelId(parallel_id));
int64_t device_id = CHECK_JUST(parallel_desc.DeviceId4ParallelId(parallel_id));
device_vec.emplace_back(machine_id, device_id);
}

std::string stream_name = kDefaultStreamName;
if (op_conf.has_stream_name_hint()) { stream_name = op_conf.stream_name_hint(); }
std::string key = GetNcclUniqueIdRpcKey(device_vec) + "-stream_name_hint:" + stream_name;

VLOG(3) << " EagerNcclCommMgr create nccl comm for " << op_conf.name() << ", rank = " << rank
<< ", dev = " << dev << ", key = {" << key << "}\n";
nccl_comm_key2devices.emplace(std::move(key), std::move(device_vec));
}

if (nccl_comm_key2devices.size() == 0) { return; }

CHECK_JUST(vm::CurrentRankSync());
CudaCurrentDeviceGuard guard(dev);

for (const auto& pair : nccl_comm_key2devices) {
const auto& key = pair.first;
auto device_id2comm_it = device7stream2device_id2comm_.find(key);
if (device_id2comm_it != device7stream2device_id2comm_.end()) {
auto comm_it = device_id2comm_it->second.find(dev);
if (comm_it != device_id2comm_it->second.end()) { continue; }
}
ncclComm_t comm;
CreateNcclComm(&comm, dev, key, pair.second);
device7stream2device_id2comm_[key][dev] = comm;
}
}

} // namespace oneflow

#endif // WITH_CUDA
39 changes: 39 additions & 0 deletions oneflow/core/job/eager_nccl_comm_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ namespace oneflow {

class EagerNcclCommMgr final {
public:
static const std::string kDefaultStreamName;

OF_DISALLOW_COPY_AND_MOVE(EagerNcclCommMgr);
~EagerNcclCommMgr();

ncclComm_t GetCommForDevice(const std::set<std::pair<int64_t, int64_t>>& device_set);
ncclComm_t GetCommForDeviceAndStreamName(const std::set<std::pair<int64_t, int64_t>>& device_set,
const std::string& stream_name);

void CreateCommFromPlan(const Plan& plan);

private:
friend class Global<EagerNcclCommMgr>;
EagerNcclCommMgr() = default;
Expand All @@ -44,8 +48,43 @@ class EagerNcclCommMgr final {
std::mutex mutex_;
};

class UserKernelUnifiedNcclCommInitRegistry final {
public:
struct Trigger {
explicit Trigger(const std::string& key) {
UserKernelUnifiedNcclCommInitRegistry::Instance().Register(key);
}
};

static UserKernelUnifiedNcclCommInitRegistry& Instance() {
static UserKernelUnifiedNcclCommInitRegistry reg;
return reg;
}

OF_DISALLOW_COPY_AND_MOVE(UserKernelUnifiedNcclCommInitRegistry);
~UserKernelUnifiedNcclCommInitRegistry() = default;

void Register(const std::string& key) {
bool insert_success = reg_set_.insert(key).second;
if (!insert_success) {
std::cerr << key << " was already registered in NcclCommRegistry" << std::endl;
abort();
}
}

bool IsRegistered(const std::string& key) const { return reg_set_.find(key) != reg_set_.end(); }

private:
UserKernelUnifiedNcclCommInitRegistry() = default;
std::set<std::string> reg_set_;
};

} // namespace oneflow

#define REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT(op_type_name) \
static auto OF_PP_CAT(g_nccl_comm_reg_, __COUNTER__) = \
::oneflow::UserKernelUnifiedNcclCommInitRegistry::Trigger(op_type_name)

#endif // WITH_CUDA

#endif // ONEFLOW_CORE_JOB_EAGER_NCCL_COMM_MANAGER_H_
2 changes: 2 additions & 0 deletions oneflow/core/job/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/job/runtime_job_descs.h"
#include "oneflow/core/job/eager_nccl_comm_manager.h"
#include "oneflow/core/thread/thread_manager.h"
#include "oneflow/core/graph/task_node.h"
#include "oneflow/core/device/cuda_util.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ Runtime::Runtime(
Global<RuntimeJobDescs>::Get()->AddPlan(plan);
collective_boxing_scheduler_plan_token_ =
Global<boxing::collective::Scheduler>::Get()->AddPlan(plan);
Global<EagerNcclCommMgr>::Get()->CreateCommFromPlan(plan);
}
std::vector<const TaskProto*> source_tasks;
source_tasks.reserve(plan.task().size());
Expand Down
16 changes: 8 additions & 8 deletions oneflow/user/kernels/data_shuffle_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,10 @@ class DataShuffleKernelState final : public user_op::OpKernelState {
public:
explicit DataShuffleKernelState(user_op::KernelInitContext* ctx)
: device_index_(-1),
has_independent_stream_(ctx->op_conf().has_stream_name_hint()),
stream_name_(""),
stream_name_(EagerNcclCommMgr::kDefaultStreamName),
parallel_desc_(ctx->parallel_desc()) {
OF_CUDA_CHECK(cudaGetDevice(&device_index_));
if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); }
if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }
OF_CUDA_CHECK(cudaMallocHost(
&host_num_unique_matrix_,
parallel_desc_.parallel_num() * parallel_desc_.parallel_num() * sizeof(IDX)));
Expand Down Expand Up @@ -283,11 +282,7 @@ class DataShuffleKernelState final : public user_op::OpKernelState {
}
EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global<EagerNcclCommMgr>::Get());
ncclComm_t comm;
if (has_independent_stream_) {
comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
} else {
comm = comm_mgr->GetCommForDevice(device_set);
}
comm = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
comm_.reset(new Comm(comm));
}

Expand Down Expand Up @@ -1517,4 +1512,9 @@ class UniqueKeyValuePairKernel final : public user_op::OpKernel {

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CUDA_UNIQUE_KEY_VALUE_PAIR_KERNEL, ID_DATA_TYPE_SEQ,
ID_DATA_TYPE_SEQ, IDX_DATA_TYPE_SEQ)

REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("id_shuffle");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("embedding_shuffle");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("embedding_gradient_shuffle");

} // namespace oneflow
31 changes: 12 additions & 19 deletions oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
public:
explicit NcclLogical2DSameDim0KernelCommState(user_op::KernelInitContext* ctx)
: is_init_(false),
has_independent_stream_(ctx->op_conf().has_stream_name_hint()),
stream_name_("NONE"),
stream_name_(EagerNcclCommMgr::kDefaultStreamName),
parallel_desc_(ctx->parallel_desc()),
this_parallel_id_(ctx->parallel_ctx().parallel_id()) {
if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); }
if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }
}
~NcclLogical2DSameDim0KernelCommState() override = default;

Expand Down Expand Up @@ -71,17 +70,12 @@ class NcclLogical2DSameDim0KernelCommState : public user_op::OpKernelState {
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global<EagerNcclCommMgr>::Get());
if (has_independent_stream_) {
comm_ = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
} else {
comm_ = comm_mgr->GetCommForDevice(device_set);
}
comm_ = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
num_ranks_ = group_size;
is_init_ = true;
}

bool is_init_;
bool has_independent_stream_;
std::string stream_name_;
ParallelDesc parallel_desc_;
int64_t this_parallel_id_;
Expand Down Expand Up @@ -399,11 +393,10 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState
public:
explicit NcclLogical2DSameDim1KernelCommState(user_op::KernelInitContext* ctx)
: is_init_(false),
has_independent_stream_(ctx->op_conf().has_stream_name_hint()),
stream_name_("NONE"),
stream_name_(EagerNcclCommMgr::kDefaultStreamName),
parallel_desc_(ctx->parallel_desc()),
this_parallel_id_(ctx->parallel_ctx().parallel_id()) {
if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); }
if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }
}
~NcclLogical2DSameDim1KernelCommState() = default;

Expand All @@ -425,12 +418,7 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global<EagerNcclCommMgr>::Get());
CHECK_NOTNULL(comm_mgr);
if (has_independent_stream_) {
comm_ = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
} else {
comm_ = comm_mgr->GetCommForDevice(device_set);
}
comm_ = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
is_init_ = true;
}
return comm_;
Expand All @@ -440,7 +428,6 @@ class NcclLogical2DSameDim1KernelCommState final : public user_op::OpKernelState

private:
bool is_init_;
bool has_independent_stream_;
std::string stream_name_;
ParallelDesc parallel_desc_;
int64_t this_parallel_id_;
Expand Down Expand Up @@ -521,6 +508,12 @@ REGISTER_USER_KERNEL("_nccl_logical_2D_same_dim1_all_reduce")
.SetCreateFn<NcclLogical2DSameDim1AllReduce>()
.SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA);

REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_2D_same_dim0_all_reduce");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_2D_same_dim0_all_gather");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_2D_same_dim0_all_gather_noncontinuous");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_2D_same_dim0_all2all");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_2D_same_dim1_all_reduce");

} // namespace oneflow

#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700
18 changes: 9 additions & 9 deletions oneflow/user/kernels/nccl_logical_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState {
public:
explicit NcclLogicalKernelCommState(user_op::KernelInitContext* ctx)
: is_init_(false),
has_independent_stream_(ctx->op_conf().has_stream_name_hint()),
stream_name_("NONE"),
stream_name_(EagerNcclCommMgr::kDefaultStreamName),
parallel_desc_(ctx->parallel_desc()) {
if (has_independent_stream_) { stream_name_ = ctx->op_conf().stream_name_hint(); }
if (ctx->op_conf().has_stream_name_hint()) { stream_name_ = ctx->op_conf().stream_name_hint(); }
}
~NcclLogicalKernelCommState() override = default;

Expand All @@ -48,11 +47,7 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState {
device_set.emplace(std::make_pair(machine_id, device_id));
}
EagerNcclCommMgr* comm_mgr = CHECK_NOTNULL(Global<EagerNcclCommMgr>::Get());
if (has_independent_stream_) {
comm_ = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
} else {
comm_ = comm_mgr->GetCommForDevice(device_set);
}
comm_ = comm_mgr->GetCommForDeviceAndStreamName(device_set, stream_name_);
is_init_ = true;
}
return comm_;
Expand All @@ -62,7 +57,6 @@ class NcclLogicalKernelCommState : public user_op::OpKernelState {

private:
bool is_init_;
bool has_independent_stream_;
std::string stream_name_;
ParallelDesc parallel_desc_;
ncclComm_t comm_{};
Expand Down Expand Up @@ -447,6 +441,12 @@ REGISTER_S2S_KERNEL(float)
REGISTER_S2S_KERNEL(double)
REGISTER_S2S_KERNEL(float16)

REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_all_reduce");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_reduce_scatter");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_all_gather");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_all_gather_noncontinuous");
REGISTER_USER_KERNEL_UNIFIED_NCCL_COMM_INIT("_nccl_logical_s2s");

} // namespace oneflow

#endif // WITH_CUDA && NCCL_VERSION_CODE > 2700