forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
changes based on code reviews #176
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include <functional> | ||
#include <algorithm> | ||
#include "./comm.h" | ||
|
||
|
@@ -85,7 +86,7 @@ class KVStoreLocal : public KVStore { | |
int priority) override { | ||
std::vector<int> uniq_keys; | ||
std::vector<std::vector<NDArray> > grouped_vals; | ||
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); | ||
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); | ||
|
||
for (size_t i = 0; i < uniq_keys.size(); ++i) { | ||
int key = uniq_keys[i]; | ||
|
@@ -114,7 +115,7 @@ class KVStoreLocal : public KVStore { | |
int priority) override { | ||
std::vector<int> uniq_keys; | ||
std::vector<std::vector<NDArray*> > grouped_vals; | ||
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); | ||
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); | ||
|
||
for (size_t i = 0; i < uniq_keys.size(); ++i) { | ||
int key = uniq_keys[i]; | ||
|
@@ -129,7 +130,7 @@ class KVStoreLocal : public KVStore { | |
int priority = 0) override { | ||
std::vector<int> uniq_keys; | ||
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids; | ||
GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); | ||
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); | ||
for (size_t i = 0; i < uniq_keys.size(); ++i) { | ||
int key = uniq_keys[i]; | ||
const NDArray& local = local_[key]; | ||
|
@@ -174,13 +175,75 @@ class KVStoreLocal : public KVStore { | |
|
||
protected: | ||
/** | ||
* \brief group values on keys | ||
* \brief group values on keys for push | ||
*/ | ||
template <typename V> | ||
void GroupKVPairsPush(const std::vector<int>& keys, | ||
const std::vector<NDArray>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<NDArray>> *grouped_vals) { | ||
// check if the storage type of a value is valid | ||
auto validator = [this](const int key, const NDArray& nd) -> bool { | ||
auto stype = nd.storage_type(); | ||
// valid NDArray | ||
if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; | ||
// invalid NDArray, abort | ||
LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; | ||
return false; | ||
}; | ||
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); | ||
} | ||
/** | ||
* \brief group values on keys for pull | ||
*/ | ||
void GroupKVPairsPull(const std::vector<int>& keys, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GroupKVPairsForPull? |
||
const std::vector<NDArray*>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<NDArray*>> *grouped_vals) { | ||
// check if the storage type of a value is valid | ||
auto validator = [this](const int key, const NDArray* nd) -> bool { | ||
// valid | ||
if (nd->storage_type() == kDefaultStorage) return true; | ||
// invalid, print warning messages once | ||
if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { | ||
LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " | ||
<< "Please make sure to use row_sparse_pull with row_ids instead."; | ||
this->warnings_printed_.insert(key); | ||
} | ||
return false; | ||
}; | ||
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); | ||
} | ||
/** | ||
* \brief group values on keys for row_sparse_pull | ||
*/ | ||
void GroupKVPairsPullRsp(const std::vector<int>& keys, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GroupKVPairsForPullRsp? |
||
const std::vector<std::pair<NDArray*, NDArray>>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<std::pair<NDArray*, NDArray>>> *grouped_vals) { | ||
// check if the storage type of a value is valid | ||
auto validator = [this](const int key, const std::pair<NDArray*, NDArray>& val_rowid) -> bool { | ||
auto val_stype = val_rowid.first->storage_type(); | ||
auto rowid_stype = val_rowid.second.storage_type(); | ||
// check storage types | ||
CHECK_EQ(val_stype, kRowSparseStorage) << "Expected row_sparse storage type for " | ||
<< "row_sparse_pull values, but detected storage type " << val_stype; | ||
CHECK_EQ(rowid_stype, kDefaultStorage) << "Expected default storage type for " | ||
<< "row_sparse_pull rowids, but detected storage type " << rowid_stype; | ||
return true; | ||
}; | ||
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); | ||
} | ||
|
||
/** | ||
* \brief group values on keys with validation. | ||
* A value `v` is not included in the result if is_valid(v) returns false. | ||
*/ | ||
template <typename V, typename FValidate> | ||
void GroupKVPairs(const std::vector<int>& keys, | ||
const std::vector<V>& values, | ||
std::vector<int>* uniq_keys, | ||
std::vector<std::vector<V> >* grouped_vals) { | ||
std::vector<std::vector<V> >* grouped_vals, | ||
const FValidate& is_valid) { | ||
CHECK_EQ(keys.size(), values.size()); | ||
// TODO(mli) check if already sorted as an optimization | ||
using Idx = std::pair<int, int>; | ||
|
@@ -194,12 +257,14 @@ class KVStoreLocal : public KVStore { | |
|
||
int pre_key = idx[0].first - 1; | ||
for (auto i : idx) { | ||
if (i.first != pre_key) { | ||
uniq_keys->push_back(i.first); | ||
grouped_vals->push_back({values[i.second]}); | ||
pre_key = i.first;; | ||
} else { | ||
grouped_vals->back().push_back(values[i.second]); | ||
if (is_valid(i.first, values[i.second])) { | ||
if (i.first != pre_key) { | ||
uniq_keys->push_back(i.first); | ||
grouped_vals->push_back({values[i.second]}); | ||
pre_key = i.first; | ||
} else { | ||
grouped_vals->back().push_back(values[i.second]); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -246,6 +311,8 @@ class KVStoreLocal : public KVStore { | |
std::unordered_map<std::string, int> str_key_dict_; | ||
/// the next available integer for string->int key mapping | ||
int next_str_key_ = 0; | ||
/// whether printed warning due to mismatch stype in each key | ||
std::unordered_set<int> warnings_printed_; | ||
}; | ||
} // namespace kvstore | ||
} // namespace mxnet | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
GroupKVPairsForPush
clearer?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A second thought is defining validator lambda function in the caller functions or a common place accessible to various callers and pass it to
GroupKVPairs
. In this way, there is no need to define extra interfaces of GroupKVPairsXXXX where XXXX stands for Push and Pull, respectively. Is this feasible?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also thought about that. What file do you think is the best place to put the free function? I started with
GroupKVPairsForPushRsp
but it is too long (>100 char) causing lint to fail..