Skip to content

Commit

Permalink
optimize start server
Browse files Browse the repository at this point in the history
  • Loading branch information
seemingwang committed Apr 1, 2021
1 parent bd6b545 commit a768899
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 55 deletions.
15 changes: 4 additions & 11 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,9 @@ uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
return 0;
}
VLOG(0) << "GraphBrpcServer::start registe_ps_server";
_environment->registe_ps_server(ip, port, _rank);
VLOG(0) << "GraphBrpcServer::start wait";
cv_.wait(lock, [&] { return stoped_; });

PSHost host;
host.ip = ip;
host.port = port;
host.rank = _rank;
VLOG(0) << "GraphBrpcServer::start return host.rank";
return host.rank;
// cv_.wait(lock, [&] { return stoped_; });
return 0;
}

int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
Expand Down Expand Up @@ -232,11 +224,12 @@ int32_t GraphBrpcService::stop_server(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
std::thread t_stop([p_server]() {
p_server->stop();
LOG(INFO) << "Server Stoped";
});
p_server->export_cv()->notify_all();
t_stop.detach();
return 0;
}
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,17 @@ class GraphBrpcServer : public PSServer {
virtual uint64_t start(const std::string &ip, uint32_t port);
virtual int32_t stop() {
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return 0;
stoped_ = true;
cv_.notify_all();

// cv_.notify_all();
_server.Stop(1000);
_server.Join();
return 0;
}
virtual int32_t port();

std::condition_variable *export_cv() { return &cv_; }

private:
virtual int32_t initialize();
mutable std::mutex mutex_;
Expand Down Expand Up @@ -104,7 +106,7 @@ class GraphBrpcService : public PsBaseService {
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
const int sample_nodes_ranges = 3;
const int sample_nodes_ranges = 23;
};

} // namespace distributed
Expand Down
53 changes: 33 additions & 20 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ void GraphPyService::set_up(std::string ips_str, int shard_num,
std::vector<std::string> edge_types) {
set_shard_num(shard_num);
set_num_node_types(node_types.size());
// set_client_Id(client_id);
// set_rank(rank);

for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
this->table_id_map[node_types[table_id]] = this->table_id_map.size();
Expand Down Expand Up @@ -89,27 +87,33 @@ void GraphPyClient::start_client() {
worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id);
worker_ptr->set_shard_num(get_shard_num());
}
void GraphPyServer::start_server() {
void GraphPyServer::start_server(bool block) {
std::string ip = server_list[rank];
uint32_t port = std::stoul(port_list[rank]);
server_thread = new std::thread([this, &ip, &port]() {
::paddle::distributed::PSParameter server_proto = this->GetServerProto();
::paddle::distributed::PSParameter server_proto = this->GetServerProto();

auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port);
});
server_thread->detach();
sleep(3);
auto _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&this->host_sign_list,
this->host_sign_list.size()); // test
pserver_ptr = std::shared_ptr<paddle::distributed::GraphBrpcServer>(
(paddle::distributed::GraphBrpcServer*)
paddle::distributed::PSServerFactory::create(server_proto));
VLOG(0) << "pserver-ptr created ";
std::vector<framework::ProgramDesc> empty_vec;
framework::ProgramDesc empty_prog;
empty_vec.push_back(empty_prog);
pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
pserver_ptr->start(ip, port);
std::condition_variable* cv_ = pserver_ptr->export_cv();
if (block) {
std::mutex mutex_;
std::unique_lock<std::mutex> lock(mutex_);
cv_->wait(lock);
}

// });
// server_thread->detach();
// sleep(3);
}
::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
// Generate server proto desc
Expand Down Expand Up @@ -312,5 +316,14 @@ std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
}
return res;
}

void GraphPyClient::stop_server() {
VLOG(0) << "going to stop server";
std::unique_lock<std::mutex> lock(mutex_);
if (stoped_) return;
auto status = this->worker_ptr->stop_server();
if (status.get() == 0) stoped_ = true;
}
void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); }
}
}
11 changes: 6 additions & 5 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,8 @@ class GraphPyServer : public GraphPyService {
}
int get_rank() { return rank; }
void set_rank(int rank) { this->rank = rank; }
// paddle::distributed::GraphBrpcService * get_service(){
// return pserver_ptr->get_service();
// }
void start_server();

void start_server(bool block = true);
::paddle::distributed::PSParameter GetServerProto();
std::shared_ptr<paddle::distributed::GraphBrpcServer> get_ps_server() {
return pserver_ptr;
Expand Down Expand Up @@ -151,7 +149,8 @@ class GraphPyClient : public GraphPyService {
(paddle::distributed::GraphBrpcService*)server.get_ps_server()
->get_service());
}

