Skip to content

Commit

Permalink
aysnc send/recv, seriliaze/deserialize using threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
helinwang committed Jan 23, 2018
1 parent e7d44a2 commit 97e68a3
Showing 1 changed file with 57 additions and 40 deletions.
97 changes: 57 additions & 40 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "grpc_client.h"
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
Expand All @@ -22,25 +23,32 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
sendrecv::VariableMessage req;
auto* var = scope.FindVar(var_name);
SerializeToMessage(var_name, var, ctx, &req);

// varhandle
VarHandle var_h;
var_h.ep = ep;
var_h.scope = &scope;
var_h.name = var_name;
var_h.ctx = &ctx;

// stub context
auto ch = GetChannel(ep);
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;

auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
auto* var = p_scope->FindVar(var_name_val);
sendrecv::VariableMessage req;
SerializeToMessage(var_name_val, var, *p_ctx, &req);

// varhandle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;

// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;

auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
});

req_count_++;

Expand All @@ -50,34 +58,43 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name);

std::istringstream iss(ret_msg.serialized());
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
const VarHandle* p_var_h = &var_h;
const sendrecv::VariableMessage* p_ret_msg = &ret_msg;
framework::Async([outvar, p_var_h, p_ret_msg] {
DeserializeFromMessage(*p_ret_msg, *(p_var_h->ctx), outvar);
}).wait();
}

bool RPCClient::AsyncGetVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(var_name);

// varhandle
VarHandle var_h;
var_h.ep = ep;
var_h.scope = &scope;
var_h.name = var_name;
var_h.ctx = &ctx;

// stub context
auto ch = GetChannel(ep);
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;

auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);

framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
sendrecv::VariableMessage req;
req.set_varname(var_name_val);

// varhandle
VarHandle var_h;
var_h.ep = ep_val;
var_h.scope = p_scope;
var_h.name = var_name_val;
var_h.ctx = p_ctx;

// stub context
GetProcessor* s = new GetProcessor(ch);
s->Prepare(var_h, time_out);
s->response_call_back_ = ProcGetResponse;

auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
});

req_count_++;

Expand Down

0 comments on commit 97e68a3

Please sign in to comment.