Skip to content

Commit

Permalink
Merge pull request #22 from DesmonDay/cpu_gpu_graph_engine
Browse files Browse the repository at this point in the history
Add sample method v2
  • Loading branch information
seemingwang authored Apr 20, 2022
2 parents b7ae11a + bd55c12 commit b7aa7f9
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 5 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ int32_t GraphTable::build_sampler(std::string sample_type) {
}
return 0;
}

int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
// #ifdef PADDLE_WITH_HETERPS
// if (gpups_mode) pthread_rwlock_rdlock(rw_lock.get());
Expand Down Expand Up @@ -978,6 +979,7 @@ int32_t GraphTable::random_sample_neighbors(
seq_id[index].emplace_back(idx);
id_list[index].emplace_back(node_ids[idx], sample_size, need_weight);
}

for (int i = 0; i < (int)seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ node_list[8]-> node_id:17, neighbor_size:1, neighbor_offset:15
struct NeighborSampleResult {
int64_t *val;
int *actual_sample_size, sample_size, key_size;
int *offset;
std::shared_ptr<memory::Allocation> val_mem, actual_sample_size_mem;

NeighborSampleResult(int _sample_size, int _key_size, int dev_id)
Expand All @@ -130,7 +129,6 @@ struct NeighborSampleResult {
actual_sample_size_mem =
memory::AllocShared(place, _key_size * sizeof(int));
actual_sample_size = (int *)actual_sample_size_mem->ptr();
offset = NULL;
};
~NeighborSampleResult() {
// if (val != NULL) cudaFree(val);
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class GpuPsGraphTable : public HeterComm<int64_t, int, int> {
NodeQueryResult *graph_node_sample(int gpu_id, int sample_size);
NeighborSampleResult *graph_neighbor_sample(int gpu_id, int64_t *key,
int sample_size, int len);
NeighborSampleResult *graph_neighbor_sample_v2(int gpu_id, int64_t *key,
int sample_size, int len,
bool cpu_query_switch);
NodeQueryResult *query_node_list(int gpu_id, int start, int query_size);
void clear_graph_info();
void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num,
Expand Down
242 changes: 242 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/device_vector.h>

#pragma once
#ifdef PADDLE_WITH_HETERPS
//#include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h"
Expand All @@ -28,6 +30,69 @@ sample_result is to save the neighbor sampling result, its size is len *
sample_size;
*/

__global__ void get_cpu_id_index(int64_t* key, int* val, int64_t* cpu_key,
int* sum, int* index, int len) {
CUDA_KERNEL_LOOP(i, len) {
if (val[i] == -1) {
int old = atomicAdd(sum, 1);
cpu_key[old] = key[i];
index[old] = i;
}
}
}

template <int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample_example_v2(GpuPsCommGraph graph,
int* node_index, int* actual_size,
int64_t* res, int sample_len,
int n) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);

int i = blockIdx.x * TILE_SIZE + threadIdx.y;
const int last_idx = min(static_cast<int>(blockIdx.x + 1) * TILE_SIZE, n);
curandState rng;
curand_init(blockIdx.x, threadIdx.y * WARP_SIZE + threadIdx.x, 0, &rng);

while (i < last_idx) {
if (node_index[i] == -1) {
actual_size[i] = 0;
i += BLOCK_WARPS;
continue;
}
int neighbor_len = graph.node_list[node_index[i]].neighbor_size;
int data_offset = graph.node_list[node_index[i]].neighbor_offset;
int offset = i * sample_len;
int64_t* data = graph.neighbor_list;
if (neighbor_len <= sample_len) {
for (int j = threadIdx.x; j < neighbor_len; j += WARP_SIZE) {
res[offset + j] = data[data_offset + j];
}
actual_size[i] = neighbor_len;
} else {
for (int j = threadIdx.x; j < sample_len; j += WARP_SIZE) {
res[offset + j] = j;
}
__syncwarp();
for (int j = sample_len + threadIdx.x; j < neighbor_len; j += WARP_SIZE) {
const int num = curand(&rng) % (j + 1);
if (num < sample_len) {
atomicMax(reinterpret_cast<unsigned int*>(res + offset + num),
static_cast<unsigned int>(j));
}
}
__syncwarp();
for (int j = threadIdx.x; j < sample_len; j += WARP_SIZE) {
const int perm_idx = res[offset + j] + data_offset;
res[offset + j] = data[perm_idx];
}
actual_size[i] = sample_len;
}
i += BLOCK_WARPS;
}
}

__global__ void neighbor_sample_example(GpuPsCommGraph graph, int* node_index,
int* actual_size, int64_t* res,
int sample_len, int* sample_status,
Expand Down Expand Up @@ -402,6 +467,7 @@ void GpuPsGraphTable::build_graph_from_cpu(
}
cudaDeviceSynchronize();
}

NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
int64_t* key,
int sample_size,
Expand Down Expand Up @@ -620,6 +686,182 @@ NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
return result;
}

NeighborSampleResult* GpuPsGraphTable::graph_neighbor_sample_v2(
int gpu_id, int64_t* key, int sample_size, int len, bool cpu_query_switch) {
NeighborSampleResult* result =
new NeighborSampleResult(sample_size, len, resource_->dev_id(gpu_id));

if (len == 0) {
return result;
}

platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
int* actual_sample_size = result->actual_sample_size;
int64_t* val = result->val;
int total_gpu = resource_->total_device();
auto stream = resource_->local_stream(gpu_id, 0);

int grid_size = (len - 1) / block_size_ + 1;

int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT

auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());

cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
//
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());

