Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aysnc send/recv, seriliaze/deserialize using threadpool. #7705

Merged
merged 2 commits into from
Jan 24, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 71 additions & 47 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "grpc_client.h"
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
Expand All @@ -22,25 +23,32 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
sendrecv::VariableMessage req;
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);

// varhandle
VarHandle var_h;
var_h.ep = ep;
var_h.scope = &scope;
var_h.name = var_name;
var_h.ctx = &ctx;

// stub context
auto ch = GetChannel(ep);
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;

auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage req;
SerializeToMessage(var_name_val, var, *p_ctx, &req);

// varhandle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;

// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;

auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
});

req_count_++;

Expand All @@ -50,8 +58,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name);

std::istringstream iss(ret_msg.serialized());
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
}

Expand All @@ -60,44 +66,63 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(var_name);

// varhandle
VarHandle var_h;
var_h.ep = ep;
var_h.scope = &scope;
var_h.name = var_name;
var_h.ctx = &ctx;

// stub context
auto ch = GetChannel(ep);
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;

auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
sendrecv::VariableMessage req;
req.set_varname(var_name_val);

// varhandle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;

// stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;

auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
});

req_count_++;

return true;
}

bool RPCClient::Wait() {
bool ok = true;
if (req_count_ <= 0) {
return true;
}

while (true) {
if (req_count_ <= 0) {
break;
}
std::vector<bool> a(req_count_);
std::vector<std::future<void>> waits(req_count_);

if (!Proceed()) {
for (int i = 0; i < req_count_; i++) {
waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); });
}

for (int i = 0; i < req_count_; i++) {
waits[i].wait();
}

int last_req_count = req_count_;
req_count_ = 0;

for (int i = 0; i < last_req_count; i++) {
if (!a[i]) {
return false;
}
}

return ok;
return true;
}

bool RPCClient::Proceed() {
Expand All @@ -124,7 +149,6 @@ bool RPCClient::Proceed() {

c->Process();
delete c;
req_count_--;
return true;
}

Expand Down