-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
kvstore push row sparse #93
Changes from all commits
5bff75f
f50f8e6
ab40b35
d7225ec
d24e52b
23cf932
528b3f6
e1b0329
31fa6df
b89b8b6
c4fec03
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,16 @@ | |
*/ | ||
#ifndef MXNET_KVSTORE_COMM_H_ | ||
#define MXNET_KVSTORE_COMM_H_ | ||
#include <dmlc/omp.h> | ||
#include <string> | ||
#include <algorithm> | ||
#include <utility> | ||
#include <limits> | ||
#include <vector> | ||
#include <tuple> | ||
#include <thread> | ||
#include "mxnet/ndarray.h" | ||
#include "../common/utils.h" | ||
namespace mxnet { | ||
namespace kvstore { | ||
/** | ||
|
@@ -65,6 +68,8 @@ class CommCPU : public Comm { | |
CommCPU() { | ||
nthread_reduction_ = dmlc::GetEnv("MXNET_KVSTORE_REDUCTION_NTHREADS", 4); | ||
bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); | ||
// TODO(junwu) delete the following data member, now for benchmark only | ||
is_serial_push_ = dmlc::GetEnv("MXNET_KVSTORE_SERIAL_PUSH", 0); | ||
} | ||
virtual ~CommCPU() { } | ||
|
||
|
@@ -130,7 +135,8 @@ class CommCPU : public Comm { | |
auto result = buf.merged; | ||
Engine::Get()->PushSync([reduce, result, this](RunContext rctx) { | ||
NDArray out = result; | ||
ReduceSumCPUEx(reduce, &out); | ||
is_serial_push_? | ||
ReduceSumCPUExSerial(reduce, &out) : ReduceSumCPUExParallel(reduce, &out); | ||
}, Context::CPU(), const_vars, {result.var()}, | ||
FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); | ||
} | ||
|
@@ -168,7 +174,7 @@ class CommCPU : public Comm { | |
|
||
// serial implementation of reduce sum for row sparse NDArray. | ||
// TODO(haibin) use openmp kernel to parallelize the summation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remember to remove the todo here in next PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove it now. |
||
inline void ReduceSumCPUEx(const std::vector<NDArray> &in, NDArray *out) { | ||
inline void ReduceSumCPUExSerial(const std::vector<NDArray> &in, NDArray *out) { | ||
using namespace rowsparse; | ||
using namespace mshadow; | ||
auto stype = out->storage_type(); | ||
|
@@ -239,6 +245,115 @@ class CommCPU : public Comm { | |
}); | ||
} | ||
|
||
template<typename DType, typename IType> | ||
void ReduceSumCPUExImpl(const std::vector<NDArray>& nds, | ||
const std::vector<IType>& uniq_row_idx, | ||
NDArray* out) { | ||
#pragma omp parallel num_threads(nthread_reduction_) | ||
{ | ||
const size_t nnr = uniq_row_idx.size(); | ||
const int num_threads = omp_get_num_threads(); | ||
size_t row_block_len = (nnr + num_threads - 1) / num_threads; | ||
const size_t row_block_start = omp_get_thread_num() * row_block_len; | ||
if (row_block_start < nnr) { | ||
const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); | ||
|
||
auto out_values = out->data().FlatTo2D<cpu, DType>(); | ||
auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(); | ||
for (size_t i = row_block_start; i < row_block_end; ++i) { | ||
out_indices[i] = uniq_row_idx[i]; | ||
} | ||
for (const auto& nd : nds) { | ||
if (nd.storage_initialized()) { | ||
const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D<cpu, IType>(); | ||
const auto nd_values = nd.data().FlatTo2D<cpu, DType>(); | ||
const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another thing we need to double check is the use of .Size() in our code - which returns a |
||
const IType* nd_indices_start = &nd_indices[0]; | ||
const IType* nd_indices_end = nd_indices_start + nd_num_rows; | ||
const IType* row_idx_ptr = std::lower_bound(nd_indices_start, nd_indices_end, | ||
out_indices[row_block_start]); | ||
// skip this nd if all of its row indices are smaller than out_indices[row_block_start] | ||
// or current row block is not covered by [*row_idx_ptr, nd_indices_end). | ||
if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end-1]) { | ||
continue; | ||
} | ||
for (size_t irow = row_block_start; | ||
irow < row_block_end && row_idx_ptr != nd_indices_end;) { | ||
if (out_indices[irow] == *row_idx_ptr) { | ||
auto out_value_cur_row = out_values[irow]; | ||
const auto offset = row_idx_ptr - nd_indices_start; | ||
auto nd_value_cur_row = nd_values[offset]; | ||
for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) { | ||
out_value_cur_row[j] += nd_value_cur_row[j]; | ||
} | ||
++irow; | ||
++row_idx_ptr; | ||
} else if (out_indices[irow] < *row_idx_ptr) { | ||
++irow; | ||
} else { | ||
++row_idx_ptr; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/*! | ||
* \brief Given a vector of ndarrays, generate a index vector containing | ||
* all the unique row indices of the ndarrays. | ||
*/ | ||
template<typename IType> | ||
void GetUniqueRspRowIdx(const std::vector<NDArray>& nds, | ||
std::vector<IType>* uniq_row_idx) { | ||
using namespace rowsparse; | ||
size_t total_num_rows = 0; | ||
for (const auto& nd : nds) { | ||
CHECK_EQ(nd.storage_type(), kRowSparseStorage); | ||
if (nd.storage_initialized()) { | ||
total_num_rows += nd.aux_shape(kIdx).Size(); | ||
} | ||
} | ||
|
||
uniq_row_idx->resize(total_num_rows); | ||
int nthreads = omp_get_max_threads(); | ||
size_t offset = 0; | ||
for (const auto& nd : nds) { | ||
if (nd.storage_initialized()) { | ||
const IType* nd_row_idx = nd.aux_data(kIdx).dptr<IType>(); | ||
const size_t num_rows = nd.aux_shape(kIdx).Size(); | ||
#pragma omp parallel for num_threads(nthreads) | ||
for (size_t i = 0; i < num_rows; ++i) { | ||
(*uniq_row_idx)[offset+i] = nd_row_idx[i]; | ||
} | ||
offset += num_rows; | ||
} | ||
} | ||
|
||
common::ParallelSort(uniq_row_idx->begin(), uniq_row_idx->end(), nthreads); | ||
auto it = std::unique(uniq_row_idx->begin(), uniq_row_idx->end()); | ||
uniq_row_idx->resize(it - uniq_row_idx->begin()); | ||
} | ||
|
||
void ReduceSumCPUExParallel(const std::vector<NDArray>& nds, NDArray* out) { | ||
if (nds.empty()) return; | ||
using namespace rowsparse; | ||
CHECK_EQ(out->storage_type(), kRowSparseStorage) | ||
<< "Expected row sparse storage type (" | ||
<< out->storage_type() << " given)"; | ||
|
||
MSHADOW_TYPE_SWITCH(out->dtype(), DType, { | ||
MSHADOW_INT_TYPE_SWITCH(out->aux_type(kIdx), IType, { | ||
std::vector<IType> uniq_row_idx; | ||
GetUniqueRspRowIdx(nds, &uniq_row_idx); | ||
out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())}); | ||
out->data().FlatTo2D<cpu, DType>() = static_cast<DType>(0); | ||
ReduceSumCPUExImpl<DType, IType>(nds, uniq_row_idx, out); | ||
}); | ||
}); | ||
} | ||
|
||
template<typename DType> | ||
inline static void ReduceSumCPU( | ||
const std::vector<DType*> &dptr, size_t offset, index_t size) { | ||
|
@@ -304,6 +419,7 @@ class CommCPU : public Comm { | |
std::unordered_map<int, BufferEntry> merge_buf_; | ||
size_t bigarray_bound_; | ||
int nthread_reduction_; | ||
bool is_serial_push_; | ||
}; | ||
|
||
/** | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does
num / num_threads + 5
come from?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From Mu's code. It's can be seen under ps-lite/include/ps/internal/parallel_sort.h. 5 and 1024*16 are two magic numbers I'm not sure of.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if
src/common/utils.h
is the best place to put the code. Maybe better to put it undersrc/operator
? Could be useful for other operatorsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function is more widely applicable to than just operators. Putting it under src/operator makes it look like only for operators. But the fact is that any components under src/ could use it. For example src/kvstore uses this function.