-
Notifications
You must be signed in to change notification settings - Fork 6.8k
support str key type in kvstore #6765
Changes from 6 commits
caa1d08
946abc6
de8f0b8
6e50272
c430ce5
eac18b3
3a277a4
3b7e730
aec39f1
c1b9497
d1b92fc
6e97bcd
9e2562f
a4972b9
862cda8
bcac25f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,32 @@ def _ctype_key_value(keys, vals): | |
c_vals += c_val_i | ||
return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) | ||
|
||
def _ctype_str_key_value(keys, vals): | ||
names = [] | ||
if isinstance(keys, str): | ||
if isinstance(vals, NDArray): | ||
names.append(c_str(keys)) | ||
return (c_array(ctypes.c_char_p, names), | ||
c_array(NDArrayHandle, [vals.handle])) | ||
else: | ||
for value in vals: | ||
assert(isinstance(value, NDArray)) | ||
return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), | ||
c_array(NDArrayHandle, [value.handle for value in vals])) | ||
else: | ||
assert(len(keys) == len(vals)) | ||
for k in keys: | ||
assert(isinstance(k, str)) | ||
c_keys = [] | ||
c_vals = [] | ||
for key, val in zip(keys, vals): | ||
c_key_i, c_val_i = _ctype_str_key_value(key, val) | ||
c_keys += c_key_i | ||
c_vals += c_val_i | ||
return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) | ||
|
||
def _use_str_keys(key): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can always use str key by converting int to str There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure. Do we still want to keep the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original API needs to be kept for other packages. Python doesn't need to use it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense |
||
return isinstance(key, str) or (isinstance(key, (list, tuple)) and isinstance(key[0], str)) | ||
|
||
def _updater_wrapper(updater): | ||
"""A wrapper for the user-defined handle.""" | ||
|
@@ -95,9 +121,15 @@ def init(self, key, value): | |
>>> keys = [5, 7, 9] | ||
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) | ||
""" | ||
ckeys, cvals = _ctype_key_value(key, value) | ||
check_call(_LIB.MXKVStoreInit( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals)) | ||
use_str_key = _use_str_keys(key) | ||
if use_str_key: | ||
ckeys, cvals = _ctype_str_key_value(key, value) | ||
check_call(_LIB.MXKVStoreInitEx( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals)) | ||
else: | ||
ckeys, cvals = _ctype_key_value(key, value) | ||
check_call(_LIB.MXKVStoreInit( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals)) | ||
|
||
def push(self, key, value, priority=0): | ||
""" Pushes a single or a sequence of key-value pairs into the store. | ||
|
@@ -156,10 +188,18 @@ def push(self, key, value, priority=0): | |
[[ 4. 4. 4.] | ||
[ 4. 4. 4.]] | ||
""" | ||
ckeys, cvals = _ctype_key_value(key, value) | ||
check_call(_LIB.MXKVStorePush( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, | ||
ctypes.c_int(priority))) | ||
use_str_key = _use_str_keys(key) | ||
if use_str_key: | ||
ckeys, cvals = _ctype_str_key_value(key, value) | ||
check_call(_LIB.MXKVStorePushEx( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, | ||
ctypes.c_int(priority))) | ||
else: | ||
ckeys, cvals = _ctype_key_value(key, value) | ||
check_call(_LIB.MXKVStorePush( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, | ||
ctypes.c_int(priority))) | ||
|
||
|
||
def pull(self, key, out=None, priority=0): | ||
""" Pulls a single value or a sequence of values from the store. | ||
|
@@ -218,10 +258,17 @@ def pull(self, key, out=None, priority=0): | |
[ 2. 2. 2.]] | ||
""" | ||
assert(out is not None) | ||
ckeys, cvals = _ctype_key_value(key, out) | ||
check_call(_LIB.MXKVStorePull( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, | ||
ctypes.c_int(priority))) | ||
use_str_key = _use_str_keys(key) | ||
if use_str_key: | ||
ckeys, cvals = _ctype_str_key_value(key, out) | ||
check_call(_LIB.MXKVStorePullEx( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, | ||
ctypes.c_int(priority))) | ||
else: | ||
ckeys, cvals = _ctype_key_value(key, out) | ||
check_call(_LIB.MXKVStorePull( | ||
self.handle, mx_uint(len(ckeys)), ckeys, cvals, | ||
ctypes.c_int(priority))) | ||
|
||
def set_optimizer(self, optimizer): | ||
""" Registers an optimizer with the kvstore. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for _cast_to_str_keys
always cast to str directly.
names.append(c_str(str(keys)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok