-
Notifications
You must be signed in to change notification settings - Fork 754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nd boxing use nccl send/recv #7936
Conversation
hierarchy_index_helper, in_nd_sbp, visit); | ||
} else { | ||
// If Split or PartialSum, go through all the ranks along the depth-dimension. | ||
for (int64_t i = 0; i < parallel_hierarchy.dim_vec().at(depth); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里直接parallel_hierarchy.At(depth)
就可以了
CHECK_EQ(out_id, parallel_id); | ||
const TensorSliceView& in_slice = in_slices.at(in_id); | ||
const TensorSliceView& intersection = cur_rank_out_slice.Intersect(in_slice); | ||
dst_recv_intersections->at(in_id) = intersection; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果intersection是 empty 状态的话,两者没有交集,是不是可以跳过,在这里不用赋值,毕竟当前维度是broadcast的情况下,dst_recv_intersections也有“空洞”状态,所以维度为split或partial_sum的情况下,是不是只更新有交集的in_id就可以了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
嗯,确实可以去掉。
漏掉了这一句
if (intersection.IsEmpty()) { return; }
if (in_id != parallel_id) { return; } | ||
const TensorSliceView& out_slice = out_slices.at(out_id); | ||
const TensorSliceView& intersection = out_slice.Intersect(cur_rank_in_slice); | ||
src_send_intersections->at(out_id) = intersection; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,intersection 不是 empty 的时候更新
bool NdSbpNoPartialParallel(const NdSbp& nd_sbp) { | ||
CHECK_GT(nd_sbp.sbp_parallel_size(), 0); | ||
FOR_RANGE(int64_t, i, 0, nd_sbp.sbp_parallel_size()) { | ||
if (nd_sbp.sbp_parallel(i).has_partial_sum_parallel()) { return false; } | ||
} | ||
return true; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和oneflow/core/job/nd_sbp_util.h
中的NdSbpHasPartialParallel
,删除?
OF_NCCL_CHECK(ncclGroupStart()); | ||
for (int64_t i = 0; i < parallel_num; ++i) { | ||
if (send_elem_cnts.at(i) != 0) { | ||
LOG(INFO) << parallel_id << " send " << send_elem_cnts.at(i) << " to " << i; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个用VLOG(3)吧,不需要每次都把这个过程打印到日志里面,否则日志太长了
comm, cuda_stream)); | ||
} | ||
if (recv_elem_cnts.at(i) != 0) { | ||
LOG(INFO) << parallel_id << " recv " << recv_elem_cnts.at(i) << " from " << i; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
} | ||
} | ||
} else { | ||
std::unique_ptr<ep::primitive::Add> primitive = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
primitive
改成 add_primitive
?
void* out_buf = reinterpret_cast<void*>(buf_ptr + offset); | ||
memset_primitive->Launch(ctx->stream(), out_buf, 0, | ||
out->shape().elem_cnt() * GetSizeOfDataType(data_type)); | ||
out_tensor_slice_copier_vec.at(i)->Copy(ctx->stream(), out_buf, recv_out_ptr.at(i)); | ||
primitive->Launch(ctx->stream(), out->dptr(), out_buf, out->mut_dptr(), | ||
out->shape().elem_cnt()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里好像并不是很有必要,if ... else ...
是不是这样就可以合并?不需要为output准备tmp buf
primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(),
recv_elem_cnts.at(i));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把159行memset提到for 循环外边,src_nd_sbp_no_partial_parallel_为false的分支代码是不是可以更简洁?这样也并不需要为output准备tmp buf
std::unique_ptr<ep::primitive::Add> primitive =
ep::primitive::NewPrimitive<ep::primitive::AddFactory>(ctx->stream()->device_type(),
out->data_type());
CHECK(primitive);
std::unique_ptr<ep::primitive::Memset> memset_primitive =
ep::primitive::NewPrimitive<ep::primitive::MemsetFactory>(ctx->stream()->device_type());
CHECK(memset_primitive);
memset_primitive->Launch(ctx->stream(), out->mut_dptr(), 0,
out->shape().elem_cnt() * GetSizeOfDataType(data_type));
for (int64_t i = 0; i < parallel_num; ++i) {
if (out_tensor_slice_copier_vec.at(i)) {
primitive->Launch(ctx->stream(), out->dptr(), recv_out_ptr.at(i), out->mut_dptr(),
recv_elem_cnts.at(i));
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
形状不一样的话需要考虑offset,不一定是拷贝到指针开始的地方,源代码没问题
const int64_t machine_id = CHECK_JUST(out_parallel_desc.MachineId4ParallelId(out_id)); | ||
int64_t device_index = CHECK_JUST(out_parallel_desc.DeviceId4ParallelId(out_id)); | ||
int64_t thrd_id = EncodeStreamIdToInt64(GenerateNamedTaskStreamId( | ||
machine_id, out_parallel_desc.device_type(), device_index, "NCCL_SEND_RECV_BOXING")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不能用相同stream,不能保证顺序
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"NCCL_SEND_RECV_BOXING" + NewUniqueId()
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/7936/ |
Speed stats:
|
visit(hierarchy_index_helper.NdIndexToOffset(out_parallel_ids.data(), | ||
parallel_hierarchy.NumAxes()), | ||
hierarchy_index_helper.NdIndexToOffset(in_parallel_ids.data(), | ||
parallel_hierarchy.NumAxes())); | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
visit没必要传2个参数。
只需要传一个in_id。
每卡的out_id是固定的。这里实际上是一个out_id 转到 NdIndex,
然后在多个枝叶上对毫无改动的 NdIndex 转回 out_id,然后下文只是检查了一下out_id经过了2次转换是否相等。
这样可以省掉 out_id的很多次转换。
这里去掉out_id的转换还有一个原因,就是在实现不同的placement的时候,你就会发现,这个out_id 只会在in_parallel_desc的情况下转化而成的 NdIdex才有意义。
打个比方,[0, 1, 2, 3] -> [1, 2, 3, 4],
out_id 为0时是1卡,而1卡在out_parallel_desc对应的 NdIndex 是 (0, 0),(0, 0) 对应的输入卡是 0卡。如果让1卡优先从0卡传输,明显是亏的。
考虑 [0, 1, 2, 3]: B -> [1, 2, 3, 4]: S(0)
传输量其实是 1/4 T, 而在out_parallel_desc的 NdIndex引导下传输量提升到了 T
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试过没问题
sbp=src_nd_sbp, | ||
placement=placement, | ||
) | ||
graph = TestGraph() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个测试比较适合放到 test/graph 下面?
Speed stats:
|
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
…flow-Inc/oneflow into dev_nd_nccl_send_recv_boxing
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
…flow-Inc/oneflow into dev_nd_nccl_send_recv_boxing
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
has been merged into master in #8437 |
使用nccl send/recv支持任意src_parallel_desc == dst_parallel_desc 且device=kCUDA 且dst中没有P的boxing