From caa1d08f55b30a41755b4a1cb212fff68580fa52 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Tue, 20 Jun 2017 05:52:42 +0000 Subject: [PATCH 1/8] update kvstore unit test --- include/mxnet/c_api.h | 41 ++++++++ include/mxnet/kvstore.h | 27 ++++++ python/mxnet/kvstore.py | 69 ++++++++++--- src/c_api/c_api.cc | 47 +++++++++ src/kvstore/kvstore.cc | 1 - src/kvstore/kvstore_local.h | 45 +++++++++ tests/python/unittest/test_kvstore.py | 133 +++++++++++++++----------- 7 files changed, 297 insertions(+), 66 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 90270f776456..9783a88ed366 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1328,6 +1328,19 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, const int* keys, NDArrayHandle* vals); +/*! + * \brief Init a list of (key,value) pairs in kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals); + /*! * \brief Push a list of (key,value) pairs to kvstore * \param handle handle to the kvstore @@ -1342,6 +1355,20 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, const int* keys, NDArrayHandle* vals, int priority); +/*! + * \brief Push a list of (key,value) pairs to kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore * \param handle handle to the kvstore @@ -1356,6 +1383,20 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, const int* keys, NDArrayHandle* vals, int priority); +/*! + * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority); /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index dafaf1bf9cab..a77f653d492c 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -63,6 +63,13 @@ class KVStore { */ virtual void Init(const std::vector& keys, const std::vector& values) = 0; + /*! + * \brief Initialize a list of key-value pair to the store. + * \param keys a list of unique keys in string format + * \param values a list of values + */ + virtual void Init(const std::vector& str_keys, + const std::vector& values) = 0; /*! * \brief push a list of key-value pairs into the store * @@ -102,6 +109,16 @@ class KVStore { virtual void Push(const std::vector& keys, const std::vector& values, int priority = 0) = 0; + + /*! + * \brief push a list of key-value pairs into the store + * \param keys the list of keys in string format + * \param values the list of values + * \param priority Priority of the action. + */ + virtual void Push(const std::vector& str_keys, + const std::vector& values, + int priority = 0) = 0; /*! * \brief pull a list of key-value pairs from the store * @@ -128,6 +145,16 @@ class KVStore { virtual void Pull(const std::vector& keys, const std::vector& values, int priority = 0) = 0; + /*! + * \brief pull a list of key-value pairs from the store + * \param keys the list of keys in string format + * \param values the list of buffers for the pulled data, they should be preallocated + * \param priority Priority of the action. + */ + virtual void Pull(const std::vector& str_keys, + const std::vector& values, + int priority = 0) = 0; + /** * \brief the prototype of user-defined updater diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ab07421caffd..dc6c1bc06e3f 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -35,6 +35,32 @@ def _ctype_key_value(keys, vals): c_vals += c_val_i return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) +def _ctype_str_key_value(keys, vals): + names = [] + if isinstance(keys, str): + if isinstance(vals, NDArray): + names.append(c_str(keys)) + return (c_array(ctypes.c_char_p, names), + c_array(NDArrayHandle, [vals.handle])) + else: + for value in vals: + assert(isinstance(value, NDArray)) + return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), + c_array(NDArrayHandle, [value.handle for value in vals])) + else: + assert(len(keys) == len(vals)) + for k in keys: + assert(isinstance(k, str)) + c_keys = [] + c_vals = [] + for key, val in zip(keys, vals): + c_key_i, c_val_i = _ctype_str_key_value(key, val) + c_keys += c_key_i + c_vals += c_val_i + return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) + +def _use_str_keys(key): + return isinstance(key, str) or (isinstance(key, (list, tuple)) and isinstance(key[0], str)) def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" @@ -95,9 +121,15 @@ def init(self, key, value): >>> keys = [5, 7, 9] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStoreInit( - self.handle, mx_uint(len(ckeys)), ckeys, cvals)) + use_str_key = _use_str_keys(key) + if use_str_key: + ckeys, cvals = _ctype_str_key_value(key, value) + check_call(_LIB.MXKVStoreInitEx( + self.handle, mx_uint(len(ckeys)), ckeys, cvals)) + else: + ckeys, cvals = _ctype_key_value(key, value) + check_call(_LIB.MXKVStoreInit( + self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): """ Pushes a single or a sequence of key-value pairs into the store. @@ -156,10 +188,18 @@ def push(self, key, value, priority=0): [[ 4. 4. 4.] [ 4. 4. 4.]] """ - ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStorePush( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority))) + use_str_key = _use_str_keys(key) + if use_str_key: + ckeys, cvals = _ctype_str_key_value(key, value) + check_call(_LIB.MXKVStorePushEx( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) + else: + ckeys, cvals = _ctype_key_value(key, value) + check_call(_LIB.MXKVStorePush( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) + def pull(self, key, out=None, priority=0): """ Pulls a single value or a sequence of values from the store. @@ -218,10 +258,17 @@ def pull(self, key, out=None, priority=0): [ 2. 2. 2.]] """ assert(out is not None) - ckeys, cvals = _ctype_key_value(key, out) - check_call(_LIB.MXKVStorePull( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority))) + use_str_key = _use_str_keys(key) + if use_str_key: + ckeys, cvals = _ctype_str_key_value(key, out) + check_call(_LIB.MXKVStorePullEx( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) + else: + ckeys, cvals = _ctype_key_value(key, out) + check_call(_LIB.MXKVStorePull( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) def set_optimizer(self, optimizer): """ Registers an optimizer with the kvstore. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9d60c8615027..bea6437b4c64 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -625,6 +625,21 @@ int MXKVStoreInit(KVStoreHandle handle, API_END(); } +int MXKVStoreInitEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = *static_cast(vals[i]); + } + static_cast(handle)->Init(v_keys, v_vals); + API_END(); +} + int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int* keys, @@ -641,6 +656,22 @@ int MXKVStorePush(KVStoreHandle handle, API_END(); } +int MXKVStorePushEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = *static_cast(vals[i]); + } + static_cast(handle)->Push(v_keys, v_vals, priority); + API_END(); +} + int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, @@ -657,6 +688,22 @@ int MXKVStorePull(KVStoreHandle handle, API_END(); } +int MXKVStorePullEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority); + API_END(); +} + int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index be5662e8a6db..78d4958096cc 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -7,7 +7,6 @@ #include #include #include "./kvstore_local.h" -// #include "./kvstore_device.h" #if MXNET_USE_DIST_KVSTORE #include "./kvstore_dist.h" #endif // MXNET_USE_DIST_KVSTORE diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index caa57a20d46e..db3e6c443185 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -47,6 +47,20 @@ class KVStoreLocal : public KVStore { } } + void Init(const std::vector& str_keys, + const std::vector& values) override { + std::vector keys(str_keys.size()); + for (size_t i = 0; i < str_keys.size(); ++i) { + auto &str_key = str_keys[i]; + CHECK(str_key_dict_.find(str_key) == str_key_dict_.end()) + << "duplicate init of key " << str_key; + auto key = next_str_key_++; + str_key_dict_[str_key] = key; + keys[i] = key; + } + Init(keys, values); + } + void Push(const std::vector& keys, const std::vector& values, int priority) override { @@ -87,6 +101,22 @@ class KVStoreLocal : public KVStore { } } + void Push(const std::vector& str_keys, + const std::vector& values, + int priority) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + Push(keys, values, priority); + } + + void Pull(const std::vector& str_keys, + const std::vector& values, + int priority) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + Pull(keys, values, priority); + } + protected: /** * \brief group values on keys @@ -118,12 +148,27 @@ class KVStoreLocal : public KVStore { } } } + + void LookupKeys(const std::vector& str_keys, + std::vector *keys) { + for (size_t i = 0; i < str_keys.size(); ++i) { + auto &str_key = str_keys[i]; + CHECK(str_key_dict_.find(str_key) != str_key_dict_.end()) + << "key " << str_key << " doesn't exist. Did you init?"; + keys->at(i) = str_key_dict_[str_key]; + } + } + /// reducer and broadcaster Comm* comm_; /// pinned context Context pinned_ctx_; /// \brief buffer for storing local values std::unordered_map local_; + /// key mapping for string -> integer + std::unordered_map str_key_dict_; + /// the next available integer for string->int key mapping + int next_str_key_ = 0; }; } // namespace kvstore } // namespace mxnet diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index dd8149d4822e..87e5e0027241 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -4,6 +4,8 @@ shape = (4, 4) keys = [5, 7, 11] +str_keys = ['b', 'c', 'd'] + def init_kv(): """init kv """ kv = mx.kv.create() @@ -13,6 +15,14 @@ def init_kv(): kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) return kv +def init_kv_with_str(): + """init kv """ + kv = mx.kv.create() + # single + kv.init('a', mx.nd.zeros(shape)) + # list + kv.init(str_keys, [mx.nd.zeros(shape)] * len(keys)) + return kv def check_diff_to_scalar(A, x): """ assert A == x""" @@ -20,59 +30,67 @@ def check_diff_to_scalar(A, x): def test_single_kv_pair(): """single key-value pair push & pull""" + def check_single_kv_pair(kv, key): + kv.push(key, mx.nd.ones(shape)) + val = mx.nd.empty(shape) + kv.pull(key, out = val) + check_diff_to_scalar(val, 1) - kv = init_kv() - kv.push(3, mx.nd.ones(shape)) - val = mx.nd.empty(shape) - kv.pull(3, out = val) - check_diff_to_scalar(val, 1) + check_single_kv_pair(init_kv(), 3) + check_single_kv_pair(init_kv_with_str(), 'a') def test_init(): """test init""" - kv = mx.kv.create() - kv.init(3, mx.nd.ones(shape)*4) - a = mx.nd.zeros(shape) - kv.pull(3, out=a) - check_diff_to_scalar(a, 4) + def check_init(kv, key): + kv.init(key, mx.nd.ones(shape)*4) + a = mx.nd.zeros(shape) + kv.pull(key, out=a) + check_diff_to_scalar(a, 4) + + check_init(mx.kv.create(), 3) + check_init(mx.kv.create(), 'a') def test_list_kv_pair(): """list key-value pair push & pull""" + def check_list_kv_pair(kv, key): + kv.push(key, [mx.nd.ones(shape)*4] * len(key)) + val = [mx.nd.empty(shape)] * len(key) + kv.pull(key, out = val) + for v in val: + check_diff_to_scalar(v, 4) - kv = init_kv() - - kv.push(keys, [mx.nd.ones(shape)*4] * len(keys)) - val = [mx.nd.empty(shape)] * len(keys) - kv.pull(keys, out = val) - for v in val: - check_diff_to_scalar(v, 4) + check_list_kv_pair(init_kv(), keys) + check_list_kv_pair(init_kv_with_str(), str_keys) def test_aggregator(): """aggregate value on muliple devices""" - kv = init_kv() + def check_aggregator(kv, key, key_list): + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] - # devices - num_devs = 4 - devs = [mx.Context('cpu', i) for i in range(num_devs)] + # single + vals = [mx.nd.ones(shape, d) for d in devs] - # single - vals = [mx.nd.ones(shape, d) for d in devs] + kv.push(key, vals) + kv.pull(key, out = vals) - kv.push(3, vals) - kv.pull(3, out = vals) + for v in vals: + check_diff_to_scalar(v, num_devs) - for v in vals: - check_diff_to_scalar(v, num_devs) + # list + vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list) + kv.push(key_list, vals) + kv.pull(key_list, out = vals) - # list - vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys) - kv.push(keys, vals) - kv.pull(keys, out = vals) + for vv in vals: + for v in vv: + check_diff_to_scalar(v, num_devs * 2.0) - for vv in vals: - for v in vv: - check_diff_to_scalar(v, num_devs * 2.0) + check_aggregator(init_kv(), 3, keys) + check_aggregator(init_kv_with_str(), 'a', str_keys) def updater(key, recv, local): @@ -82,34 +100,41 @@ def updater(key, recv, local): def test_updater(dev = 'cpu'): """updater""" - kv = init_kv() - kv._set_updater(updater) + def check_updater(kv, key, key_list): + # devices + num_devs = 4 + devs = [mx.Context(dev, i) for i in range(num_devs)] - # devices - num_devs = 4 - devs = [mx.Context(dev, i) for i in range(num_devs)] + # single + vals = [mx.nd.ones(shape, d) for d in devs] - # single - vals = [mx.nd.ones(shape, d) for d in devs] + kv.push(key, vals) + kv.pull(key, out = vals) - kv.push(3, vals) - kv.pull(3, out = vals) + for v in vals: + check_diff_to_scalar(v, num_devs) - for v in vals: - check_diff_to_scalar(v, num_devs) + # list + vals = [[mx.nd.ones(shape, d) for d in devs]] * len(key_list) - # list - vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys) + num_push = 4 + for i in range(num_push): + kv.push(key_list, vals) + + kv.pull(key_list, out = vals) + + for vv in vals: + for v in vv: + check_diff_to_scalar(v, num_devs * num_push) - num_push = 4 - for i in range(num_push): - kv.push(keys, vals) + kv = init_kv() + kv._set_updater(updater) + check_updater(kv, 3, keys) - kv.pull(keys, out = vals) + str_kv = init_kv_with_str() + str_kv._set_updater(updater) + check_updater(str_kv, 'a', str_keys) - for vv in vals: - for v in vv: - check_diff_to_scalar(v, num_devs * num_push) def test_get_type(): kvtype = 'local_allreduce_cpu' From 946abc67a58243884428d5eede3e1014ca1ea1cc Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Tue, 20 Jun 2017 15:57:22 +0000 Subject: [PATCH 2/8] update model/module.py --- python/mxnet/model.py | 24 ++++++++++++++---------- python/mxnet/module/module.py | 5 +++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 189f301e91f7..a476d84efd92 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -80,34 +80,37 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): - kvstore.init(idx, arg_params[param_names[idx]]) + name = param_names[idx] + kvstore.init(name, arg_params[name]) if update_on_kvstore: - kvstore.pull(idx, param_on_devs, priority=-idx) + kvstore.pull(name, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + name = param_names[index] # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) + kvstore.push(name, grad_list, priority=-index) # pull back the weights - kvstore.pull(index, arg_list, priority=-index) + kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None): + kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue if kvstore: + name = param_names[index] # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) + kvstore.push(name, grad_list, priority=-index) # pull back the sum gradients, to the same locations. - kvstore.pull(index, grad_list, priority=-index) + kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) @@ -245,13 +248,14 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, if update_on_kvstore: _update_params_on_kvstore(executor_manager.param_arrays, executor_manager.grad_arrays, - kvstore) + kvstore, executor_manager.param_names) else: _update_params(executor_manager.param_arrays, executor_manager.grad_arrays, updater=updater, num_device=len(ctx), - kvstore=kvstore) + kvstore=kvstore, + param_names=executor_manager.param_names) if monitor is not None: monitor.toc_print() diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index fef5c507d7e8..249122311274 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -572,13 +572,14 @@ def update(self): if self._update_on_kvstore: _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, - self._kvstore) + self._kvstore, self._exec_group.param_names) else: _update_params(self._exec_group.param_arrays, self._exec_group.grad_arrays, updater=self._updater, num_device=len(self._context), - kvstore=self._kvstore) + kvstore=self._kvstore, + param_names=self._exec_group.param_names) def get_outputs(self, merge_multi_context=True): """Gets outputs of the previous forward computation. From de8f0b8090e1666977fd97897f5f69e7f8599f7c Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Tue, 20 Jun 2017 16:05:16 +0000 Subject: [PATCH 3/8] fix lint --- src/kvstore/kvstore_local.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index db3e6c443185..dc5f7b786244 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include "./comm.h" From 3a277a40736671b6387bea8c1830626287d1de20 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 21 Jun 2017 18:09:36 +0000 Subject: [PATCH 4/8] remove int keys in kvstore --- python/mxnet/kvstore.py | 108 +++++++++++++--------------------------- 1 file changed, 35 insertions(+), 73 deletions(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index dc6c1bc06e3f..6af610fa44c7 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -10,31 +10,6 @@ from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt -def _ctype_key_value(keys, vals): - """ - Returns ctype arrays for the key-value args. For internal use. - """ - if isinstance(keys, int): - if isinstance(vals, NDArray): - return (c_array(ctypes.c_int, [keys]), - c_array(NDArrayHandle, [vals.handle])) - else: - for value in vals: - assert(isinstance(value, NDArray)) - return (c_array(ctypes.c_int, [keys] * len(vals)), - c_array(NDArrayHandle, [value.handle for value in vals])) - else: - assert(len(keys) == len(vals)) - for k in keys: - assert(isinstance(k, int)) - c_keys = [] - c_vals = [] - for key, val in zip(keys, vals): - c_key_i, c_val_i = _ctype_key_value(key, val) - c_keys += c_key_i - c_vals += c_val_i - return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) - def _ctype_str_key_value(keys, vals): names = [] if isinstance(keys, str): @@ -59,8 +34,13 @@ def _ctype_str_key_value(keys, vals): c_vals += c_val_i return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) -def _use_str_keys(key): - return isinstance(key, str) or (isinstance(key, (list, tuple)) and isinstance(key[0], str)) +def _cast_to_str_keys(keys): + if isinstance(keys, int): + return str(keys) + if isinstance(keys, (list, tuple)): + for i, key in keys: + keys[i] = str(key) if isinstance(key, int) else key + return keys def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" @@ -100,7 +80,7 @@ def init(self, key, value): Parameters ---------- - key : int or sequence of int + key : str or sequence of str The keys. value : NDArray or sequence of NDArray Values corresponding to the keys. @@ -110,26 +90,20 @@ def init(self, key, value): >>> # init a single key-value pair >>> shape = (2,3) >>> kv = mx.kv.create('local') - >>> kv.init(3, mx.nd.ones(shape)*2) + >>> kv.init('3', mx.nd.ones(shape)*2) >>> a = mx.nd.zeros(shape) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # init a list of key-value pairs - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - use_str_key = _use_str_keys(key) - if use_str_key: - ckeys, cvals = _ctype_str_key_value(key, value) - check_call(_LIB.MXKVStoreInitEx( - self.handle, mx_uint(len(ckeys)), ckeys, cvals)) - else: - ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStoreInit( - self.handle, mx_uint(len(ckeys)), ckeys, cvals)) + key = _cast_to_str_keys(key) + ckeys, cvals = _ctype_str_key_value(key, value) + check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): """ Pushes a single or a sequence of key-value pairs into the store. @@ -142,7 +116,7 @@ def push(self, key, value, priority=0): Parameters ---------- - key : int or list of int + key : str or list of str Keys. value : NDArray or list of NDArray or list of list of NDArray @@ -156,8 +130,8 @@ def push(self, key, value, priority=0): Examples -------- >>> # push a single key-value pair - >>> kv.push(3, mx.nd.ones(shape)*8) - >>> kv.pull(3, out=a) # pull out the value + >>> kv.push('3', mx.nd.ones(shape)*8) + >>> kv.pull('3', out=a) # pull out the value >>> print a.asnumpy() [[ 8. 8. 8.] [ 8. 8. 8.]] @@ -165,8 +139,8 @@ def push(self, key, value, priority=0): >>> # aggregate the value and the push >>> gpus = [mx.gpu(i) for i in range(4)] >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] - >>> kv.push(3, b) - >>> kv.pull(3, out=a) + >>> kv.push('3', b) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] @@ -188,17 +162,11 @@ def push(self, key, value, priority=0): [[ 4. 4. 4.] [ 4. 4. 4.]] """ - use_str_key = _use_str_keys(key) - if use_str_key: - ckeys, cvals = _ctype_str_key_value(key, value) - check_call(_LIB.MXKVStorePushEx( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority))) - else: - ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStorePush( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority))) + key = _cast_to_str_keys(key) + ckeys, cvals = _ctype_str_key_value(key, value) + check_call(_LIB.MXKVStorePushEx( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) def pull(self, key, out=None, priority=0): @@ -230,21 +198,21 @@ def pull(self, key, out=None, priority=0): -------- >>> # pull a single key-value pair >>> a = mx.nd.zeros(shape) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # pull into multiple devices >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] - >>> kv.pull(3, out=b) + >>> kv.pull('3', out=b) >>> print b[1].asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # pull a list of key-value pairs. >>> # On single device - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> b = [mx.nd.zeros(shape)]*len(keys) >>> kv.pull(keys, out=b) >>> print b[1].asnumpy() @@ -258,17 +226,11 @@ def pull(self, key, out=None, priority=0): [ 2. 2. 2.]] """ assert(out is not None) - use_str_key = _use_str_keys(key) - if use_str_key: - ckeys, cvals = _ctype_str_key_value(key, out) - check_call(_LIB.MXKVStorePullEx( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority))) - else: - ckeys, cvals = _ctype_key_value(key, out) - check_call(_LIB.MXKVStorePull( - self.handle, mx_uint(len(ckeys)), ckeys, cvals, - ctypes.c_int(priority))) + key = _cast_to_str_keys(key) + ckeys, cvals = _ctype_str_key_value(key, out) + check_call(_LIB.MXKVStorePullEx( + self.handle, mx_uint(len(ckeys)), ckeys, cvals, + ctypes.c_int(priority))) def set_optimizer(self, optimizer): """ Registers an optimizer with the kvstore. @@ -395,13 +357,13 @@ def _set_updater(self, updater): ... print "update on key: %d" % key ... stored += input * 2 >>> kv._set_updater(update) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] - >>> kv.push(3, mx.nd.ones(shape)) + >>> kv.push('3', mx.nd.ones(shape)) update on key: 3 - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 6. 6. 6.] [ 6. 6. 6.]] From 3b7e73037c1020da9fa7b9282e33664889a8c22e Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 21 Jun 2017 18:24:10 +0000 Subject: [PATCH 5/8] update cast to str function --- python/mxnet/kvstore.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 6af610fa44c7..fea01234d9af 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -35,12 +35,14 @@ def _ctype_str_key_value(keys, vals): return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) def _cast_to_str_keys(keys): + if isinstance(keys, str): + return keys if isinstance(keys, int): return str(keys) - if isinstance(keys, (list, tuple)): - for i, key in keys: - keys[i] = str(key) if isinstance(key, int) else key - return keys + str_keys = [] + for key in keys: + str_keys.append(str(key) if isinstance(key, int) else key) + return str_keys def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" From d1b92fcde367bc2c827dfba2f566eee19a6f9acf Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 22 Jun 2017 05:11:40 +0000 Subject: [PATCH 6/8] remove _cast_to_str_keys --- python/mxnet/kvstore.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index fea01234d9af..9a2fe644c0e7 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -10,8 +10,10 @@ from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt -def _ctype_str_key_value(keys, vals): +def _ctype_key_value(keys, vals): names = [] + if isinstance(keys, int): + keys = str(keys) if isinstance(keys, str): if isinstance(vals, NDArray): names.append(c_str(keys)) @@ -24,26 +26,14 @@ def _ctype_str_key_value(keys, vals): c_array(NDArrayHandle, [value.handle for value in vals])) else: assert(len(keys) == len(vals)) - for k in keys: - assert(isinstance(k, str)) c_keys = [] c_vals = [] for key, val in zip(keys, vals): - c_key_i, c_val_i = _ctype_str_key_value(key, val) + c_key_i, c_val_i = _ctype_key_value(key, val) c_keys += c_key_i c_vals += c_val_i return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) -def _cast_to_str_keys(keys): - if isinstance(keys, str): - return keys - if isinstance(keys, int): - return str(keys) - str_keys = [] - for key in keys: - str_keys.append(str(key) if isinstance(key, int) else key) - return str_keys - def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" def updater_handle(key, lhs_handle, rhs_handle, _): @@ -103,8 +93,7 @@ def init(self, key, value): >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, value) + ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): @@ -164,8 +153,7 @@ def push(self, key, value, priority=0): [[ 4. 4. 4.] [ 4. 4. 4.]] """ - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, value) + ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStorePushEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) @@ -228,8 +216,7 @@ def pull(self, key, out=None, priority=0): [ 2. 2. 2.]] """ assert(out is not None) - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, out) + ckeys, cvals = _ctype_key_value(key, out) check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) From 9e2562fa68919d10187207799bcce76c91b51dbc Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 22 Jun 2017 05:49:49 +0000 Subject: [PATCH 7/8] fix lint --- python/mxnet/kvstore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 9a2fe644c0e7..6be4c9868973 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -13,7 +13,7 @@ def _ctype_key_value(keys, vals): names = [] if isinstance(keys, int): - keys = str(keys) + keys = str(keys) if isinstance(keys, str): if isinstance(vals, NDArray): names.append(c_str(keys)) From a4972b9f83ebaca77ddfe307799c0868e34e58dc Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 22 Jun 2017 16:33:50 +0000 Subject: [PATCH 8/8] always cast to str --- python/mxnet/kvstore.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 6be4c9868973..10b83b04db97 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -11,20 +11,7 @@ from . import optimizer as opt def _ctype_key_value(keys, vals): - names = [] - if isinstance(keys, int): - keys = str(keys) - if isinstance(keys, str): - if isinstance(vals, NDArray): - names.append(c_str(keys)) - return (c_array(ctypes.c_char_p, names), - c_array(NDArrayHandle, [vals.handle])) - else: - for value in vals: - assert(isinstance(value, NDArray)) - return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), - c_array(NDArrayHandle, [value.handle for value in vals])) - else: + if isinstance(keys, (tuple, list)): assert(len(keys) == len(vals)) c_keys = [] c_vals = [] @@ -33,6 +20,17 @@ def _ctype_key_value(keys, vals): c_keys += c_key_i c_vals += c_val_i return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) + names = [] + keys = str(keys) + if isinstance(vals, NDArray): + names.append(c_str(keys)) + return (c_array(ctypes.c_char_p, names), + c_array(NDArrayHandle, [vals.handle])) + else: + for value in vals: + assert(isinstance(value, NDArray)) + return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), + c_array(NDArrayHandle, [value.handle for value in vals])) def _updater_wrapper(updater): """A wrapper for the user-defined handle."""