Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support str key type in kvstore #6765

Merged
merged 16 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,19 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle,
const int* keys,
NDArrayHandle* vals);

/*!
* \brief Init a list of (key,value) pairs in kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals);

/*!
* \brief Push a list of (key,value) pairs to kvstore
* \param handle handle to the kvstore
Expand All @@ -1325,6 +1338,20 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle,
const int* keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief Push a list of (key,value) pairs to kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief pull a list of (key, value) pairs from the kvstore
* \param handle handle to the kvstore
Expand All @@ -1339,6 +1366,20 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle,
const int* keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
Expand Down
27 changes: 27 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class KVStore {
*/
virtual void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) = 0;
/*!
* \brief Initialize a list of key-value pair to the store.
* \param keys a list of unique keys in string format
* \param values a list of values
*/
virtual void Init(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values) = 0;
/*!
* \brief push a list of key-value pairs into the store
*
Expand Down Expand Up @@ -102,6 +109,16 @@ class KVStore {
virtual void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority = 0) = 0;

/*!
* \brief push a list of key-value pairs into the store
* \param keys the list of keys in string format
* \param values the list of values
* \param priority Priority of the action.
*/
virtual void Push(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values,
int priority = 0) = 0;
/*!
* \brief pull a list of key-value pairs from the store
*
Expand All @@ -128,6 +145,16 @@ class KVStore {
virtual void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;
/*!
* \brief pull a list of key-value pairs from the store
* \param keys the list of keys in string format
* \param values the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
*/
virtual void Pull(const std::vector<std::string>& str_keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;


/**
* \brief the prototype of user-defined updater
Expand Down
75 changes: 43 additions & 32 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,39 @@
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt

def _ctype_key_value(keys, vals):
"""
Returns ctype arrays for the key-value args. For internal use.
"""
if isinstance(keys, int):
def _ctype_str_key_value(keys, vals):
names = []
if isinstance(keys, str):
if isinstance(vals, NDArray):
return (c_array(ctypes.c_int, [keys]),
names.append(c_str(keys))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for _cast_to_str_keys

always cast to str directly.
names.append(c_str(str(keys)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

return (c_array(ctypes.c_char_p, names),
c_array(NDArrayHandle, [vals.handle]))
else:
for value in vals:
assert(isinstance(value, NDArray))
return (c_array(ctypes.c_int, [keys] * len(vals)),
return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)),
c_array(NDArrayHandle, [value.handle for value in vals]))
else:
assert(len(keys) == len(vals))
for k in keys:
assert(isinstance(k, int))
assert(isinstance(k, str))
c_keys = []
c_vals = []
for key, val in zip(keys, vals):
c_key_i, c_val_i = _ctype_key_value(key, val)
c_key_i, c_val_i = _ctype_str_key_value(key, val)
c_keys += c_key_i
c_vals += c_val_i
return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals))
return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals))

def _cast_to_str_keys(keys):
if isinstance(keys, str):
return keys
if isinstance(keys, int):
return str(keys)
str_keys = []
for key in keys:
str_keys.append(str(key) if isinstance(key, int) else key)
return str_keys

def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
Expand Down Expand Up @@ -74,7 +82,7 @@ def init(self, key, value):

Parameters
----------
key : int or sequence of int
key : str or sequence of str
The keys.
value : NDArray or sequence of NDArray
Values corresponding to the keys.
Expand All @@ -84,20 +92,20 @@ def init(self, key, value):
>>> # init a single key-value pair
>>> shape = (2,3)
>>> kv = mx.kv.create('local')
>>> kv.init(3, mx.nd.ones(shape)*2)
>>> kv.init('3', mx.nd.ones(shape)*2)
>>> a = mx.nd.zeros(shape)
>>> kv.pull(3, out=a)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]

>>> # init a list of key-value pairs
>>> keys = [5, 7, 9]
>>> keys = ['5', '7', '9']
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
"""
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStoreInit(
self.handle, mx_uint(len(ckeys)), ckeys, cvals))
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, value)
check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals))

def push(self, key, value, priority=0):
""" Pushes a single or a sequence of key-value pairs into the store.
Expand All @@ -110,7 +118,7 @@ def push(self, key, value, priority=0):

Parameters
----------
key : int or list of int
key : str or list of str
Keys.

value : NDArray or list of NDArray or list of list of NDArray
Expand All @@ -124,17 +132,17 @@ def push(self, key, value, priority=0):
Examples
--------
>>> # push a single key-value pair
>>> kv.push(3, mx.nd.ones(shape)*8)
>>> kv.pull(3, out=a) # pull out the value
>>> kv.push('3', mx.nd.ones(shape)*8)
>>> kv.pull('3', out=a) # pull out the value
>>> print a.asnumpy()
[[ 8. 8. 8.]
[ 8. 8. 8.]]

>>> # aggregate the value and the push
>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(3, b)
>>> kv.pull(3, out=a)
>>> kv.push('3', b)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
Expand All @@ -156,11 +164,13 @@ def push(self, key, value, priority=0):
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStorePush(
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, value)
check_call(_LIB.MXKVStorePushEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))


def pull(self, key, out=None, priority=0):
""" Pulls a single value or a sequence of values from the store.

Expand Down Expand Up @@ -190,21 +200,21 @@ def pull(self, key, out=None, priority=0):
--------
>>> # pull a single key-value pair
>>> a = mx.nd.zeros(shape)
>>> kv.pull(3, out=a)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]

>>> # pull into multiple devices
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pull(3, out=b)
>>> kv.pull('3', out=b)
>>> print b[1].asnumpy()
[[ 2. 2. 2.]
[ 2. 2. 2.]]

>>> # pull a list of key-value pairs.
>>> # On single device
>>> keys = [5, 7, 9]
>>> keys = ['5', '7', '9']
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out=b)
>>> print b[1].asnumpy()
Expand All @@ -218,8 +228,9 @@ def pull(self, key, out=None, priority=0):
[ 2. 2. 2.]]
"""
assert(out is not None)
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePull(
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, out)
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))

Expand Down Expand Up @@ -348,13 +359,13 @@ def _set_updater(self, updater):
... print "update on key: %d" % key
... stored += input * 2
>>> kv._set_updater(update)
>>> kv.pull(3, out=a)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 4. 4. 4.]
[ 4. 4. 4.]]
>>> kv.push(3, mx.nd.ones(shape))
>>> kv.push('3', mx.nd.ones(shape))
update on key: 3
>>> kv.pull(3, out=a)
>>> kv.pull('3', out=a)
>>> print a.asnumpy()
[[ 6. 6. 6.]
[ 6. 6. 6.]]
Expand Down
24 changes: 14 additions & 10 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,34 +80,37 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
update_on_kvstore):
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
kvstore.init(idx, arg_params[param_names[idx]])
name = param_names[idx]
kvstore.init(name, arg_params[name])

if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)
kvstore.pull(name, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore):
def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
name = param_names[index]
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
kvstore.pull(index, arg_list, priority=-index)
kvstore.pull(name, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None):
kvstore=None, param_names=None):
"""Perform update of param_arrays from grad_arrays not on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
if kvstore:
name = param_names[index]
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
kvstore.push(name, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
kvstore.pull(index, grad_list, priority=-index)
kvstore.pull(name, grad_list, priority=-index)
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
Expand Down Expand Up @@ -245,13 +248,14 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
if update_on_kvstore:
_update_params_on_kvstore(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore)
kvstore, executor_manager.param_names)
else:
_update_params(executor_manager.param_arrays,
executor_manager.grad_arrays,
updater=updater,
num_device=len(ctx),
kvstore=kvstore)
kvstore=kvstore,
param_names=executor_manager.param_names)

if monitor is not None:
monitor.toc_print()
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,14 @@ def update(self):
if self._update_on_kvstore:
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore)
self._kvstore, self._exec_group.param_names)
else:
_update_params(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
updater=self._updater,
num_device=len(self._context),
kvstore=self._kvstore)
kvstore=self._kvstore,
param_names=self._exec_group.param_names)

def get_outputs(self, merge_multi_context=True):
"""Gets outputs of the previous forward computation.
Expand Down
Loading