void stop_server();
void finalize_worker();
void load_edge_file(std::string name, std::string filepath, bool reverse);
void load_node_file(std::string name, std::string filepath);
int get_client_id() { return client_id; }
Expand All @@ -169,9 +168,11 @@ class GraphPyClient : public GraphPyService {
::paddle::distributed::PSParameter GetWorkerProto();

protected:
mutable std::mutex mutex_;
int client_id;
std::shared_ptr<paddle::distributed::GraphBrpcClient> worker_ptr;
std::thread* client_thread;
bool stoped_ = false;
};
}
}
32 changes: 16 additions & 16 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ void RunBrpcPushSparse() {
host_sign_list_.push_back(ph_host2.serialize_to_string());
// test-end
// Srart Server
std::thread server_thread(RunServer);
std::thread server_thread2(RunServer2);
std::thread* server_thread = new std::thread(RunServer);
std::thread* server_thread2 = new std::thread(RunServer2);
sleep(1);

std::map<uint64_t, std::vector<paddle::distributed::Region>> dense_regions;
Expand Down Expand Up @@ -433,9 +433,9 @@ void RunBrpcPushSparse() {
client2.add_table_feat_conf("user", "d", "string", 1);
client2.add_table_feat_conf("item", "a", "float32", 1);

server1.start_server();
server1.start_server(false);
std::cout << "first server done" << std::endl;
server2.start_server();
server2.start_server(false);
std::cout << "second server done" << std::endl;
client1.start_client();
std::cout << "first client done" << std::endl;
Expand All @@ -451,8 +451,6 @@ void RunBrpcPushSparse() {
client1.load_node_file(std::string("item"), std::string(node_file_name));
client1.load_edge_file(std::string("user2item"), std::string(edge_file_name),
0);
// client2.load_edge_file(std::string("user2item"), std::string(file_name),
// 0);
nodes.clear();

nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1);
Expand Down Expand Up @@ -505,10 +503,10 @@ void RunBrpcPushSparse() {
client1.get_node_feat(std::string("user"), node_ids, feature_names);
ASSERT_EQ(node_feat.size(), 2);
ASSERT_EQ(node_feat[0].size(), 2);
std::cout << "get_node_feat: " << node_feat[0][0] << std::endl;
std::cout << "get_node_feat: " << node_feat[0][1] << std::endl;
std::cout << "get_node_feat: " << node_feat[1][0] << std::endl;
std::cout << "get_node_feat: " << node_feat[1][1] << std::endl;
VLOG(0) << "get_node_feat: " << node_feat[0][0];
VLOG(0) << "get_node_feat: " << node_feat[0][1];
VLOG(0) << "get_node_feat: " << node_feat[1][0];
VLOG(0) << "get_node_feat: " << node_feat[1][1];

// Test string
node_ids.clear();
Expand All @@ -522,24 +520,26 @@ void RunBrpcPushSparse() {
client1.get_node_feat(std::string("user"), node_ids, feature_names);
ASSERT_EQ(node_feat.size(), 2);
ASSERT_EQ(node_feat[0].size(), 2);
std::cout << "get_node_feat: " << node_feat[0][0].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[0][1].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[1][0].size() << std::endl;
std::cout << "get_node_feat: " << node_feat[1][1].size() << std::endl;
VLOG(0) << "get_node_feat: " << node_feat[0][0].size();
VLOG(0) << "get_node_feat: " << node_feat[0][1].size();
VLOG(0) << "get_node_feat: " << node_feat[1][0].size();
VLOG(0) << "get_node_feat: " << node_feat[1][1].size();

std::remove(edge_file_name);
std::remove(node_file_name);
LOG(INFO) << "Run stop_server";
worker_ptr_->stop_server();
LOG(INFO) << "Run finalize_worker";
worker_ptr_->finalize_worker();
server_thread.join();
server_thread2.join();

// server_thread.join();
// server_thread2.join();
testFeatureNodeSerializeInt();
testFeatureNodeSerializeInt64();
testFeatureNodeSerializeFloat32();
testFeatureNodeSerializeFloat64();
testGraphToBuffer();
client1.stop_server();
}

void testGraphToBuffer() {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ void BindGraphPyClient(py::module* m) {
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighboors)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
.def("get_node_feat",
[](GraphPyClient& self, std::string node_type,
std::vector<uint64_t> node_ids,
Expand Down

1 comment on commit a768899

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on a768899 Apr 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍PR: #31226 Commit ID: a768899 contains failed CI.

Please sign in to comment.