auto d_shard_keys = memory::Alloc(place, len * sizeof(int64_t));
int64_t* d_shard_keys_ptr = reinterpret_cast<int64_t*>(d_shard_keys->ptr());
auto d_shard_vals = memory::Alloc(place, sample_size * len * sizeof(int64_t));
int64_t* d_shard_vals_ptr = reinterpret_cast<int64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());

split_input_to_shard(key, d_idx_ptr, len, d_left_ptr, d_right_ptr, gpu_id);

heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream);

cudaStreamSynchronize(stream);

cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t));
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right, d_shard_keys_ptr, NULL);

// For cpu_query_switch, we need global items.
std::vector<thrust::device_vector<int64_t>> cpu_keys_list;
std::vector<thrust::device_vector<int>> cpu_index_list;
thrust::device_vector<int64_t> tmp1;
thrust::device_vector<int> tmp2;
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
// Insert empty object
cpu_keys_list.emplace_back(tmp1);
cpu_index_list.emplace_back(tmp2);
continue;
}
auto& node = path_[gpu_id][i].nodes_.back();
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
// If not found, val is -1.
tables_[i]->get(reinterpret_cast<int64_t*>(node.key_storage),
reinterpret_cast<int*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));

auto shard_len = h_right[i] - h_left[i] + 1;
auto graph = gpu_graph_list[i];
int* id_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = id_array + shard_len;
int64_t* sample_array = (int64_t*)(id_array + shard_len * 2);
constexpr int WARP_SIZE = 32;
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((shard_len + TILE_SIZE - 1) / TILE_SIZE);
neighbor_sample_example_v2<
WARP_SIZE, BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, resource_->remote_stream(i, gpu_id)>>>(
graph, id_array, actual_size_array, sample_array, sample_size,
shard_len);

// cpu_graph_table->random_sample_neighbors
if (cpu_query_switch) {
thrust::device_vector<int64_t> cpu_keys_ptr(shard_len);
thrust::device_vector<int> index_ptr(shard_len + 1, 0);
int64_t* node_id_array = reinterpret_cast<int64_t*>(node.key_storage);
int grid_size2 = (shard_len - 1) / block_size_ + 1;
get_cpu_id_index<<<grid_size2, block_size_, 0,
resource_->remote_stream(i, gpu_id)>>>(
node_id_array, id_array,
thrust::raw_pointer_cast(cpu_keys_ptr.data()),
thrust::raw_pointer_cast(index_ptr.data()),
thrust::raw_pointer_cast(index_ptr.data()) + 1, shard_len);

cpu_keys_list.emplace_back(cpu_keys_ptr);
cpu_index_list.emplace_back(index_ptr);
}
}

for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
cudaStreamSynchronize(resource_->remote_stream(i, gpu_id));
}

if (cpu_query_switch) {
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
auto shard_len = h_right[i] - h_left[i] + 1;
int* cpu_index = new int[shard_len + 1];
cudaMemcpy(cpu_index, thrust::raw_pointer_cast(cpu_index_list[i].data()),
(shard_len + 1) * sizeof(int), cudaMemcpyDeviceToHost);
if (cpu_index[0] > 0) {
int number_on_cpu = cpu_index[0];
int64_t* cpu_keys = new int64_t[number_on_cpu];
cudaMemcpy(cpu_keys, thrust::raw_pointer_cast(cpu_keys_list[i].data()),
number_on_cpu * sizeof(int64_t), cudaMemcpyDeviceToHost);

std::vector<std::shared_ptr<char>> buffers(number_on_cpu);
std::vector<int> ac(number_on_cpu);
auto status = cpu_graph_table->random_sample_neighbors(
cpu_keys, sample_size, buffers, ac, false);

auto& node = path_[gpu_id][i].nodes_.back();
int* id_array = reinterpret_cast<int*>(node.val_storage);
int* actual_size_array = id_array + shard_len;
int64_t* sample_array = (int64_t*)(id_array + shard_len * 2);
for (int j = 0; j < number_on_cpu; j++) {
int offset = cpu_index[j + 1] * sample_size;
ac[j] = ac[j] / sizeof(int64_t);
cudaMemcpy(sample_array + offset, (int64_t*)(buffers[j].get()),
sizeof(int64_t) * ac[j], cudaMemcpyHostToDevice);
cudaMemcpy(actual_size_array + cpu_index[j + 1], ac.data() + j,
sizeof(int), cudaMemcpyHostToDevice);
}
}
}
}
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);

fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
d_idx_ptr, sample_size, len);
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
destroy_storage(gpu_id, i);
}
cudaStreamSynchronize(stream);
return result;
}

NodeQueryResult* GpuPsGraphTable::graph_node_sample(int gpu_id,
int sample_size) {}

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ void HeterComm<KeyType, ValType, GradType>::walk_to_dest(int start_index,
memory_copy(dst_place, node.key_storage, src_place,
reinterpret_cast<char*>(src_key + h_left[i]),
node.key_bytes_len, node.in_stream);
cudaMemsetAsync(node.val_storage, -1, node.val_bytes_len, node.in_stream);

if (need_copy_val) {
memory_copy(dst_place, node.val_storage, src_place,
reinterpret_cast<char*>(src_val + h_left[i]),
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/framework/fleet/heter_ps/test_cpu_query.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ TEST(TEST_FLEET, test_cpu_cache) {
platform::CUDADeviceGuard guard(0);
cudaMalloc((void **)&key, 3 * sizeof(int64_t));
cudaMemcpy(key, cpu_key, 3 * sizeof(int64_t), cudaMemcpyHostToDevice);
auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 2, 3);
// auto neighbor_sample_res = g.graph_neighbor_sample(0, (int64_t *)key, 2,
// 3);
auto neighbor_sample_res =
g.graph_neighbor_sample_v2(0, (int64_t *)key, 2, 3, true);
int64_t *res = new int64_t[7];
cudaMemcpy(res, neighbor_sample_res->val, 3 * 2 * sizeof(int64_t),
cudaMemcpyDeviceToHost);
Expand All @@ -79,7 +82,7 @@ TEST(TEST_FLEET, test_cpu_cache) {
3 * sizeof(int),
cudaMemcpyDeviceToHost); // 3, 1, 3

//{0,9} or {9,0} is expected for key 0
//{1,9} or {9,1} is expected for key 0
//{0,2} or {2,0} is expected for key 1
//{1,3} or {3,1} is expected for key 2
for (int i = 0; i < 3; i++) {
Expand Down
33 changes: 32 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ void testSampleRate() {
res[i].push_back(result);
}
*/

// g.graph_neighbor_sample
start = 0;
auto func = [&rwlock, &g, &start, &ids](int i) {
int st = 0;
Expand All @@ -288,8 +290,37 @@ void testSampleRate() {
auto end1 = std::chrono::steady_clock::now();
auto tt =
std::chrono::duration_cast<std::chrono::microseconds>(end1 - start1);
std::cerr << "total time cost without cache is "
std::cerr << "total time cost without cache for v1 is "
<< tt.count() / exe_count / gpu_num1 << " us" << std::endl;

// g.graph_neighbor_sample_v2
start = 0;
auto func2 = [&rwlock, &g, &start, &ids](int i) {
int st = 0;
int size = ids.size();
for (int k = 0; k < exe_count; k++) {
st = 0;
while (st < size) {
int len = std::min(fixed_key_size, (int)ids.size() - st);
auto r = g.graph_neighbor_sample_v2(i, (int64_t *)(key[i] + st),
sample_size, len, false);
st += len;
delete r;
}
}
};
auto start2 = std::chrono::steady_clock::now();
std::thread thr2[gpu_num1];
for (int i = 0; i < gpu_num1; i++) {
thr2[i] = std::thread(func2, i);
}
for (int i = 0; i < gpu_num1; i++) thr2[i].join();
auto end2 = std::chrono::steady_clock::now();
auto tt2 =
std::chrono::duration_cast<std::chrono::microseconds>(end2 - start2);
std::cerr << "total time cost without cache for v2 is "
<< tt2.count() / exe_count / gpu_num1 << " us" << std::endl;

for (int i = 0; i < gpu_num1; i++) {
cudaFree(key[i]);
}
Expand Down

0 comments on commit b7aa7f9

Please sign in to comment.