diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a0e842c21765..b8f8411353bf 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1311,6 +1311,19 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle, const int* keys, NDArrayHandle* vals); +/*! + * \brief Init a list of (key,value) pairs in kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals); + /*! * \brief Push a list of (key,value) pairs to kvstore * \param handle handle to the kvstore @@ -1325,6 +1338,20 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle, const int* keys, NDArrayHandle* vals, int priority); +/*! + * \brief Push a list of (key,value) pairs to kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority); /*! * \brief pull a list of (key, value) pairs from the kvstore * \param handle handle to the kvstore @@ -1339,6 +1366,20 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle, const int* keys, NDArrayHandle* vals, int priority); +/*! + * \brief pull a list of (key, value) pairs from the kvstore, where each key is a string + * \param handle handle to the kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \param priority the priority of the action + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority); /*! * \brief user-defined updater for the kvstore * It's this updater's responsibility to delete \a recv and \a local diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index dafaf1bf9cab..a77f653d492c 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -63,6 +63,13 @@ class KVStore { */ virtual void Init(const std::vector& keys, const std::vector& values) = 0; + /*! + * \brief Initialize a list of key-value pair to the store. + * \param keys a list of unique keys in string format + * \param values a list of values + */ + virtual void Init(const std::vector& str_keys, + const std::vector& values) = 0; /*! * \brief push a list of key-value pairs into the store * @@ -102,6 +109,16 @@ class KVStore { virtual void Push(const std::vector& keys, const std::vector& values, int priority = 0) = 0; + + /*! + * \brief push a list of key-value pairs into the store + * \param keys the list of keys in string format + * \param values the list of values + * \param priority Priority of the action. + */ + virtual void Push(const std::vector& str_keys, + const std::vector& values, + int priority = 0) = 0; /*! * \brief pull a list of key-value pairs from the store * @@ -128,6 +145,16 @@ class KVStore { virtual void Pull(const std::vector& keys, const std::vector& values, int priority = 0) = 0; + /*! + * \brief pull a list of key-value pairs from the store + * \param keys the list of keys in string format + * \param values the list of buffers for the pulled data, they should be preallocated + * \param priority Priority of the action. + */ + virtual void Pull(const std::vector& str_keys, + const std::vector& values, + int priority = 0) = 0; + /** * \brief the prototype of user-defined updater diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ab07421caffd..10b83b04db97 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -11,30 +11,26 @@ from . import optimizer as opt def _ctype_key_value(keys, vals): - """ - Returns ctype arrays for the key-value args. For internal use. - """ - if isinstance(keys, int): - if isinstance(vals, NDArray): - return (c_array(ctypes.c_int, [keys]), - c_array(NDArrayHandle, [vals.handle])) - else: - for value in vals: - assert(isinstance(value, NDArray)) - return (c_array(ctypes.c_int, [keys] * len(vals)), - c_array(NDArrayHandle, [value.handle for value in vals])) - else: + if isinstance(keys, (tuple, list)): assert(len(keys) == len(vals)) - for k in keys: - assert(isinstance(k, int)) c_keys = [] c_vals = [] for key, val in zip(keys, vals): c_key_i, c_val_i = _ctype_key_value(key, val) c_keys += c_key_i c_vals += c_val_i - return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) - + return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) + names = [] + keys = str(keys) + if isinstance(vals, NDArray): + names.append(c_str(keys)) + return (c_array(ctypes.c_char_p, names), + c_array(NDArrayHandle, [vals.handle])) + else: + for value in vals: + assert(isinstance(value, NDArray)) + return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), + c_array(NDArrayHandle, [value.handle for value in vals])) def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" @@ -74,7 +70,7 @@ def init(self, key, value): Parameters ---------- - key : int or sequence of int + key : str or sequence of str The keys. value : NDArray or sequence of NDArray Values corresponding to the keys. @@ -84,20 +80,19 @@ def init(self, key, value): >>> # init a single key-value pair >>> shape = (2,3) >>> kv = mx.kv.create('local') - >>> kv.init(3, mx.nd.ones(shape)*2) + >>> kv.init('3', mx.nd.ones(shape)*2) >>> a = mx.nd.zeros(shape) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # init a list of key-value pairs - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStoreInit( - self.handle, mx_uint(len(ckeys)), ckeys, cvals)) + check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): """ Pushes a single or a sequence of key-value pairs into the store. @@ -110,7 +105,7 @@ def push(self, key, value, priority=0): Parameters ---------- - key : int or list of int + key : str or list of str Keys. value : NDArray or list of NDArray or list of list of NDArray @@ -124,8 +119,8 @@ def push(self, key, value, priority=0): Examples -------- >>> # push a single key-value pair - >>> kv.push(3, mx.nd.ones(shape)*8) - >>> kv.pull(3, out=a) # pull out the value + >>> kv.push('3', mx.nd.ones(shape)*8) + >>> kv.pull('3', out=a) # pull out the value >>> print a.asnumpy() [[ 8. 8. 8.] [ 8. 8. 8.]] @@ -133,8 +128,8 @@ def push(self, key, value, priority=0): >>> # aggregate the value and the push >>> gpus = [mx.gpu(i) for i in range(4)] >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] - >>> kv.push(3, b) - >>> kv.pull(3, out=a) + >>> kv.push('3', b) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] @@ -157,10 +152,11 @@ def push(self, key, value, priority=0): [ 4. 4. 4.]] """ ckeys, cvals = _ctype_key_value(key, value) - check_call(_LIB.MXKVStorePush( + check_call(_LIB.MXKVStorePushEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + def pull(self, key, out=None, priority=0): """ Pulls a single value or a sequence of values from the store. @@ -190,21 +186,21 @@ def pull(self, key, out=None, priority=0): -------- >>> # pull a single key-value pair >>> a = mx.nd.zeros(shape) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # pull into multiple devices >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] - >>> kv.pull(3, out=b) + >>> kv.pull('3', out=b) >>> print b[1].asnumpy() [[ 2. 2. 2.] [ 2. 2. 2.]] >>> # pull a list of key-value pairs. >>> # On single device - >>> keys = [5, 7, 9] + >>> keys = ['5', '7', '9'] >>> b = [mx.nd.zeros(shape)]*len(keys) >>> kv.pull(keys, out=b) >>> print b[1].asnumpy() @@ -219,7 +215,7 @@ def pull(self, key, out=None, priority=0): """ assert(out is not None) ckeys, cvals = _ctype_key_value(key, out) - check_call(_LIB.MXKVStorePull( + check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) @@ -348,13 +344,13 @@ def _set_updater(self, updater): ... print "update on key: %d" % key ... stored += input * 2 >>> kv._set_updater(update) - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 4. 4. 4.] [ 4. 4. 4.]] - >>> kv.push(3, mx.nd.ones(shape)) + >>> kv.push('3', mx.nd.ones(shape)) update on key: 3 - >>> kv.pull(3, out=a) + >>> kv.pull('3', out=a) >>> print a.asnumpy() [[ 6. 6. 6.] [ 6. 6. 6.]] diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 189f301e91f7..a476d84efd92 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -80,34 +80,37 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): - kvstore.init(idx, arg_params[param_names[idx]]) + name = param_names[idx] + kvstore.init(name, arg_params[name]) if update_on_kvstore: - kvstore.pull(idx, param_on_devs, priority=-idx) + kvstore.pull(name, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + name = param_names[index] # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) + kvstore.push(name, grad_list, priority=-index) # pull back the weights - kvstore.pull(index, arg_list, priority=-index) + kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None): + kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue if kvstore: + name = param_names[index] # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) + kvstore.push(name, grad_list, priority=-index) # pull back the sum gradients, to the same locations. - kvstore.pull(index, grad_list, priority=-index) + kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) @@ -245,13 +248,14 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, if update_on_kvstore: _update_params_on_kvstore(executor_manager.param_arrays, executor_manager.grad_arrays, - kvstore) + kvstore, executor_manager.param_names) else: _update_params(executor_manager.param_arrays, executor_manager.grad_arrays, updater=updater, num_device=len(ctx), - kvstore=kvstore) + kvstore=kvstore, + param_names=executor_manager.param_names) if monitor is not None: monitor.toc_print() diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index fef5c507d7e8..249122311274 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -572,13 +572,14 @@ def update(self): if self._update_on_kvstore: _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, - self._kvstore) + self._kvstore, self._exec_group.param_names) else: _update_params(self._exec_group.param_arrays, self._exec_group.grad_arrays, updater=self._updater, num_device=len(self._context), - kvstore=self._kvstore) + kvstore=self._kvstore, + param_names=self._exec_group.param_names) def get_outputs(self, merge_multi_context=True): """Gets outputs of the previous forward computation. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9d60c8615027..bea6437b4c64 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -625,6 +625,21 @@ int MXKVStoreInit(KVStoreHandle handle, API_END(); } +int MXKVStoreInitEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = *static_cast(vals[i]); + } + static_cast(handle)->Init(v_keys, v_vals); + API_END(); +} + int MXKVStorePush(KVStoreHandle handle, mx_uint num, const int* keys, @@ -641,6 +656,22 @@ int MXKVStorePush(KVStoreHandle handle, API_END(); } +int MXKVStorePushEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = *static_cast(vals[i]); + } + static_cast(handle)->Push(v_keys, v_vals, priority); + API_END(); +} + int MXKVStorePull(KVStoreHandle handle, mx_uint num, const int* keys, @@ -657,6 +688,22 @@ int MXKVStorePull(KVStoreHandle handle, API_END(); } +int MXKVStorePullEx(KVStoreHandle handle, + mx_uint num, + const char** keys, + NDArrayHandle* vals, + int priority) { + API_BEGIN(); + std::vector v_keys(num); + std::vector v_vals(num); + for (mx_uint i = 0; i < num; ++i) { + v_keys[i] = keys[i]; + v_vals[i] = static_cast(vals[i]); + } + static_cast(handle)->Pull(v_keys, v_vals, priority); + API_END(); +} + int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater, void* updater_handle) { diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index be5662e8a6db..78d4958096cc 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -7,7 +7,6 @@ #include #include #include "./kvstore_local.h" -// #include "./kvstore_device.h" #if MXNET_USE_DIST_KVSTORE #include "./kvstore_dist.h" #endif // MXNET_USE_DIST_KVSTORE diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index caa57a20d46e..dc5f7b786244 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include "./comm.h" @@ -47,6 +48,20 @@ class KVStoreLocal : public KVStore { } } + void Init(const std::vector& str_keys, + const std::vector& values) override { + std::vector keys(str_keys.size()); + for (size_t i = 0; i < str_keys.size(); ++i) { + auto &str_key = str_keys[i]; + CHECK(str_key_dict_.find(str_key) == str_key_dict_.end()) + << "duplicate init of key " << str_key; + auto key = next_str_key_++; + str_key_dict_[str_key] = key; + keys[i] = key; + } + Init(keys, values); + } + void Push(const std::vector& keys, const std::vector& values, int priority) override { @@ -87,6 +102,22 @@ class KVStoreLocal : public KVStore { } } + void Push(const std::vector& str_keys, + const std::vector& values, + int priority) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + Push(keys, values, priority); + } + + void Pull(const std::vector& str_keys, + const std::vector& values, + int priority) override { + std::vector keys(str_keys.size()); + LookupKeys(str_keys, &keys); + Pull(keys, values, priority); + } + protected: /** * \brief group values on keys @@ -118,12 +149,27 @@ class KVStoreLocal : public KVStore { } } } + + void LookupKeys(const std::vector& str_keys, + std::vector *keys) { + for (size_t i = 0; i < str_keys.size(); ++i) { + auto &str_key = str_keys[i]; + CHECK(str_key_dict_.find(str_key) != str_key_dict_.end()) + << "key " << str_key << " doesn't exist. Did you init?"; + keys->at(i) = str_key_dict_[str_key]; + } + } + /// reducer and broadcaster Comm* comm_; /// pinned context Context pinned_ctx_; /// \brief buffer for storing local values std::unordered_map local_; + /// key mapping for string -> integer + std::unordered_map str_key_dict_; + /// the next available integer for string->int key mapping + int next_str_key_ = 0; }; } // namespace kvstore } // namespace mxnet diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index dd8149d4822e..87e5e0027241 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -4,6 +4,8 @@ shape = (4, 4) keys = [5, 7, 11] +str_keys = ['b', 'c', 'd'] + def init_kv(): """init kv """ kv = mx.kv.create() @@ -13,6 +15,14 @@ def init_kv(): kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) return kv +def init_kv_with_str(): + """init kv """ + kv = mx.kv.create() + # single + kv.init('a', mx.nd.zeros(shape)) + # list + kv.init(str_keys, [mx.nd.zeros(shape)] * len(keys)) + return kv def check_diff_to_scalar(A, x): """ assert A == x""" @@ -20,59 +30,67 @@ def check_diff_to_scalar(A, x): def test_single_kv_pair(): """single key-value pair push & pull""" + def check_single_kv_pair(kv, key): + kv.push(key, mx.nd.ones(shape)) + val = mx.nd.empty(shape) + kv.pull(key, out = val) + check_diff_to_scalar(val, 1) - kv = init_kv() - kv.push(3, mx.nd.ones(shape)) - val = mx.nd.empty(shape) - kv.pull(3, out = val) - check_diff_to_scalar(val, 1) + check_single_kv_pair(init_kv(), 3) + check_single_kv_pair(init_kv_with_str(), 'a') def test_init(): """test init""" - kv = mx.kv.create() - kv.init(3, mx.nd.ones(shape)*4) - a = mx.nd.zeros(shape) - kv.pull(3, out=a) - check_diff_to_scalar(a, 4) + def check_init(kv, key): + kv.init(key, mx.nd.ones(shape)*4) + a = mx.nd.zeros(shape) + kv.pull(key, out=a) + check_diff_to_scalar(a, 4) + + check_init(mx.kv.create(), 3) + check_init(mx.kv.create(), 'a') def test_list_kv_pair(): """list key-value pair push & pull""" + def check_list_kv_pair(kv, key): + kv.push(key, [mx.nd.ones(shape)*4] * len(key)) + val = [mx.nd.empty(shape)] * len(key) + kv.pull(key, out = val) + for v in val: + check_diff_to_scalar(v, 4) - kv = init_kv() - - kv.push(keys, [mx.nd.ones(shape)*4] * len(keys)) - val = [mx.nd.empty(shape)] * len(keys) - kv.pull(keys, out = val) - for v in val: - check_diff_to_scalar(v, 4) + check_list_kv_pair(init_kv(), keys) + check_list_kv_pair(init_kv_with_str(), str_keys) def test_aggregator(): """aggregate value on muliple devices""" - kv = init_kv() + def check_aggregator(kv, key, key_list): + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] - # devices - num_devs = 4 - devs = [mx.Context('cpu', i) for i in range(num_devs)] + # single + vals = [mx.nd.ones(shape, d) for d in devs] - # single - vals = [mx.nd.ones(shape, d) for d in devs] + kv.push(key, vals) + kv.pull(key, out = vals) - kv.push(3, vals) - kv.pull(3, out = vals) + for v in vals: + check_diff_to_scalar(v, num_devs) - for v in vals: - check_diff_to_scalar(v, num_devs) + # list + vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(key_list) + kv.push(key_list, vals) + kv.pull(key_list, out = vals) - # list - vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys) - kv.push(keys, vals) - kv.pull(keys, out = vals) + for vv in vals: + for v in vv: + check_diff_to_scalar(v, num_devs * 2.0) - for vv in vals: - for v in vv: - check_diff_to_scalar(v, num_devs * 2.0) + check_aggregator(init_kv(), 3, keys) + check_aggregator(init_kv_with_str(), 'a', str_keys) def updater(key, recv, local): @@ -82,34 +100,41 @@ def updater(key, recv, local): def test_updater(dev = 'cpu'): """updater""" - kv = init_kv() - kv._set_updater(updater) + def check_updater(kv, key, key_list): + # devices + num_devs = 4 + devs = [mx.Context(dev, i) for i in range(num_devs)] - # devices - num_devs = 4 - devs = [mx.Context(dev, i) for i in range(num_devs)] + # single + vals = [mx.nd.ones(shape, d) for d in devs] - # single - vals = [mx.nd.ones(shape, d) for d in devs] + kv.push(key, vals) + kv.pull(key, out = vals) - kv.push(3, vals) - kv.pull(3, out = vals) + for v in vals: + check_diff_to_scalar(v, num_devs) - for v in vals: - check_diff_to_scalar(v, num_devs) + # list + vals = [[mx.nd.ones(shape, d) for d in devs]] * len(key_list) - # list - vals = [[mx.nd.ones(shape, d) for d in devs]] * len(keys) + num_push = 4 + for i in range(num_push): + kv.push(key_list, vals) + + kv.pull(key_list, out = vals) + + for vv in vals: + for v in vv: + check_diff_to_scalar(v, num_devs * num_push) - num_push = 4 - for i in range(num_push): - kv.push(keys, vals) + kv = init_kv() + kv._set_updater(updater) + check_updater(kv, 3, keys) - kv.pull(keys, out = vals) + str_kv = init_kv_with_str() + str_kv._set_updater(updater) + check_updater(str_kv, 'a', str_keys) - for vv in vals: - for v in vv: - check_diff_to_scalar(v, num_devs * num_push) def test_get_type(): kvtype = 'local_allreduce_cpu'