Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-331] Single machine All Reduce Topology-aware Communication (U…
Browse files Browse the repository at this point in the history
…pdated) (#11591)

* add multiroot all-reduce communication pattern

* fix bug with UpdateWeight

* fix PCI-E links appearing in weight matrix bug

* optimization to skip CopyFromTo in ReduceInner gains a bit of throughput

* remove unnecessary if statement

* Add tests

* add more tests, 6 tests left to add

* get rid of some dead code

* Add comments

* Add randomized tests for backtrack and kernighan-lin

* Fix Postprocess

* Add switch for first valid tree when num_gpus > 8, and for maximum weight when num_gpus <= 8

* Kernighan-Lin seems to find better trees

* get rid of printfs

* change defaults

* inherit from CommDevice instead of Comm

* Fix lint errors

* Add Python test using MXNET_KVSTORE_USETREE, fix CMake compilation problem, add header guard

* fix lint errors

* better header guard that works for tests

* get rid of unused variable warning

* retrigger jenkins

* resolve 2 comments

* address comment using Class to do test, get rid of extraneous test, use PCI-E as fallback for GPUs that are not linked by NVLink

* address comments

* fix a few bugs

* get rid of printfs

* get rid of print

* Comment out test for now

* fix 2 more bugs

* fix segfault

* change PrintVector, PrintTopo, PrintMatrix to LOG(INFO) instead of stdout

* Fix code alignment

* get rid of todo

* Make changes to env variable names to indicate they are TREE-related

* Add note saying when ARRAY_BOUND env var takes effect
  • Loading branch information
ctcyang authored and eric-haibin-lin committed Jul 24, 2018
1 parent 64d2e8b commit fe07d50
Show file tree
Hide file tree
Showing 9 changed files with 2,538 additions and 77 deletions.
26 changes: 26 additions & 0 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,32 @@ export MXNET_GPU_WORKER_NTHREADS=3
- The minimum size of a "big array".
- When the array size is bigger than this threshold, MXNET_KVSTORE_REDUCTION_NTHREADS threads are used for reduction.
- This parameter is also used as a load balancer in kvstore. It controls when to partition a single weight to all the servers. If the size of a single weight is less than MXNET_KVSTORE_BIGARRAY_BOUND then, it is sent to a single randomly picked server otherwise it is partitioned to all the servers.

* MXNET_KVSTORE_USETREE
- Values: 0(false) or 1(true) ```(default=0)```
- If true, MXNet tries to use tree reduction for Push and Pull communication.
- Otherwise, MXNet uses the default Push and Pull implementation.
- [Tree reduction technology](http://www.sysml.cc/doc/178.pdf) has been shown to be faster than the standard ```--kv-store device``` Push/Pull and ```--kv-store nccl``` Push/Pull for small batch sizes.

* MXNET_KVSTORE_LOGTREE
- Values: 0(false) or 1(true) ```(default=0)```
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet will log the reduction trees that have been generated.

* MXNET_KVSTORE_TREE_ARRAY_BOUND
- Values: Int ```(default=10000000)```
- The minimum size of a "big array".
- When the array size is bigger than this threshold and MXNET_KVSTORE_USETREE is set to 1, multiple trees are used to load balance the big gradient being communicated in order to better saturate link bandwidth.
- Note: This environmental variable only takes effect if Tree KVStore is being used (MXNET_KVSTORE_USETREE=1).

* MXNET_KVSTORE_TREE_BACKTRACK
- Values: 0(false) or 1(true) ```(default=0)
- If true and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use backtracking to generate the trees required for tree reduction.
- If false and MXNET_KVSTORE_USETREE is set to 1, MXNet tries to use Kernighan-Lin heuristic to generate the trees required for tree reduction.

* MXNET_KVSTORE_TREE_LINK_USAGE_PENALTY
- Values: Float ```(default=0.7)```
- The multiplicative penalty term to a link being used once.

* MXNET_ENABLE_GPU_P2P
- Values: 0(false) or 1(true) ```(default=1)```
- If true, MXNet tries to use GPU peer-to-peer communication, if available on your device,
Expand Down
121 changes: 68 additions & 53 deletions src/kvstore/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,31 @@ class CommDevice : public Comm {
}
}

const NDArray& ReduceRowSparse(int key, const std::vector<NDArray>& src,
int priority) {
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());

const NDArrayStorageType stype = src[0].storage_type();
NDArray& buf_merged = buf.merged_buf(stype);
if (buf.copy_buf.empty()) {
// initialize buffer for copying during reduce
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
}
}
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
}
ElementwiseSum(reduce, &buf_merged, priority);
return buf_merged;
}

const NDArray& Reduce(int key, const std::vector<NDArray>& src,
int priority) override {
// when this reduce is called from kvstore_dist, gc is not set
Expand All @@ -490,13 +515,14 @@ class CommDevice : public Comm {

InitBuffersAndComm(src);
auto& buf = merge_buf_[key];
std::vector<NDArray> reduce(src.size());

const NDArrayStorageType stype = src[0].storage_type();
NDArray& buf_merged = buf.merged_buf(stype);
// normal dense reduce
if (stype == kDefaultStorage) {
CopyFromTo(src[0], &buf_merged, priority);

std::vector<NDArray> reduce(src.size());
reduce[0] = buf_merged;

if (buf.copy_buf.empty()) {
Expand All @@ -514,24 +540,11 @@ class CommDevice : public Comm {
CopyFromTo(src[i+1], &(buf.copy_buf[i]), priority);
reduce[i+1] = buf.copy_buf[i];
}
ElementwiseSum(reduce, &buf_merged, priority);
} else {
// sparse reduce
if (buf.copy_buf.empty()) {
// initialize buffer for copying during reduce
buf.copy_buf.resize(src.size());
for (size_t j = 0; j < src.size(); ++j) {
buf.copy_buf[j] = NDArray(stype, src[0].shape(), buf_merged.ctx(), true, src[0].dtype());
}
}
CHECK(src[0].storage_type() == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << src[0].storage_type() << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 0; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i]), priority);
reduce[i] = buf.copy_buf[i];
}
buf_merged = ReduceRowSparse(key, src, priority);
}
ElementwiseSum(reduce, &buf_merged, priority);
return buf_merged;
}

Expand Down Expand Up @@ -659,6 +672,42 @@ class CommDevice : public Comm {
}
}

using KeyAttrs = std::tuple<int, TShape, int>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
const KeyAttrs& a, const KeyAttrs& b) {
return std::get<1>(a).Size() > std::get<1>(b).Size();
});

std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
for (auto d : devs) {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}

for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
const int key = std::get<0>(sorted_key_attrs_[i]);
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
const int type = std::get<2>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
size_t size = it->second.second;
if (size <= min_size) {
ctx = it->second.first;
min_size = size;
}
}
// Delayed allocation - as the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}

private:
void EnableP2P(const std::vector<Context>& devs) {
#if MXNET_USE_CUDA
Expand Down Expand Up @@ -702,43 +751,6 @@ class CommDevice : public Comm {
#endif
}

using KeyAttrs = std::tuple<int, TShape, int>;
// try to allocate buff on device evenly
void InitMergeBuffer(const std::vector<Context>& devs) {
std::sort(sorted_key_attrs_.begin(), sorted_key_attrs_.end(), [](
const KeyAttrs& a, const KeyAttrs& b) {
return std::get<1>(a).Size() > std::get<1>(b).Size();
});

std::unordered_map<int, std::pair<Context, size_t>> ctx_info;
for (auto d : devs) {
ctx_info[d.dev_id] = std::make_pair(d, 0);
}

for (size_t i = 0; i < sorted_key_attrs_.size(); ++i) {
const int key = std::get<0>(sorted_key_attrs_[i]);
const TShape& shape = std::get<1>(sorted_key_attrs_[i]);
const int type = std::get<2>(sorted_key_attrs_[i]);
auto& buf = merge_buf_[key];
Context ctx;
size_t min_size = std::numeric_limits<size_t>::max();
for (auto it = ctx_info.begin(); it != ctx_info.end(); ++it) {
size_t size = it->second.second;
if (size <= min_size) {
ctx = it->second.first;
min_size = size;
}
}
// Delayed allocation - as the dense merged buffer might not be used at all if push()
// only sees sparse arrays
bool delay_alloc = true;
buf.merged = NDArray(shape, ctx, delay_alloc, type);
ctx_info[ctx.dev_id].second += shape.Size();
}
inited_ = true;
}

std::vector<KeyAttrs> sorted_key_attrs_;
/// \brief temporal space for pushing and pulling
struct BufferEntry {
/// \brief the dense merged value for reduce and broadcast operations
Expand Down Expand Up @@ -773,7 +785,10 @@ class CommDevice : public Comm {
NDArray sparse_merged;
};
std::unordered_map<int, BufferEntry> merge_buf_;

public:
bool inited_;
std::vector<KeyAttrs> sorted_key_attrs_;
};

} // namespace kvstore
Expand Down
Loading

0 comments on commit fe07d50

Please sign in to comment.