Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changes based on code reviews #176

Merged
merged 3 commits into from
Aug 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
22 changes: 2 additions & 20 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,6 @@ def pull(self, key, out=None, priority=0):
[ 2. 2. 2.]]
"""
assert(out is not None)
if not isinstance(out, (list, tuple)):
out = [out]
for val in out:
if not isinstance(val, (list, tuple)):
assert(val.stype == 'default')
else:
for v in val:
assert(v.stype == 'default')
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
Expand Down Expand Up @@ -270,8 +262,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
other pull actions.

row_ids : NDArray or list of NDArray
The row_ids for which to pull for each value. The row_ids doesn't have to be unique
or sorted.
The row_ids for which to pull for each value. Each row_id is an 1D-NDArray \
whose values don't have to be unique nor sorted.

Examples
--------
Expand Down Expand Up @@ -299,16 +291,6 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
"""
assert(out is not None)
assert(row_ids is not None)
if isinstance(row_ids, NDArray):
row_ids = [row_ids]
if not isinstance(out, (list, tuple)):
out = [out]
for val in out:
if not isinstance(val, (list, tuple)):
assert(val.stype == 'row_sparse')
else:
for v in val:
assert(v.stype == 'row_sparse')
ckeys, cvals = _ctype_key_value(key, out)
_, crow_ids = _ctype_key_value(key, row_ids)
assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values"
Expand Down
31 changes: 3 additions & 28 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,29 +93,14 @@ def _create_kvstore(kvstore, num_device, arg_params):

return (kv, update_on_kvstore)

def _contains_non_default_storage(params):
if isinstance(params, (list, tuple)):
for param in params:
if param.stype != 'default':
return True
elif isinstance(params, NDArray):
return param.stype != 'default'
else:
return False

def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore):
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
name = param_names[idx]
kvstore.init(name, arg_params[name])

