Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support str key type in kvstore #6765

Merged
merged 16 commits into from
Jun 23, 2017
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,19 @@ MXNET_DLL int MXKVStoreInit(KVStoreHandle handle,
const int* keys,
NDArrayHandle* vals);

/*!
* \brief Init a list of (key,value) pairs in kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreInitEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals);

/*!
* \brief Push a list of (key,value) pairs to kvstore
* \param handle handle to the kvstore
Expand All @@ -1325,6 +1338,20 @@ MXNET_DLL int MXKVStorePush(KVStoreHandle handle,
const int* keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief Push a list of (key,value) pairs to kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePushEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief pull a list of (key, value) pairs from the kvstore
* \param handle handle to the kvstore
Expand All @@ -1339,6 +1366,20 @@ MXNET_DLL int MXKVStorePull(KVStoreHandle handle,
const int* keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief pull a list of (key, value) pairs from the kvstore, where each key is a string
* \param handle handle to the kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \param priority the priority of the action
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority);
/*!
* \brief user-defined updater for the kvstore
* It's this updater's responsibility to delete \a recv and \a local
Expand Down
27 changes: 27 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ class KVStore {
*/
virtual void Init(const std::vector<int>& keys,
const std::vector<NDArray>& values) = 0;
/*!
* \brief Initialize a list of key-value pair to the store.
* \param keys a list of unique keys in string format
* \param values a list of values
*/
virtual void Init(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values) = 0;
/*!
* \brief push a list of key-value pairs into the store
*
Expand Down Expand Up @@ -102,6 +109,16 @@ class KVStore {
virtual void Push(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority = 0) = 0;

/*!
* \brief push a list of key-value pairs into the store
* \param keys the list of keys in string format
* \param values the list of values
* \param priority Priority of the action.
*/
virtual void Push(const std::vector<std::string>& str_keys,
const std::vector<NDArray>& values,
int priority = 0) = 0;
/*!
* \brief pull a list of key-value pairs from the store
*
Expand All @@ -128,6 +145,16 @@ class KVStore {
virtual void Pull(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;
/*!
* \brief pull a list of key-value pairs from the store
* \param keys the list of keys in string format
* \param values the list of buffers for the pulled data, they should be preallocated
* \param priority Priority of the action.
*/
virtual void Pull(const std::vector<std::string>& str_keys,
const std::vector<NDArray*>& values,
int priority = 0) = 0;


