-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpoint M2: lookup table checkpoint #11490
Changes from 66 commits
4170196
a895916
8a17816
12de20f
b089b80
bb17604
1cb0ab3
fb27c9a
fe76244
98c30c7
f224948
8d46d1d
860360d
1c2e9bd
985026c
925e232
a9ac200
36d17d1
74384b7
050b66e
54013a9
15532c7
bbb349f
527b86b
85215df
8af8da4
5553adf
752eb08
3088084
ae12281
af0a6a1
bb10c37
549f0aa
a501766
ca27f78
ee64f57
1296d96
620698e
459690a
5250ca8
bccf8df
7efd73a
1571c25
d93dc81
8c0e1d5
49c2d0c
16ecead
6abf076
28482f8
06f6c21
5600b13
32fa832
8af4d4c
db6126c
5a4a24c
91eae9c
298588f
c073bb3
05bd9db
e589005
9764844
620999c
8e01f3b
2229db5
e684575
7fae9e0
4388ce1
b519bf0
fb7e479
dc847f1
33ff69b
88cb5d7
b6e6355
fa3d470
f57978e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <future> // NOLINT | ||
#include <ostream> | ||
|
||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/detail/macros.h" | ||
#include "paddle/fluid/operators/send_recv_util.h" | ||
#include "paddle/fluid/string/printf.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class CheckpointNotifyOp : public framework::OperatorBase { | ||
public: | ||
CheckpointNotifyOp(const std::string& type, | ||
const framework::VariableNameMap& inputs, | ||
const framework::VariableNameMap& outputs, | ||
const framework::AttributeMap& attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
|
||
void RunImpl(const framework::Scope& scope, | ||
const platform::Place& place) const override { | ||
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); | ||
std::string dir = Attr<std::string>("dir"); | ||
std::string lookup_table_name = Attr<std::string>("lookup_table"); | ||
|
||
distributed::RPCClient* rpc_client = | ||
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); | ||
for (size_t i = 0; i < epmap.size(); i++) { | ||
auto lookup_table_save_dir = | ||
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); | ||
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir); | ||
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name | ||
<< " and dir:" << dir << " to " << epmap[i]; | ||
} | ||
rpc_client->Wait(); | ||
} | ||
}; | ||
|
||
class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() { | ||
AddAttr<std::vector<std::string>>( | ||
"epmap", | ||
"(string vector, default 127.0.0.1:6164)" | ||
"Server endpoints in the order of input variables for mapping") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the comment should change. no input variables. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
.SetDefault({"127.0.0.1:6164"}); | ||
AddAttr<std::string>( | ||
"dir", "(string, default '') indicate the folder checkpoint will use"); | ||
AddAttr<std::string>("lookup_table", | ||
"(string, default '') the lookup table name"); | ||
AddComment(R"DOC( | ||
CheckpointNotify operator | ||
|
||
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at | ||
the parameter server. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* ctx) const override {} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp, | ||
paddle::framework::EmptyGradOpMaker, | ||
ops::CheckpointNotifyOpMaker, | ||
ops::CheckpointNotifyOpShapeInference); |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,6 +187,45 @@ class RequestPrefetch final : public RequestBase { | |
framework::Scope* local_scope_; | ||
}; | ||
|
||
class RequestCheckpointNotify final : public RequestBase { | ||
public: | ||
explicit RequestCheckpointNotify(GrpcService::AsyncService* service, | ||
::grpc::ServerCompletionQueue* cq, | ||
RequestHandler* request_handler, int req_id) | ||
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { | ||
request_.reset(new VariableResponse(request_handler->scope(), | ||
request_handler->dev_ctx())); | ||
int method_id = | ||
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify); | ||
service_->RequestAsyncUnary( | ||
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, | ||
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); | ||
} | ||
|
||
virtual ~RequestCheckpointNotify() {} | ||
|
||
std::string GetReqName() override { return request_->Varname(); } | ||
|
||
void Process() override { | ||
auto scope = request_->GetMutableLocalScope(); | ||
|
||
std::string checkpoint_notify = request_->Varname(); | ||
std::string checkpoint_dir = request_->OutVarname(); | ||
|
||
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify | ||
<< ", dir: " << checkpoint_dir; | ||
|
||
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, | ||
checkpoint_dir); | ||
Finish(reply_, &responder_); | ||
} | ||
|
||
protected: | ||
std::shared_ptr<VariableResponse> request_; | ||
sendrecv::VoidMessage reply_; | ||
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; | ||
}; | ||
|
||
void AsyncGRPCServer::WaitServerReady() { | ||
VLOG(3) << "AsyncGRPCServer is wait server ready"; | ||
std::unique_lock<std::mutex> lock(this->mutex_ready_); | ||
|
@@ -224,6 +263,8 @@ void AsyncGRPCServer::StartServer() { | |
reqs.reserve(kRequestBufSize); | ||
|
||
for (int i = 0; i < kRequestBufSize; i++) { | ||
LOG(INFO) << "TryToRegisterNewOne on RPC NAME: " << rpc_name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些日志太多了,用VLOG There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
<< " I: " << i; | ||
TryToRegisterNewOne(rpc_name, i); | ||
} | ||
|
||
|
@@ -276,8 +317,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, | |
return; | ||
} | ||
|
||
VLOG(4) << "register send rpc_name:" << rpc_name | ||
<< ", handler:" << rpc_call_map_[kRequestSend]; | ||
VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name | ||
<< " REQ ID: " << req_id; | ||
|
||
auto& reqs = rpc_reqs_[rpc_name]; | ||
auto& handler = rpc_call_map_[rpc_name]; | ||
|
@@ -290,6 +331,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, | |
b = new RequestGet(&service_, cq.get(), handler, req_id); | ||
} else if (rpc_name == kRequestPrefetch) { | ||
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); | ||
} else if (rpc_name == kRequestCheckpoint) { | ||
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id); | ||
} else { | ||
PADDLE_ENFORCE(false, "not supported rpc"); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a vlog to log the dir name? for checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This op should also support pserver checkpoint without distribute_lookup_table.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.