Skip to content

Commit

Permalink
Merge pull request #11490 from seiriosPlus/ckpt_m2
Browse files Browse the repository at this point in the history
Checkpoint M2: lookup table checkpoint
  • Loading branch information
seiriosPlus authored Jun 26, 2018
2 parents bb29800 + f57978e commit b20fa02
Show file tree
Hide file tree
Showing 20 changed files with 617 additions and 86 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE)
endif()

set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach()
Expand All @@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif()
else()
set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif()

op_library(cross_entropy_op DEPS cross_entropy)
Expand Down
88 changes: 88 additions & 0 deletions paddle/fluid/operators/checkpoint_notify_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <future> // NOLINT
#include <ostream>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace operators {

class CheckpointNotifyOp : public framework::OperatorBase {
public:
CheckpointNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");

distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
}
rpc_client->Wait();
}
};

class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
}
};

class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointNotifyOpMaker,
ops::CheckpointNotifyOpShapeInference);
17 changes: 17 additions & 0 deletions paddle/fluid/operators/distributed/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_++;
}

void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);

CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);

auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}

void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; });
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/operators/distributed/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};

class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}

virtual ~CheckpointNotifyProcessor() {}

virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};

class GRPCClient : public RPCClient {
public:
GRPCClient() {}
Expand All @@ -197,6 +211,9 @@ class GRPCClient : public RPCClient {
void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) override;

void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_grpc_deadline) override;

void Wait() override;

void SendComplete() override;
Expand Down
48 changes: 45 additions & 3 deletions paddle/fluid/operators/distributed/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope_;
};

class RequestCheckpointNotify final : public RequestBase {
public:
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}

virtual ~RequestCheckpointNotify() {}

std::string GetReqName() override { return request_->Varname(); }

void Process() override {
auto scope = request_->GetMutableLocalScope();

std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();

VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;

request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir);
Finish(reply_, &responder_);
}

protected:
std::shared_ptr<VariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};

void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
Expand Down Expand Up @@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() {
reqs.reserve(kRequestBufSize);

for (int i = 0; i < kRequestBufSize; i++) {
VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
TryToRegisterNewOne(rpc_name, i);
}

Expand Down Expand Up @@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return;
}

VLOG(4) << "register send rpc_name:" << rpc_name
<< ", handler:" << rpc_call_map_[kRequestSend];
VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< " REQ ID: " << req_id;

auto& reqs = rpc_reqs_[rpc_name];
auto& handler = rpc_call_map_[rpc_name];
Expand All @@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not supported rpc");
}
Expand All @@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest(
while (true) {
VLOG(4) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!";
VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
break;
}

Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/distributed/grpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ enum class GrpcMethod {
kSendVariable,
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
};

static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
static_cast<int>(GrpcMethod::kCheckpointNotify) + 1;

inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
Expand All @@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
}

// Shouldn't be reached.
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/operators/distributed/request_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,16 @@ namespace distributed {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"

#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"

class RPCServer;

class RequestHandler {
Expand Down Expand Up @@ -69,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_ = g;
}

void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}

// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
Expand Down Expand Up @@ -115,6 +124,8 @@ class RequestHandler {
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;

// Used for async.
std::unordered_map<std::string,
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/operators/distributed/request_handler_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"

namespace paddle {
namespace operators {
namespace distributed {

// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";

bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
Expand Down Expand Up @@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true;
}

bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");

auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
}

} // namespace distributed
} // namespace operators
} // namespace paddle
15 changes: 15 additions & 0 deletions paddle/fluid/operators/distributed/request_handler_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override;
};

class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;

private:
int checkpoint_notify_id;
};

} // namespace distributed
} // namespace operators
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/operators/distributed/rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class RPCClient {
virtual void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_rpc_deadline) = 0;

virtual void AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_grpc_deadline) = 0;

// SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue
// to train with other trainers.
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/operators/distributed/send_recv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ service SendRecvService {
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}

rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
}

// VariableMessage is serialized paddle variable message.
Expand Down
Loading

0 comments on commit b20fa02

Please sign in to comment.