/**
* \brief the prototype of user-defined updater
Expand Down
69 changes: 58 additions & 11 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ def _ctype_key_value(keys, vals):
c_vals += c_val_i
return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals))

def _ctype_str_key_value(keys, vals):
names = []
if isinstance(keys, str):
if isinstance(vals, NDArray):
names.append(c_str(keys))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for _cast_to_str_keys

always cast to str directly.
names.append(c_str(str(keys)))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

return (c_array(ctypes.c_char_p, names),
c_array(NDArrayHandle, [vals.handle]))
else:
for value in vals:
assert(isinstance(value, NDArray))
return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)),
c_array(NDArrayHandle, [value.handle for value in vals]))
else:
assert(len(keys) == len(vals))
for k in keys:
assert(isinstance(k, str))
c_keys = []
c_vals = []
for key, val in zip(keys, vals):
c_key_i, c_val_i = _ctype_str_key_value(key, val)
c_keys += c_key_i
c_vals += c_val_i
return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals))

def _use_str_keys(key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can always use str key by converting int to str

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. Do we still want to keep the MXKVStorePull(..., mx_uint *key) API then? Currently I need to create a separate one MXKVStorePullEx

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also the Pull(vector<int> keys ..) interface in kvstore.h

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original API needs to be kept for other packages. Python doesn't need to use it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense

return isinstance(key, str) or (isinstance(key, (list, tuple)) and isinstance(key[0], str))

def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
Expand Down Expand Up @@ -95,9 +121,15 @@ def init(self, key, value):
>>> keys = [5, 7, 9]
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
"""
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStoreInit(
self.handle, mx_uint(len(ckeys)), ckeys, cvals))
use_str_key = _use_str_keys(key)
if use_str_key:
ckeys, cvals = _ctype_str_key_value(key, value)
check_call(_LIB.MXKVStoreInitEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals))
else:
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStoreInit(
self.handle, mx_uint(len(ckeys)), ckeys, cvals))

def push(self, key, value, priority=0):
""" Pushes a single or a sequence of key-value pairs into the store.
Expand Down Expand Up @@ -156,10 +188,18 @@ def push(self, key, value, priority=0):
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStorePush(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
use_str_key = _use_str_keys(key)
if use_str_key:
ckeys, cvals = _ctype_str_key_value(key, value)
check_call(_LIB.MXKVStorePushEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
else:
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStorePush(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))


def pull(self, key, out=None, priority=0):
""" Pulls a single value or a sequence of values from the store.
Expand Down Expand Up @@ -218,10 +258,17 @@ def pull(self, key, out=None, priority=0):
[ 2. 2. 2.]]
"""
assert(out is not None)
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePull(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
use_str_key = _use_str_keys(key)
if use_str_key:
ckeys, cvals = _ctype_str_key_value(key, out)
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
else:
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePull(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))

def set_optimizer(self, optimizer):
""" Registers an optimizer with the kvstore.
Expand Down
24 changes: 14 additions & 10 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,34 +80,37 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,
update_on_kvstore):
"""Initialize kvstore"""
for idx, param_on_devs in enumerate(param_arrays):
kvstore.init(idx, arg_params[param_names[idx]])
name = param_names[idx]
kvstore.init(name, arg_params[name])

if update_on_kvstore:
kvstore.pull(idx, param_on_devs, priority=-idx)
kvstore.pull(name, param_on_devs, priority=-idx)

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore):
def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names):
"""Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
name = param_names[index]
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
kvstore.push(name, grad_list, priority=-index)
# pull back the weights
kvstore.pull(index, arg_list, priority=-index)
kvstore.pull(name, arg_list, priority=-index)

def _update_params(param_arrays, grad_arrays, updater, num_device,
kvstore=None):
kvstore=None, param_names=None):
"""Perform update of param_arrays from grad_arrays not on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
if grad_list[0] is None:
continue
if kvstore:
name = param_names[index]
# push gradient, priority is negative index
kvstore.push(index, grad_list, priority=-index)
kvstore.push(name, grad_list, priority=-index)
# pull back the sum gradients, to the same locations.
kvstore.pull(index, grad_list, priority=-index)
kvstore.pull(name, grad_list, priority=-index)
for k, p in enumerate(zip(arg_list, grad_list)):
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
Expand Down Expand Up @@ -245,13 +248,14 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
if update_on_kvstore:
_update_params_on_kvstore(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore)
kvstore, executor_manager.param_names)
else:
_update_params(executor_manager.param_arrays,
executor_manager.grad_arrays,
updater=updater,
num_device=len(ctx),
kvstore=kvstore)
kvstore=kvstore,
param_names=executor_manager.param_names)

if monitor is not None:
monitor.toc_print()
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,14 @@ def update(self):
if self._update_on_kvstore:
_update_params_on_kvstore(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
self._kvstore)
self._kvstore, self._exec_group.param_names)
else:
_update_params(self._exec_group.param_arrays,
self._exec_group.grad_arrays,
updater=self._updater,
num_device=len(self._context),
kvstore=self._kvstore)
kvstore=self._kvstore,
param_names=self._exec_group.param_names)

def get_outputs(self, merge_multi_context=True):
"""Gets outputs of the previous forward computation.
Expand Down
47 changes: 47 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,21 @@ int MXKVStoreInit(KVStoreHandle handle,
API_END();
}

int MXKVStoreInitEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Init(v_keys, v_vals);
API_END();
}

int MXKVStorePush(KVStoreHandle handle,
mx_uint num,
const int* keys,
Expand All @@ -641,6 +656,22 @@ int MXKVStorePush(KVStoreHandle handle,
API_END();
}

int MXKVStorePushEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = *static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Push(v_keys, v_vals, priority);
API_END();
}

int MXKVStorePull(KVStoreHandle handle,
mx_uint num,
const int* keys,
Expand All @@ -657,6 +688,22 @@ int MXKVStorePull(KVStoreHandle handle,
API_END();
}

int MXKVStorePullEx(KVStoreHandle handle,
mx_uint num,
const char** keys,
NDArrayHandle* vals,
int priority) {
API_BEGIN();
std::vector<std::string> v_keys(num);
std::vector<NDArray*> v_vals(num);
for (mx_uint i = 0; i < num; ++i) {
v_keys[i] = keys[i];
v_vals[i] = static_cast<NDArray*>(vals[i]);
}
static_cast<KVStore*>(handle)->Pull(v_keys, v_vals, priority);
API_END();
}

int MXKVStoreSetUpdater(KVStoreHandle handle,
MXKVStoreUpdater updater,
void* updater_handle) {
Expand Down
1 change: 0 additions & 1 deletion src/kvstore/kvstore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <stdlib.h>
#include <dmlc/logging.h>
#include "./kvstore_local.h"
// #include "./kvstore_device.h"
#if MXNET_USE_DIST_KVSTORE
#include "./kvstore_dist.h"
#endif // MXNET_USE_DIST_KVSTORE
Expand Down
Loading