if update_on_kvstore:
if _contains_non_default_storage(param_on_devs):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids explicitly', RuntimeWarning)
else:
kvstore.pull(name, param_on_devs, priority=-idx)
kvstore.pull(name, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on kvstore."""
Expand All @@ -127,12 +112,7 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
if _contains_non_default_storage(arg_list):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids', RuntimeWarning)
else:
kvstore.pull(name, arg_list, priority=-index)
kvstore.pull(name, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None, param_names=None):
Expand All @@ -147,12 +127,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
# push gradient, priority is negative index
kvstore.push(name, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
if _contains_non_default_storage(grad_list):
# skip pulling row_sparse weights
warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \
'sure to pull it with row_ids', RuntimeWarning)
else:
kvstore.pull(name, grad_list, priority=-index)
kvstore.pull(name, grad_list, priority=-index)
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import errno
import logging
from contextlib import contextmanager
import scipy.sparse as sp
import numpy as np
import numpy.testing as npt
import numpy.random as rnd
Expand Down Expand Up @@ -125,6 +124,7 @@ def _get_uniform_dataset_csr(num_rows, num_cols, density=0.1, dtype=None):
"""
_validate_csr_generation_inputs(num_rows, num_cols, density,
distribution="uniform")
from scipy import sparse as sp
csr = sp.rand(num_rows, num_cols, density, dtype=dtype, format="csr")
result = mx.nd.csr_matrix(csr.data, csr.indptr, csr.indices,
(num_rows, num_cols), dtype=dtype)
Expand Down
6 changes: 3 additions & 3 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class KVStoreDist : public KVStoreLocal {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand Down Expand Up @@ -160,7 +160,7 @@ class KVStoreDist : public KVStoreLocal {
const int priority = 0) {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand Down Expand Up @@ -261,7 +261,7 @@ class KVStoreDist : public KVStoreLocal {
// first aggregate the values over keys
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
// merge over devcies
Expand Down
91 changes: 79 additions & 12 deletions src/kvstore/kvstore_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <vector>
#include <string>
#include <utility>
#include <functional>
#include <algorithm>
#include "./comm.h"

Expand Down Expand Up @@ -85,7 +86,7 @@ class KVStoreLocal : public KVStore {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand Down Expand Up @@ -114,7 +115,7 @@ class KVStoreLocal : public KVStore {
int priority) override {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals);
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals);

for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
Expand All @@ -129,7 +130,7 @@ class KVStoreLocal : public KVStore {
int priority = 0) override {
std::vector<int> uniq_keys;
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids;
GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const NDArray& local = local_[key];
Expand Down Expand Up @@ -174,13 +175,75 @@ class KVStoreLocal : public KVStore {

protected:
/**
* \brief group values on keys
* \brief group values on keys for push
*/
template <typename V>
void GroupKVPairsPush(const std::vector<int>& keys,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is GroupKVPairsForPush clearer?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second thought is defining validator lambda function in the caller functions or a common place accessible to various callers and pass it to GroupKVPairs. In this way, there is no need to define extra interfaces of GroupKVPairsXXXX where XXXX stands for Push and Pull, respectively. Is this feasible?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also thought about that. What file do you think is the best place to put the free function? I started with GroupKVPairsForPushRsp but it is too long (>100 char) causing lint to fail..

const std::vector<NDArray>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const NDArray& nd) -> bool {
auto stype = nd.storage_type();
// valid NDArray
if (stype == kDefaultStorage || stype == kRowSparseStorage) return true;
// invalid NDArray, abort
LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype;
return false;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}
/**
* \brief group values on keys for pull
*/
void GroupKVPairsPull(const std::vector<int>& keys,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupKVPairsForPull?

const std::vector<NDArray*>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<NDArray*>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const NDArray* nd) -> bool {
// valid
if (nd->storage_type() == kDefaultStorage) return true;
// invalid, print warning messages once
if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) {
LOG(INFO) << "Warning: non-default weights detected during kvstore pull. "
<< "Please make sure to use row_sparse_pull with row_ids instead.";
this->warnings_printed_.insert(key);
}
return false;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}
/**
* \brief group values on keys for row_sparse_pull
*/
void GroupKVPairsPullRsp(const std::vector<int>& keys,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupKVPairsForPullRsp?

const std::vector<std::pair<NDArray*, NDArray>>& values,
std::vector<int> *uniq_keys,
std::vector<std::vector<std::pair<NDArray*, NDArray>>> *grouped_vals) {
// check if the storage type of a value is valid
auto validator = [this](const int key, const std::pair<NDArray*, NDArray>& val_rowid) -> bool {
auto val_stype = val_rowid.first->storage_type();
auto rowid_stype = val_rowid.second.storage_type();
// check storage types
CHECK_EQ(val_stype, kRowSparseStorage) << "Expected row_sparse storage type for "
<< "row_sparse_pull values, but detected storage type " << val_stype;
CHECK_EQ(rowid_stype, kDefaultStorage) << "Expected default storage type for "
<< "row_sparse_pull rowids, but detected storage type " << rowid_stype;
return true;
};
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator);
}

/**
* \brief group values on keys with validation.
* A value `v` is not included in the result if is_valid(v) returns false.
*/
template <typename V, typename FValidate>
void GroupKVPairs(const std::vector<int>& keys,
const std::vector<V>& values,
std::vector<int>* uniq_keys,
std::vector<std::vector<V> >* grouped_vals) {
std::vector<std::vector<V> >* grouped_vals,
const FValidate& is_valid) {
CHECK_EQ(keys.size(), values.size());
// TODO(mli) check if already sorted as an optimization
using Idx = std::pair<int, int>;
Expand All @@ -194,12 +257,14 @@ class KVStoreLocal : public KVStore {

int pre_key = idx[0].first - 1;
for (auto i : idx) {
if (i.first != pre_key) {
uniq_keys->push_back(i.first);
grouped_vals->push_back({values[i.second]});
pre_key = i.first;;
} else {
grouped_vals->back().push_back(values[i.second]);
if (is_valid(i.first, values[i.second])) {
if (i.first != pre_key) {
uniq_keys->push_back(i.first);
grouped_vals->push_back({values[i.second]});
pre_key = i.first;
} else {
grouped_vals->back().push_back(values[i.second]);
}
}
}
}
Expand Down Expand Up @@ -246,6 +311,8 @@ class KVStoreLocal : public KVStore {
std::unordered_map<std::string, int> str_key_dict_;
/// the next available integer for string->int key mapping
int next_str_key_ = 0;
/// whether printed warning due to mismatch stype in each key
std::unordered_set<int> warnings_printed_;
};
} // namespace kvstore
} // namespace mxnet
Expand Down
49 changes: 43 additions & 6 deletions tests/python/unittest/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_single_kv_pair():
def check_single_kv_pair(kv, key):
kv.push(key, mx.nd.ones(shape))
val = mx.nd.empty(shape)
kv.pull(key, out = val)
kv.pull(key, out=val)
check_diff_to_scalar(val, 1)

check_single_kv_pair(init_kv(), 3)
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_list_kv_pair():
def check_list_kv_pair(kv, key):
kv.push(key, [mx.nd.ones(shape)*4] * len(key))
val = [mx.nd.empty(shape)] * len(key)
kv.pull(key, out = val)
kv.pull(key, out=val)
for v in val:
check_diff_to_scalar(v, 4)

Expand All @@ -122,15 +122,15 @@ def check_aggregator(kv, key, key_list):
vals = [mx.nd.ones(shape, d) for d in devs]

kv.push(key, vals)
kv.pull(key, out = vals)
kv.pull(key, out=vals)

for v in vals:
check_diff_to_scalar(v, num_devs)

# list
vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list)
kv.push(key_list, vals)
kv.pull(key_list, out = vals)
kv.pull(key_list, out=vals)

for vv in vals:
for v in vv:
Expand Down Expand Up @@ -196,7 +196,7 @@ def check_updater(kv, key, key_list):
vals = [mx.nd.ones(shape, d) for d in devs]

kv.push(key, vals)
kv.pull(key, out = vals)
kv.pull(key, out=vals)

for v in vals:
check_diff_to_scalar(v, num_devs)
Expand All @@ -208,7 +208,7 @@ def check_updater(kv, key, key_list):
for i in range(num_push):
kv.push(key_list, vals)

kv.pull(key_list, out = vals)
kv.pull(key_list, out=vals)

for vv in vals:
for v in vv:
Expand All @@ -227,6 +227,43 @@ def test_get_type():
kv = mx.kv.create(kvtype)
assert kv.type == kvtype

def test_invalid_pull():
def check_invalid_single_kv_pair(kv, key):
dns_val = mx.nd.ones(shape) * 2
rsp_val = dns_val.tostype('row_sparse')
kv.pull(key, out=rsp_val)
# pull should be ignored with no values updated
check_diff_to_scalar(rsp_val, 2)
try:
# row_sparse_pull should be aborted when vals.stype != row_sparse
kv.row_sparse_pull(key, out=dns_val, rowids=mx.nd.array([1]))
assert(False)
except:
pass

def check_invalid_list_kv_pair(kv, key):
dns_val = [mx.nd.ones(shape) * 2] * len(key)
rsp_val = [val.tostype('row_sparse') for val in dns_val]
kv.pull(key, out=rsp_val)
for v in rsp_val:
# pull should be ignored with no values updated
check_diff_to_scalar(v, 2)
try:
# row_sparse_pull should be aborted when vals.stype != row_sparse
kv.row_sparse_pull(key, out=dns_val, rowids=[mx.nd.array([1])] * len(key))
assert(False)
except:
pass

int_kv = init_kv()
str_kv = init_kv_with_str()

check_invalid_single_kv_pair(int_kv, 3)
check_invalid_single_kv_pair(str_kv, 'a')

check_invalid_list_kv_pair(int_kv, keys)
check_invalid_list_kv_pair(str_kv, str_keys)

if __name__ == '__main__':
test_init()
test_get_type()
Expand Down
5 changes: 2 additions & 3 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,8 @@ def fm(factor_size, feature_dim, init):
# initialize parameters by uniform random numbers
mod.init_params(initializer=init)
# use Sparse SGD with learning rate 0.1 to train
sgd = mx.optimizer.SGD(momentum=0.1, clip_gradient=5.0, learning_rate=0.01,
rescale_grad=1.0/batch_size)
mod.init_optimizer(optimizer=sgd)
adam = mx.optimizer.Adam(clip_gradient=5.0, learning_rate=0.001, rescale_grad=1.0/batch_size)
mod.init_optimizer(optimizer=adam)
# use accuracy as the metric
metric = mx.metric.create('MSE')
# train 10 epoch
Expand Down