From 7e9e1e75c042a1c1aaf1ba46ae9f1ce471161e2f Mon Sep 17 00:00:00 2001 From: muli Date: Tue, 15 Sep 2015 14:21:50 -0400 Subject: [PATCH 1/3] update python_guide on kvstore --- doc/python/python_guide.md | 138 ++++++++++++++++++++++++++++++++++--- 1 file changed, 129 insertions(+), 9 deletions(-) diff --git a/doc/python/python_guide.md b/doc/python/python_guide.md index f3c70af4709b..8e88b574e65d 100644 --- a/doc/python/python_guide.md +++ b/doc/python/python_guide.md @@ -1,17 +1,18 @@ # MXNet Python Guide -This page gives a general overvie of MXNet python package. MXNet contains a -mixed flavor of elements you might need to bake flexible and efficient -applications. There are mainly three concepts in MXNet: +This page gives a general overview of MXNet's python package. MXNet contains a +mixed flavor of elements to bake flexible and efficient +applications. There are mainly three concepts: -* Numpy style [NDArray](#ndarray-numpy-style-tensor-computations-on-cpu-gpu) offers matrix and tensor computations on both CPU and -GPU, with automatic parallelization +* Numpy style [NDArray](#ndarray-numpy-style-tensor-computations-on-cpu-gpu) + offers matrix and tensor computations on both CPU and GPU, with automatic + parallelization -* [Symbol](#symbolic-and-automatic-differentiation) makes defining a neural network extremely easy, and it provides - automatic differentiation. +* [Symbol](#symbolic-and-automatic-differentiation) makes defining a neural + network extremely easy, and provides automatic differentiation. -* [KVStore](#distributed-key-value-store) allows data synchronization between - multi-GPUs and multi-machine easily +* [KVStore](#distributed-key-value-store) easy the data synchronization between + multi-GPUs and multi-machines. ## NDArray: Numpy style tensor computations on CPU/GPU @@ -375,6 +376,125 @@ greater flexiblity. ## Distributed Key-value Store +`KVStore` is a place for data sharing. We can think it as a single object shared +across different devices (GPUs and machines), where each device can push data in +and pull data out. + +### Initialization + +Let's first consider a simple example. It initializes +a (`int`, `NDAarray`) pair into the store, and then pull the value out. + +```python +>>> mx.kv.start() # start the kvstore +>>> shape = (2,3) +>>> mx.kv.init(3, mx.nd.ones(shape)*2) +>>> a = mx.nd.zeros(shape) +>>> mx.kv.pull(3, out = a) +>>> print a.asnumpy() +[[ 2. 2. 2.] + [ 2. 2. 2.]] +``` + +### Push, Aggregation, and Updater + +For any key has been initialized, we can push a new value with the same shape to the key. + +```python +>>> mx.kv.push(3, mx.nd.ones(shape)*8) +>>> mx.kv.pull(3, out = a) # pull out the value +>>> print a.asnumpy() +[[ 8. 8. 8.] + [ 8. 8. 8.]] +``` + +The data for pushing can be on any device. Furthermore, we can push multiple +values into the same key, where `kvstore` will first sum all these +values and then push the aggregated value. + +```python +>>> gpus = [mx.gpu(i) for i in range(4)] +>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] +>>> mx.kv.push(3, b) +>>> mx.kv.pull(3, out = a) +>>> print a.asnumpy() +[[ 4. 4. 4.] + [ 4. 4. 4.]] +``` + +For each push, `kvstore` applies the pushed value into the value stored by a +`updater`. The default updater is `ASSGIN`, we can replace the default one to +control how data is merged. + +```python +>>> def update(key, input, stored): +>>> print "update on key: %d" % key +>>> stored += input * 2 +>>> mx.kv.set_updater(update) +>>> mx.kv.pull(3, out=a) +>>> print a.asnumpy() +[[ 4. 4. 4.] + [ 4. 4. 4.]] +>>> mx.kv.push(3, mx.nd.ones(shape)) +update on key: 3 +>>> mx.kv.pull(3, out=a) +>>> print a.asnumpy() +[[ 6. 6. 6.] + [ 6. 6. 6.]] +``` + +### Pull + +Similar to push, we can also pull the value into several devices by a single +call. + +```python +>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] +>>> mx.kv.pull(3, out = b) +>>> print b[1].asnumpy() +[[ 6. 6. 6.] + [ 6. 6. 6.]] +``` + +### Handle a list of key-value pairs + +All operations introduced so far are on a single key. `KVStore` also provides +list of key-value pair interface. + +On single device: + +```python +>>> keys = [5, 7, 9] +>>> mx.kv.init(keys, [mx.nd.ones(shape)]*len(keys)) +>>> mx.kv.push(keys, [mx.nd.ones(shape)]*len(keys)) +update on key: 5 +update on key: 7 +update on key: 9 +>>> b = [mx.nd.zeros(shape)]*len(keys) +>>> mx.kv.pull(keys, out = b) +>>> print b[1].asnumpy() +[[ 3. 3. 3.] + [ 3. 3. 3.]] +``` + +On multi-devices: + +```pythoon +>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) +>>> mx.kv.push(keys, b) +update on key: 5 +update on key: 7 +update on key: 9 +>>> mx.kv.pull(keys, out = b) +>>> print b[1][1].asnumpy() +[[ 11. 11. 11.] + [ 11. 11. 11.]] +``` + +### Multiple machines + +Base on parameter server. The `updater` will runs on the server nodes. MORE... + ## How to Choose between APIs You can mix them all as much as you like. Here are some guidelines From 6b424e316d6049070e5d9bc85efdfb0b911bccc4 Mon Sep 17 00:00:00 2001 From: muli Date: Tue, 15 Sep 2015 15:12:27 -0400 Subject: [PATCH 2/3] update kvstore.py --- doc/python/python_api.md | 9 ++- doc/python/python_guide.md | 26 +++--- python/mxnet/kvstore.py | 159 ++++++++++++++++++++++++++++++------- 3 files changed, 153 insertions(+), 41 deletions(-) diff --git a/doc/python/python_api.md b/doc/python/python_api.md index 12b2af026f8b..b1aa36418d30 100644 --- a/doc/python/python_api.md +++ b/doc/python/python_api.md @@ -14,13 +14,20 @@ :members: ``` - ## Executor API + ```eval_rst .. automodule:: mxnet.executor :members: ``` +## KVStore API + +```eval_rst +.. automodule:: mxnet.kvstore + :members: +``` + ## IO API ```eval_rst diff --git a/doc/python/python_guide.md b/doc/python/python_guide.md index 8e88b574e65d..fa1d324d8ac6 100644 --- a/doc/python/python_guide.md +++ b/doc/python/python_guide.md @@ -4,7 +4,7 @@ This page gives a general overview of MXNet's python package. MXNet contains a mixed flavor of elements to bake flexible and efficient applications. There are mainly three concepts: -* Numpy style [NDArray](#ndarray-numpy-style-tensor-computations-on-cpu-gpu) +* Numpy style [NDArray](#ndarray-numpy-style-tensor-computations-on-cpus-and-gpus) offers matrix and tensor computations on both CPU and GPU, with automatic parallelization @@ -14,7 +14,7 @@ applications. There are mainly three concepts: * [KVStore](#distributed-key-value-store) easy the data synchronization between multi-GPUs and multi-machines. -## NDArray: Numpy style tensor computations on CPU/GPU +## NDArray: Numpy style tensor computations on CPUs and GPUs `NDArray` is the basic operation unit in MXNet for matrix and tensor computations. It is similar to `numpy.ndarray`, but with two additional @@ -445,8 +445,8 @@ update on key: 3 ### Pull -Similar to push, we can also pull the value into several devices by a single -call. +We already see how to pull a single key-value pair. Similar to push, we can also +pull the value into several devices by a single call. ```python >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] @@ -458,10 +458,8 @@ call. ### Handle a list of key-value pairs -All operations introduced so far are on a single key. `KVStore` also provides -list of key-value pair interface. - -On single device: +All operations introduced so far are about a single key. `KVStore` also provides +the interface for a list of key-value pairs. For single device: ```python >>> keys = [5, 7, 9] @@ -477,7 +475,7 @@ update on key: 9 [ 3. 3. 3.]] ``` -On multi-devices: +For multi-devices: ```pythoon >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) @@ -502,7 +500,9 @@ You can mix them all as much as you like. Here are some guidelines * Use fine-grained operator to extend parts of of more flexible symbolic graph. * Do some dynamic NArray tricks, which are even more flexible, between the calls of forward and backward of executors. -We believe that different ways offers you different levels of flexibilty and efficiency. Normally you do not need to -be flexible in all parts of the networks, so we allow you to use the fast optimized parts, -and compose it flexibly with fine-grained operator or dynamic NArray. We believe such kind of mixture allows you to build -the deep learning architecture both efficiently and flexibly as your choice. To mix is to maximize the peformance and flexiblity. +We believe that different ways offers you different levels of flexibilty and +efficiency. Normally you do not need to be flexible in all parts of the +networks, so we allow you to use the fast optimized parts, and compose it +flexibly with fine-grained operator or dynamic NArray. We believe such kind of +mixture allows you to build the deep learning architecture both efficiently and +flexibly as your choice. To mix is to maximize the peformance and flexiblity. diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 426962751039..7736cc22cff2 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -8,8 +8,12 @@ from .base import check_call, c_array, NDArrayHandle import atexit +__all__ = ['start', 'init', 'push', 'pull', 'stop', 'set_updater'] + def _ctype_key_value(keys, vals): - """parse key-value args into ctype""" + """ + Return ctype arrays for the key-value args, for internal use + """ if isinstance(keys, int): if isinstance(vals, NDArray): return (c_array(ctypes.c_int, [keys]), @@ -32,51 +36,139 @@ def _ctype_key_value(keys, vals): return (c_array(ctypes.c_int, c_keys), c_array(NDArrayHandle, c_vals)) def start(): - """start kvstore""" + """ + Start the KV Store. One must call it before calling any other functions. + + Examples: + --------- + >>> import mxnet as mx + >>> mx.kv.start() + """ check_call(_LIB.MXKVStoreStart()) def init(key, value): - """ Initialize a list of key-value pairs + """ Initialize a single or a sequence of key-value pairs into the store. + + For each key, one must init it before push and pull Parameters ---------- - keys: int or list of int - A single key or a list of keys - values: NDArray or list of NDArray - A single value of a list of values + key : int or sequence of int + The keys + value : NDArray or sequence of NDArray + The values + + Examples + -------- + # init a single key-value pair + >>> shape = (2,3) + >>> mx.kv.init(3, mx.nd.ones(shape)*2) + >>> a = mx.nd.zeros(shape) + >>> mx.kv.pull(3, out = a) + >>> print a.asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] + + # init a list of key-value pairs + >>> keys = [5, 7, 9] + >>> mx.kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStoreInit(len(ckeys), ckeys, cvals)) def push(key, value): - """ Push a value into the store + """ Push a single or a sequence of key-value pairs into the store Parameters ---------- key : int or list of int - A single key or a list of key - value: list of NDArray or list of list of NDArray - A single value of a list of value + Keys + value: NDArray or list of NDArray or list of list of NDArray + According values + + Examples + -------- + # push a single key-value pair + >>> mx.kv.push(3, mx.nd.ones(shape)*8) + >>> mx.kv.pull(3, out = a) # pull out the value + >>> print a.asnumpy() + [[ 8. 8. 8.] + [ 8. 8. 8.]] + + # aggregate the value and the push + >>> gpus = [mx.gpu(i) for i in range(4)] + >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] + >>> mx.kv.push(3, b) + >>> mx.kv.pull(3, out = a) + >>> print a.asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] + + # push a list of keys. + # single device + >>> mx.kv.push(keys, [mx.nd.ones(shape)]*len(keys)) + >>> b = [mx.nd.zeros(shape)]*len(keys) + >>> mx.kv.pull(keys, out = b) + >>> print b[1].asnumpy() + [[ 1. 1. 1.] + [ 1. 1. 1.]] + # multiple devices: + + >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) + >>> mx.kv.push(keys, b) + >>> mx.kv.pull(keys, out = b) + >>> print b[1][1].asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] """ ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStorePush(len(ckeys), ckeys, cvals)) def pull(key, out=None): - """Pull value from the store + """ Pull a single value or a sequence of values from the store Parameters ---------- - key: int or list of int - A single key or a list of key - out: NDArray or list of NDArray - A single value of a list of value + key : int or list of int + Keys + out: NDArray or list of NDArray or list of list of NDArray + According values + + Examples + -------- + # pull a single key-value pair + >>> a = mx.nd.zeros(shape) + >>> mx.kv.pull(3, out = a) + >>> print a.asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] + + # pull into multiple devices + >>> b = [mx.nd.ones(shape, gpu) for gpu in gpus] + >>> mx.kv.pull(3, out = b) + >>> print b[1].asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] + + # pull a list of key-value pairs. + # On single device + >>> keys = [5, 7, 9] + >>> b = [mx.nd.zeros(shape)]*len(keys) + >>> mx.kv.pull(keys, out = b) + >>> print b[1].asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] + # On multiple devices + >>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys) + >>> mx.kv.pull(keys, out = b) + >>> print b[1][1].asnumpy() + [[ 2. 2. 2.] + [ 2. 2. 2.]] """ assert(out is not None) ckeys, cvals = _ctype_key_value(key, out) check_call(_LIB.MXKVStorePull(len(ckeys), ckeys, cvals)) - return out - def _updater_wrapper(updater): """ a wrapper for the user-defined handle """ @@ -90,17 +182,30 @@ def updater_handle(key, lhs_handle, rhs_handle): _updater_func = None def set_updater(updater): - """ set a updater into the store - - Example: - - def updater(recv, local): - local += recv - kvstore.set_updater(updater) + """ + Set a push updater into the store Parameters ---------- - updater: functon + updater: function + the updater function + + Examples: + --------- + >>> def update(key, input, stored): + >>> print "update on key: %d" % key + >>> stored += input * 2 + >>> mx.kv.set_updater(update) + >>> mx.kv.pull(3, out=a) + >>> print a.asnumpy() + [[ 4. 4. 4.] + [ 4. 4. 4.]] + >>> mx.kv.push(3, mx.nd.ones(shape)) + update on key: 3 + >>> mx.kv.pull(3, out=a) + >>> print a.asnumpy() + [[ 6. 6. 6.] + [ 6. 6. 6.]] """ _updater_proto = ctypes.CFUNCTYPE( None, ctypes.c_int, NDArrayHandle, NDArrayHandle) @@ -109,7 +214,7 @@ def updater(recv, local): check_call(_LIB.MXKVStoreSetUpdater(_updater_func)) def stop(): - """ Stop kvstore """ + """ Stop the kvstore """ check_call(_LIB.MXKVStoreStop()) # need to clear _updater_func before _LIB global _updater_func From b42da79fe25cbe941f44ae9808a4d5337657e71f Mon Sep 17 00:00:00 2001 From: muli Date: Tue, 15 Sep 2015 15:14:49 -0400 Subject: [PATCH 3/3] import kvstore as kv --- example/cifar10/cifar10_multi_gpus.py | 10 ++--- example/mnist/mlp_multi_gpu.py | 16 ++++---- python/mxnet/__init__.py | 2 +- tests/python/unittest/test_kvstore.py | 58 +++++++++++++-------------- 4 files changed, 43 insertions(+), 43 deletions(-) diff --git a/example/cifar10/cifar10_multi_gpus.py b/example/cifar10/cifar10_multi_gpus.py index a88d8b508f7e..6ce65a1cbfab 100644 --- a/example/cifar10/cifar10_multi_gpus.py +++ b/example/cifar10/cifar10_multi_gpus.py @@ -10,7 +10,7 @@ # use multiple devices num_devs = 4 devs = [mx.gpu(i) for i in range(num_devs)] -mx.kvstore.start() +mx.kv.start() # define the network conv_cnt = 1 @@ -113,7 +113,7 @@ def momentum_update(key, grad, weight): updater = momentum( learning_rate = .05, weight_decay = .0001, momentum = 0.9) -mx.kvstore.set_updater(updater) +mx.kv.set_updater(updater) # infer shape batch_size = 196 @@ -142,7 +142,7 @@ def momentum_update(key, grad, weight): val[:] = np.random.uniform(-0.1, 0.1, shape) elif "gamma" in param_names[idx]: val[:] = 1.0 - mx.kvstore.init(idx, val) + mx.kv.init(idx, val) # data reader get_data.GetCifar10() @@ -203,7 +203,7 @@ def train(): for data, label in train_dataiter: tic = time.time() # pull weight - mx.kvstore.pull(sync_indices, out = sync_weights) + mx.kv.pull(sync_indices, out = sync_weights) # forward and backword data = data.asnumpy() @@ -221,7 +221,7 @@ def train(): g /= batch_size # push gradient - mx.kvstore.push(sync_indices, sync_grads) + mx.kv.push(sync_indices, sync_grads) # evaluate for d in range(num_devs): diff --git a/example/mnist/mlp_multi_gpu.py b/example/mnist/mlp_multi_gpu.py index 222696d92cc1..bb45b6448879 100644 --- a/example/mnist/mlp_multi_gpu.py +++ b/example/mnist/mlp_multi_gpu.py @@ -3,14 +3,14 @@ import numpy as np import os, gzip import sys -sys.path.append("../../tests/python") +sys.path.append("../../tests/python/common") import get_data import time # use multiple devices num_devs = 4 -devs = [mx.Context('gpu', i) for i in range(num_devs)] -mx.kvstore.start() +devs = [mx.Context('cpu', i) for i in range(num_devs)] +mx.kv.start() # symbol net data = mx.symbol.Variable('data') @@ -26,7 +26,7 @@ def updater(key, grad, weight): weight -= lr * grad / batch_size -mx.kvstore.set_updater(updater) +mx.kv.set_updater(updater) # find the params needed to be synchronized between devices param_names = mlp.list_arguments() @@ -39,14 +39,14 @@ def updater(key, grad, weight): input_shape = (batch_size / num_devs, 784) param_shapes, out_shapes, aux_shapes = mlp.infer_shape(data=input_shape) -# init param in the kvstore +# init param in the kv np.random.seed(0) for idx in sync_indices: shape = param_shapes[idx] val = mx.nd.zeros(shape) if "weight" in param_names[idx]: val[:] = np.random.uniform(-0.07, 0.07, shape) - mx.kvstore.init(idx, val) + mx.kv.init(idx, val) # allocate device's memory params = [[mx.nd.zeros(s, d) for s in param_shapes] for d in devs] @@ -86,7 +86,7 @@ def run_sgd(): for data, label in train_dataiter: # pull weight for idx in sync_indices: - mx.kvstore.pull(idx, out = [p[idx] for p in params]) + mx.kv.pull(idx, out = [p[idx] for p in params]) # forward and backward data = data.asnumpy() @@ -100,7 +100,7 @@ def run_sgd(): executors[d].backward() # push gradient for idx in sync_indices: - mx.kvstore.push(idx, [g[idx] for g in grads]) + mx.kv.push(idx, [g[idx] for g in grads]) # eval for d in range(num_devs): diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index c591fc29510b..1417b1262505 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -12,7 +12,7 @@ from .base import MXNetError from . import ndarray from . import symbol -from . import kvstore +from . import kvstore as kv from . import io # use mx.nd as short for mx.ndarray from . import ndarray as nd diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 72b671c74f5d..4d86f5512ce5 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -4,17 +4,17 @@ shape = (4, 4) keys = [5, 7, 11] -def init_kvstore(): - """init kvstore """ - mx.kvstore.start() +def init_kv(): + """init kv """ + mx.kv.start() # single - mx.kvstore.init(3, mx.nd.zeros(shape)) + mx.kv.init(3, mx.nd.zeros(shape)) # list - mx.kvstore.init(keys, [mx.nd.zeros(shape)] * len(keys)) + mx.kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) -def stop_kvstore(): - """stop kvstore """ - mx.kvstore.stop() +def stop_kv(): + """stop kv """ + mx.kv.stop() def check_diff_to_scalar(A, x): """ assert A == x""" @@ -23,32 +23,32 @@ def check_diff_to_scalar(A, x): def test_single_kv_pair(): """single key-value pair push & pull""" - init_kvstore() + init_kv() - mx.kvstore.push(3, mx.nd.ones(shape)) + mx.kv.push(3, mx.nd.ones(shape)) val = mx.nd.empty(shape) - mx.kvstore.pull(3, out = val) + mx.kv.pull(3, out = val) check_diff_to_scalar(val, 1) - stop_kvstore() + stop_kv() def test_list_kv_pair(): """list key-value pair push & pull""" - init_kvstore() + init_kv() - mx.kvstore.push(keys, [mx.nd.ones(shape)*4] * len(keys)) + mx.kv.push(keys, [mx.nd.ones(shape)*4] * len(keys)) val = [mx.nd.empty(shape)] * len(keys) - mx.kvstore.pull(keys, out = val) + mx.kv.pull(keys, out = val) for v in val: check_diff_to_scalar(v, 4) - stop_kvstore() + stop_kv() def test_aggregator(): """aggregate value on muliple devices""" - init_kvstore() + init_kv() # devices num_devs = 4 @@ -57,22 +57,22 @@ def test_aggregator(): # single vals = [mx.nd.ones(shape, d) for d in devs] - mx.kvstore.push(3, vals) - mx.kvstore.pull(3, out = vals) + mx.kv.push(3, vals) + mx.kv.pull(3, out = vals) for v in vals: check_diff_to_scalar(v, num_devs) # list vals = [[mx.nd.ones(shape, d)*2.0 for d in devs]] * len(keys) - mx.kvstore.push(keys, vals) - mx.kvstore.pull(keys, out = vals) + mx.kv.push(keys, vals) + mx.kv.pull(keys, out = vals) for vv in vals: for v in vv: check_diff_to_scalar(v, num_devs * 2.0) - stop_kvstore() + stop_kv() def updater(key, recv, local): """use updater: +=""" @@ -81,8 +81,8 @@ def updater(key, recv, local): def test_updater(dev = 'cpu'): """updater""" - init_kvstore() - mx.kvstore.set_updater(updater) + init_kv() + mx.kv.set_updater(updater) # devices num_devs = 4 @@ -91,8 +91,8 @@ def test_updater(dev = 'cpu'): # single vals = [mx.nd.ones(shape, d) for d in devs] - mx.kvstore.push(3, vals) - mx.kvstore.pull(3, out = vals) + mx.kv.push(3, vals) + mx.kv.pull(3, out = vals) for v in vals: check_diff_to_scalar(v, num_devs) @@ -102,15 +102,15 @@ def test_updater(dev = 'cpu'): num_push = 4 for i in range(num_push): - mx.kvstore.push(keys, vals) + mx.kv.push(keys, vals) - mx.kvstore.pull(keys, out = vals) + mx.kv.pull(keys, out = vals) for vv in vals: for v in vv: check_diff_to_scalar(v, num_devs * num_push) - stop_kvstore() + stop_kv() if __name__ == '__main__': test_single_kv_pair()