Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix data_structure problems in gpu graph_engine #42321

Merged
merged 3 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion paddle/fluid/distributed/ps/service/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})

set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)

if(WITH_HETERPS)

set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb)

else()

set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)

endif()

brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
Expand Down
107 changes: 23 additions & 84 deletions paddle/fluid/distributed/ps/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(int64_t id) {
}

std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const uint32_t &table_id, int idx_, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server;
Expand Down Expand Up @@ -124,9 +124,11 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);

closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -144,7 +146,8 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
return fut;
}

std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id,
int type_id, int idx_) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
Expand All @@ -167,7 +170,8 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);

closure->request(server_index)->add_params((char *)&type_id, sizeof(int));
closure->request(server_index)->add_params((char *)&idx_, sizeof(int));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
Expand All @@ -177,7 +181,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<int64_t> &node_id_list,
uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
Expand Down Expand Up @@ -225,6 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -245,7 +250,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<int64_t> &node_id_list) {
uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
Expand Down Expand Up @@ -286,6 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -299,7 +305,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
uint32_t table_id, int idx_, std::vector<int64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight,
Expand Down Expand Up @@ -353,6 +359,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
Expand Down Expand Up @@ -452,6 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -469,7 +477,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size,
uint32_t table_id, int type_id, int idx_, int server_index, int sample_size,
std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
Expand Down Expand Up @@ -498,6 +506,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
Expand All @@ -508,83 +518,9 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
return fut;
}

std::future<int32_t> GraphBrpcClient::load_graph_split_config(
uint32_t table_id, std::string path) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
uint32_t table_id, size_t total_size_limit, size_t ttl) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(
request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
size_t size_limit = total_size_limit / server_size +
(total_size_limit % server_size != 0 ? 1 : 0);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<FeatureNode> &res) {
uint32_t table_id, int type_id, int idx_, int server_index, int start,
int size, int step, std::vector<FeatureNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
Expand Down Expand Up @@ -613,6 +549,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
Expand All @@ -625,7 +563,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
}

std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const uint32_t &table_id, int idx_, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
Expand Down Expand Up @@ -686,6 +624,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
Expand Down
27 changes: 12 additions & 15 deletions paddle/fluid/distributed/ps/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,40 +63,37 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
std::vector<std::vector<int64_t>>& res,
uint32_t table_id, int idx, std::vector<int64_t> node_ids,
int sample_size, std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1);

virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size, int step,
virtual std::future<int32_t> pull_graph_list(uint32_t table_id, int type_id,
int idx, int server_index,
int start, int size, int step,
std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int type_id, int idx,
int server_index,
int sample_size,
std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);

virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);

virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> clear_nodes(uint32_t table_id, int type_id,
int idx);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list,
uint32_t table_id, int idx, std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
size_t ttl);
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list);
uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list);
virtual int32_t Initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
Expand Down
Loading