Skip to content

Commit

Permalink
Update gloo in dygraph (#55537)
Browse files Browse the repository at this point in the history
* update broadcast gloo in dygraph

* update

* update reduce gloo in dygraph

* update reduce gloo in dygraph

* update

* update allreduce allgather

* update all

* update

* update

* update
  • Loading branch information
Xing-lil authored Jul 20, 2023
1 parent 982e0a9 commit 1d1e548
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 182 deletions.
244 changes: 73 additions & 171 deletions paddle/fluid/distributed/collective/process_group_gloo.cc

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions paddle/fluid/distributed/collective/process_group_gloo.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"

Expand Down Expand Up @@ -225,6 +226,8 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
return GetDeviceContext(place);
}

phi::distributed::GlooCommContext* GetCommContext();

// Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/collective/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace paddle {
namespace distributed {

// TODO(shenliang03): To support AVG for reduce
enum class ReduceOp : std::uint8_t { SUM = 0, AVG, MAX, MIN, PRODUCT };
enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG };

struct AllreduceOptions {
ReduceOp reduce_op = ReduceOp::SUM;
Expand Down
79 changes: 75 additions & 4 deletions paddle/phi/core/distributed/gloo_comm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

#include <gloo/allgather.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include <gloo/types.h>

#include "paddle/phi/common/data_type.h"
Expand All @@ -41,7 +44,8 @@ GlooCommContext::GlooCommContext(

void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root) {
int root,
uint32_t tag) {
// gloo only uses CPU now
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
Expand All @@ -56,24 +60,29 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
}
opts.setRoot(root);
opts.setTag(tag);
gloo::broadcast(opts);
}

void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor) {
const phi::DenseTensor& in_tensor,
uint32_t tag) {
// gloo only uses CPU now

gloo::AllgatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
gloo::allgather(opts);
}

void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type) {
int reduce_type,
uint32_t tag) {
gloo::AllreduceOptions opts(gloo_context_);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
Expand All @@ -84,15 +93,77 @@ void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
int root) {
int root,
uint32_t tag) {
gloo::ReduceOptions opts(gloo_context_);
opts.setRoot(root);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
GENERATE_FUNC(dtype, SetReduceFunc, &opts, reduce_type);
gloo::reduce(opts);
}

void GlooCommContext::Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag) {
gloo::GatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
opts.setRoot(src);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
if (rank_ == src) {
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
}
gloo::gather(opts);
}

void GlooCommContext::Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag) {
gloo::ScatterOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
if (rank_ == src) {
GENERATE_FUNC(dtype, SetInputForScatter, &opts, in_tensor, size);
}
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setRoot(src);
opts.setTag(tag);
gloo::scatter(opts);
}

void GlooCommContext::Barrier() {
gloo::BarrierOptions opts(gloo_context_);
gloo::barrier(opts);
}

void GlooCommContext::Send(const phi::DenseTensor& in_tensor,
int dst,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
opts.setSrc(gloo_context_.get()->rank);
opts.setDst(dst);
opts.setTag(tag);
send_recv(&opts);
}

void GlooCommContext::Recv(phi::DenseTensor* out_tensor,
int src,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = out_tensor->dtype();
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setSrc(src);
opts.setDst(gloo_context_.get()->rank);
opts.setTag(tag);
send_recv(&opts);
}

} // namespace distributed
} // namespace phi
29 changes: 25 additions & 4 deletions paddle/phi/core/distributed/gloo_comm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,38 @@ class GlooCommContext final : public CommContext {

void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root);
int root,
uint32_t tag = 0);
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type);
int reduce_type,
uint32_t tag = 0);
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
int root);
int root,
uint32_t tag = 0);

void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor);
const phi::DenseTensor& in_tensor,
uint32_t tag = 0);

void Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag = 0);

void Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag = 0);

void Barrier();

void Send(const phi::DenseTensor& in_tensor, int dst, uint32_t tag = 0);

void Recv(phi::DenseTensor* out_tensor, int src, uint32_t tag = 0);

private:
DISABLE_COPY_AND_ASSIGN(GlooCommContext);
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/core/distributed/gloo_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,20 @@ std::shared_ptr<gloo::transport::Device> CreateGlooDevice() {
}
}

void send_recv(SendRecvOptions* opts) {
const auto& context = opts->context;
gloo::transport::UnboundBuffer* in = opts->in.get();
gloo::transport::UnboundBuffer* out = opts->out.get();
const auto slot = gloo::Slot::build(kSendRecvSlotPrefix, opts->tag);

if (context->rank == opts->src) {
in->send(opts->dst, slot);
in->waitSend(opts->timeout);
} else if (context->rank == opts->dst) {
out->recv(opts->src, slot);
out->waitRecv(opts->timeout);
}
}

} // namespace distributed
} // namespace phi
64 changes: 64 additions & 0 deletions paddle/phi/core/distributed/gloo_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <gloo/allreduce.h>
#include <gloo/math.h>
#include <gloo/transport/tcp/device.h>
#include <gloo/types.h>
Expand Down Expand Up @@ -103,6 +104,19 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) {
tensor.numel());
}

template <typename T, typename P>
void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) {
std::vector<T*> ret;
ret.reserve(nranks);
T* raw_pointer = reinterpret_cast<T*>(const_cast<void*>(tensor.data()));
size_t offset = 0;
for (int i = 0; i < nranks; i++) {
ret.push_back(raw_pointer + offset);
offset += tensor.numel() / nranks;
}
opts->setInputs(ret, tensor.numel() / nranks);
}

template <typename T, typename P>
void SetReduceFunc(P* opts, int reduce_type) {
// gloo only support mutable data input
Expand Down Expand Up @@ -136,5 +150,55 @@ void SetReduceFunc(P* opts, int reduce_type) {
// env preparation
std::shared_ptr<gloo::transport::Device> CreateGlooDevice();

constexpr uint8_t kSendRecvSlotPrefix = 0x08;

class SendRecvOptions {
public:
explicit SendRecvOptions(const std::shared_ptr<gloo::Context>& context)
: context(context), timeout(context->getTimeout()) {}

template <typename T>
void setInput(T* ptr, size_t elements) {
this->in = context->createUnboundBuffer(ptr, elements * sizeof(T));
}

template <typename T>
void setOutput(T* ptr, size_t elements) {
this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
}

void setSrc(int src) { this->src = src; }

void setDst(int dst) { this->dst = dst; }

void setTag(uint32_t tag) { this->tag = tag; }

void setTimeout(std::chrono::milliseconds timeout) {
this->timeout = timeout;
}

protected:
std::shared_ptr<gloo::Context> context;
std::unique_ptr<gloo::transport::UnboundBuffer> in;
std::unique_ptr<gloo::transport::UnboundBuffer> out;

// Rank of process to send_recv from.
int src = -1;

// Rank of process to send_recv to.
int dst = -1;

// Tag for this operation.
// Must be unique across operations executing in parallel.
uint32_t tag = 0;

// End-to-end timeout for this operation.
std::chrono::milliseconds timeout;

friend void send_recv(SendRecvOptions*);
};

void send_recv(SendRecvOptions* opts);

} // namespace distributed
} // namespace phi
2 changes: 1 addition & 1 deletion test/collective/collective_allgather_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import os

import test_collective_api_base as test_base
import legacy_test.test_collective_api_base as test_base

import paddle
import paddle.distributed as dist
Expand Down
5 changes: 4 additions & 1 deletion test/collective/collective_allreduce_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
runtime_main,
)

import paddle
import paddle.distributed as dist
Expand Down

0 comments on commit 1d1e548

Please sign in to comment.