Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-432] Add Foreach #11531

Merged
merged 135 commits into from
Jul 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
135 commits
Select commit Hold shift + click to select a range
e86d92c
Test input a graph.
zheng-da Mar 20, 2018
7eacfab
Update foreach to execute the subgraph.
zheng-da Mar 21, 2018
6874f7d
print inputs/outputs in foreach.
zheng-da Mar 21, 2018
8f5e62e
Remove print.
zheng-da Mar 23, 2018
1dd35a8
add test code for foreach.
zheng-da Mar 23, 2018
1f117cb
exec foreach outside the engine.
zheng-da Apr 5, 2018
cc29fc1
Implements forward of foreach.
zheng-da Apr 5, 2018
036aada
Add support for variable numbers of inputs and outputs.
zheng-da Apr 6, 2018
84e0e24
Add a python wrapper for foreach.
zheng-da Apr 6, 2018
f2a28f0
Fix the order of inputs.
zheng-da Apr 6, 2018
1c4cf0e
hide C version of foreach.
zheng-da Apr 7, 2018
9aa896d
fix a bug temporarily.
zheng-da Apr 7, 2018
6784408
add test with lstm.
zheng-da Apr 6, 2018
f488647
Test free variables.
zheng-da Apr 7, 2018
d9b0c50
change for the new interface of InputGraph attribute.
zheng-da Apr 9, 2018
1188d97
Add attribute to the subgraph.
zheng-da Apr 11, 2018
c0cd6ac
Handle free variables.
zheng-da Apr 11, 2018
74d280b
Get all input symbols of a subgraph.
zheng-da Apr 13, 2018
2bc80e3
Fix shape, dtype and storage inference.
zheng-da Apr 13, 2018
68faa17
reorganize the output of foreach.
zheng-da Apr 14, 2018
3751ca7
Add a gluon RNN unroll with symbol foreach.
zheng-da Apr 14, 2018
b98e06d
print unnecessary print.
zheng-da Apr 16, 2018
fc575fe
have imperative and symbolic foreach.
zheng-da Apr 16, 2018
37da6fb
Fix an error after moving foreach.
zheng-da Apr 18, 2018
f41235c
Fix imperative foreach
zheng-da Apr 18, 2018
214c1c2
Fix a minor problem.
zheng-da Apr 24, 2018
9aabc74
Use CachedOp to execute subgraph.
zheng-da Apr 30, 2018
7fc0155
update TODO.
zheng-da May 1, 2018
0d3613a
make foreach op use FStatefulComputeEx.
zheng-da May 1, 2018
f33d0f4
Add backward.
zheng-da May 2, 2018
d82dd30
Fix bugs.
zheng-da May 4, 2018
868c9f2
enable backward test in lstm.
zheng-da May 4, 2018
84e1877
Fix a bug in foreach backward for free variables.
zheng-da May 7, 2018
00b8b1c
change for the new CachedOp.
zheng-da May 9, 2018
f2e324e
Detect the backward computation.
zheng-da May 9, 2018
e1322d1
Fix bugs in foreach.
zheng-da May 9, 2018
98955a4
fix tests.
zheng-da May 10, 2018
d8c9b1f
update tests.
zheng-da May 11, 2018
32e3b17
check state shape.
zheng-da May 12, 2018
8caf708
enable nested foreach.
zheng-da May 14, 2018
d4ef381
remove print.
zheng-da May 16, 2018
4270032
fix a bug in test.
zheng-da May 17, 2018
b54f234
handle infer storage type for backward.
zheng-da May 18, 2018
14d319b
address comments.
zheng-da May 18, 2018
0e666a9
address comments.
zheng-da May 18, 2018
255c478
move some common functions out.
zheng-da May 18, 2018
2beb3f3
address comments.
zheng-da May 18, 2018
716bc6a
fix lint.
zheng-da May 18, 2018
dd5f862
Fix lint.
zheng-da May 18, 2018
b60157a
add doc.
zheng-da May 19, 2018
57b2ba5
undo modification in imperative.h
zheng-da May 19, 2018
7c49057
add doc and remove example code.
zheng-da May 19, 2018
1045908
fix lint.
zheng-da May 19, 2018
e8ec3aa
fix lint.
zheng-da May 19, 2018
e4f5808
Fix lint.
zheng-da May 19, 2018
c078cbf
make nd.foreach and sym.foreach consistent.
zheng-da May 21, 2018
b965e6e
fix compile error.
zheng-da May 21, 2018
57fcb84
address comments.
zheng-da May 21, 2018
c03c56f
update.
zheng-da May 21, 2018
224f3e2
check for loop only works for dense arrays.
zheng-da May 22, 2018
cd67c6f
move control flow op out of nn/
zheng-da May 22, 2018
742ef40
fix include.
zheng-da May 22, 2018
4492949
add a test in gluon.
zheng-da May 22, 2018
26e3e7e
small fix.
zheng-da May 22, 2018
6bff448
remove subgraph_name
zheng-da May 22, 2018
1e4cd45
create loop state for reuse in the future.
zheng-da May 22, 2018
7079e73
work for GPU.
zheng-da May 22, 2018
64f4362
Fix tests.
zheng-da May 29, 2018
31d9112
Fix bugs caused by ctypes (#29)
junrushao1994 May 30, 2018
601edbe
Add save/load json in testcases for foreach (#30)
junrushao1994 Jun 1, 2018
f4da935
support subgraph in stateful executor.
zheng-da Jun 4, 2018
90b7829
Fix compilation.
zheng-da Jun 4, 2018
ae3ea22
move code.
zheng-da May 22, 2018
0db16f0
Revert "remove subgraph_name"
zheng-da May 23, 2018
5f626ae
cut graph.
zheng-da May 25, 2018
f2c428f
rename new var nodes.
zheng-da May 26, 2018
2a69257
fix a bug when a subgraph has variable nodes.
zheng-da Jun 8, 2018
efeedd6
Fix a bug of getting symbols.
zheng-da Jun 8, 2018
28fe469
copy var nodes.
zheng-da Jun 8, 2018
cf91c59
Fix getting op states.
zheng-da Jun 13, 2018
f2edf2a
fix lint error.
zheng-da Jun 13, 2018
a35899f
address comments.
zheng-da Jun 13, 2018
8c6aca0
fix lint error.
zheng-da Jun 13, 2018
ccaf388
simplify the execution of subgraph in the main thread.
zheng-da Jun 13, 2018
25cf8ac
fix lint error.
zheng-da Jun 13, 2018
51de14c
avoid waiting for computation in each iteration.
zheng-da Jun 13, 2018
3eb0bc1
reuse cached op for inference.
zheng-da Jun 14, 2018
4a0ff21
share memory across mini-batches.
zheng-da Jun 14, 2018
8766cb2
reuse memory.
zheng-da Jun 14, 2018
97b9074
add tests for multiple batches.
zheng-da Jun 14, 2018
e38b7f4
remove entry.
zheng-da Jun 15, 2018
198bcfb
add benchmark for foreach.
zheng-da Jun 16, 2018
811acb3
benchmark large batch size.
zheng-da Jun 16, 2018
24fa83b
Fix the benchmark for GPU.
zheng-da Jun 17, 2018
550e48a
address comments.
zheng-da Jun 17, 2018
97d0332
update shape/dtype/storage inference.
zheng-da Jun 17, 2018
0b0a36e
update contrib API docs.
zheng-da Jun 17, 2018
f1ff55d
support nested foreach.
zheng-da Jun 18, 2018
156f1c8
use a single CachedOp for all iterations.
zheng-da Jun 19, 2018
871fd3b
use large dim.
zheng-da Jun 19, 2018
b5dfc3f
update benchmark.
zheng-da Jun 19, 2018
202a74c
update benchmark.
zheng-da Jun 19, 2018
0606c3c
update benchmark.
zheng-da Jun 19, 2018
6019de5
update benchmark.
zheng-da Jun 19, 2018
045186d
return symbol arrays correctly in MXSymbolCutSubgraph.
zheng-da Jun 20, 2018
484309e
return symbol arrays in MXSymbolGetInputSymbols.
zheng-da Jun 20, 2018
0ebd5e5
fix lint error.
zheng-da Jun 20, 2018
a9e253d
use cachedop to infer storage in backward.
zheng-da Jun 21, 2018
1f8469f
fix scala API.
zheng-da Jun 21, 2018
25e15a0
update comments.
zheng-da Jun 21, 2018
ff4eea0
fix scala.
zheng-da Jun 21, 2018
b8aa62a
fix test.
zheng-da Jun 21, 2018
64e4ff6
fix attribute name.
zheng-da Jun 21, 2018
d243c12
move benchmark.
zheng-da Jun 21, 2018
3afc4d4
fix the mapping of operator inputs/outputs and subgraph inputs/outputs.
zheng-da Jun 22, 2018
62901fe
add tests for dtype/shape inference.
zheng-da Jun 23, 2018
14b8fb9
reorganize tests.
zheng-da Jun 23, 2018
f7d7f17
fix a bug of cutting NodeEntry.
zheng-da Jun 23, 2018
7d012d9
fix lint error.
zheng-da Jun 23, 2018
b83253d
handle the case that outputs are inputs.
zheng-da Jun 24, 2018
275bbf1
handle the case that inputs aren't used.
zheng-da Jun 24, 2018
0e6df9a
handle the case without output data.
zheng-da Jun 25, 2018
dfadc8d
fix a bug in foreach backward.
zheng-da Jun 25, 2018
5e9cf5f
fix a bug when there isn't output data.
zheng-da Jun 25, 2018
696f53c
Fix lint error.
zheng-da Jun 26, 2018
094977d
test diff Gluon RNN cells.
zheng-da Jun 26, 2018
7b016ae
test all symbol RNN cells.
zheng-da Jun 26, 2018
9609ce8
adjust the test precision.
zheng-da Jun 26, 2018
fa8abbd
Fix a bug in getting a list of variable names.
zheng-da Jun 26, 2018
8e74d80
fix lint error.
zheng-da Jun 26, 2018
2439105
Test 1D array.
zheng-da Jun 27, 2018
53cdbfa
fix a bug when subgraph inputs and outputs share NDArray.
zheng-da Jun 27, 2018
9bff317
fix.
zheng-da Jun 28, 2018
d3687ef
fix
zheng-da Jun 28, 2018
392a7e4
add comments.
zheng-da Jul 2, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions benchmark/python/control_flow/rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import subprocess
import mxnet as mx
from mxnet import gluon
import time
import copy

def get_gpus():
"""
return a list of GPUs
"""
try:
re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True)
except OSError:
return []
return range(len([i for i in re.split('\n') if 'GPU' in i]))

class TestRNNLayer(gluon.HybridBlock):
def __init__(self, cell, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.cell = cell

def hybrid_forward(self, F, inputs, states):
out, states = F.contrib.foreach(self.cell, inputs, states)
return out

def benchmark_rnn(cell, rnn_data, states):
ctx = rnn_data.context
num_batches = 20

# Imperative
cell0 = copy.deepcopy(cell)
layer0 = TestRNNLayer(cell0)
layer0.initialize(ctx=ctx)

# Hybridize
cell1 = copy.deepcopy(cell)
cell1.hybridize()
layer1 = TestRNNLayer(cell1)
layer1.initialize(ctx=ctx)

# Hybridize
cell2 = copy.deepcopy(cell)
layer2 = TestRNNLayer(cell2)
layer2.initialize(ctx=ctx)
layer2.hybridize()
layer2(rnn_data, states)

# Hybridize
cell3 = copy.deepcopy(cell)
cell3.hybridize(static_alloc=True)
layer3 = TestRNNLayer(cell3)
layer3.initialize(ctx=ctx)

tic = time.time()
for i in range(num_batches):
res0 = layer0(rnn_data, states)
mx.nd.waitall()
print("Imperative inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res1 = layer1(rnn_data, states)
mx.nd.waitall()
print("Hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res3 = layer3(rnn_data, states)
mx.nd.waitall()
print("Static-hybrid-cell inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
res2 = layer2(rnn_data, states)
mx.nd.waitall()
print("Hybrid inference takes " + str(time.time() - tic))

layer2.export("foreach_rnn")
symnet = mx.symbol.load('foreach_rnn-symbol.json')
args1 = {}
params = layer2.collect_params()
for key in params.keys():
args1[key] = params[key].data()
args1['data0'] = rnn_data
for i in range(len(states)):
args1['data' + str(i + 1)] = states[i]
exe = symnet.bind(ctx=ctx, args=args1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=False)
mx.nd.waitall()
print("Symbol inference takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res0 = layer0(rnn_data, states)
res0.backward()
mx.nd.waitall()
print("Imperative training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res1 = layer1(rnn_data, states)
res1.backward()
mx.nd.waitall()
print("Hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res3 = layer3(rnn_data, states)
res3.backward()
mx.nd.waitall()
print("Static-hybrid-cell training takes " + str(time.time() - tic))

tic = time.time()
for i in range(num_batches):
with mx.autograd.record():
res2 = layer2(rnn_data, states)
res2.backward()
mx.nd.waitall()
print("Hybrid training takes " + str(time.time() - tic))

# gradients for the backward of the foreach symbol
args_grad1 = {}
for key in args1.keys():
args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx)
exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1)
tic = time.time()
for i in range(num_batches):
exe.forward(is_train=True)
exe.backward(res2)
mx.nd.waitall()
print("Symbol training takes " + str(time.time() - tic))
print("")

if __name__ == '__main__':
ndim = 512
seq_len = 100
batch_sizes = [1, 32]
cells = [gluon.rnn.GRUCell(ndim, prefix='rnn_'),
gluon.rnn.LSTMCell(ndim, prefix='rnn_')]
ctxs = [mx.cpu(0), mx.gpu(0)]
for cell in cells:
for ctx in ctxs:
for batch_size in batch_sizes:
if len(get_gpus()) == 0 and ctx == mx.gpu(0):
continue

if isinstance(cell, gluon.rnn.GRUCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
elif isinstance(cell, gluon.rnn.LSTMCell):
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim),
ctx=mx.cpu(0))
states = []
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim),
ctx=mx.cpu(0)))
if ctx == mx.gpu(0):
dev = "GPU"
else:
dev = "CPU"
print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev,
batch_size))
benchmark_rnn(cell, rnn_data, states)
1 change: 1 addition & 0 deletions docs/api/python/ndarray/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib`
fft
ifft
quantize
foreach
```

## API Reference
Expand Down
1 change: 1 addition & 0 deletions docs/api/python/symbol/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib`
fft
ifft
quantize
foreach
```

## API Reference
Expand Down
22 changes: 22 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,28 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **name);

/*!
* \brief Get the input symbols of the graph.
* \param sym The graph.
* \param inputs The input symbols of the graph.
* \param input_size the number of input symbols returned.
*/
MXNET_DLL int MXSymbolGetInputSymbols(SymbolHandle sym, SymbolHandle **inputs,
int *input_size);

/*!
* \brief Cut a subgraph whose nodes are marked with a subgraph attribute.
* The input graph will be modified. A variable node will be created for each
* edge that connects to nodes outside the subgraph. The outside nodes that
* connect to the subgraph will be returned.
* \param sym The graph.
* \param inputs The nodes that connect to the subgraph.
* \param input_size The number of such nodes.
*/
MXNET_DLL int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **inputs,
int *input_size);

/*!
* \brief Get the detailed information about atomic symbol.
* \param creator the AtomicSymbolCreator.
Expand Down
11 changes: 9 additions & 2 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ enum OpReqType {
* \sa Resource
*/
struct OpContext {
/*! \brief whether there is a backward phase to compute gradients. */
bool need_grad;
/*! \brief whether it is training phase */
int is_train;
bool is_train;
/*! \brief RunContext related resources */
RunContext run_ctx;
/*! \brief the callback when operation completes, used by asynchronize ops */
Expand Down Expand Up @@ -98,7 +100,12 @@ enum class ExecType {
* In current implementation, copy operator is specially handled by executor.
* This flag is used for special case treatment and future extension of different copy ops.
*/
kCrossDeviceCopy
kCrossDeviceCopy,
/*!
* \brief A subgraph execution should happen in the main thread, instead of
* in the execution engine.
*/
kSubgraphExec,
};

/*! \brief the dispatch mode of the operator */
Expand Down
96 changes: 96 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import math
from ..context import current_context
from ..random import uniform
from ..base import _as_list
from . import ndarray
try:
from .gen_contrib import *
except ImportError:
Expand Down Expand Up @@ -95,3 +97,97 @@ def rand_zipfian(true_classes, num_sampled, range_max, ctx=None):
expected_count_sampled = expected_prob_sampled * num_sampled
return sampled_classes, expected_count_true, expected_count_sampled
# pylint: enable=line-too-long

def foreach(body, data, init_states):
"""Run a for loop with user-defined computation over NDArrays on dimension 0.

This operator simulates a for loop and body has the computation for an iteration
of the for loop. It runs the computation in body on each slice from the input
NDArrays.

body takes two arguments as input and outputs a tuple of two elements,
as illustrated below:

out, states = body(data1, states)

data1 can be either an NDArray or a list of NDArrays. If data is an NDArray,
data1 is an NDArray. Otherwise, data1 is a list of NDArrays and has the same
size as data. states is a list of NDArrays and have the same size as init_states.
Similarly, out can be either an NDArray or a list of NDArrays, which are concatenated
as the first output of foreach; states from the last execution of body
are the second output of foreach.

The computation done by this operator is equivalent to the pseudo code below
when the input data is NDArray:

states = init_states
outs = []
for i in data.shape[0]:
s = data[i]
out, states = body(s, states)
outs.append(out)
outs = stack(*outs)


Parameters
----------
body : a Python function.
Define computation in an iteration.
data: an NDArray or a list of NDArrays.
The input data.
init_states: an NDArray or a list of NDArrays.
The initial values of the loop states.
name: string.
The name of the operator.

Returns
-------
outputs: an NDArray or a list of NDArrays.
The output data concatenated from the output of all iterations.
states: a list of NDArrays.
The loop states in the last iteration.

Examples
--------
>>> step = lambda data, states: (data + states[0], [states[0] * 2])
>>> data = mx.nd.random.uniform(shape=(2, 10))
>>> states = [mx.nd.random.uniform(shape=(10))]
>>> outs, states = mx.nd.contrib.foreach(step, data, states)
"""

def check_input(inputs, in_type, msg):
is_NDArray_or_list = True
if isinstance(inputs, list):
for i in inputs:
if not isinstance(i, in_type):
is_NDArray_or_list = False
break
else:
is_NDArray_or_list = isinstance(inputs, in_type)
assert is_NDArray_or_list, msg

check_input(data, ndarray.NDArray, "data should be an NDArray or a list of NDArrays")
check_input(init_states, ndarray.NDArray,
"init_states should be an NDArray or a list of NDArrays")

not_data_list = isinstance(data, ndarray.NDArray)
num_iters = data.shape[0] if not_data_list else data[0].shape[0]
states = init_states
outputs = []
for i in range(num_iters):
if not_data_list:
eles = data[i]
else:
eles = [d[i] for d in data]
outs, states = body(eles, states)
outs = _as_list(outs)
outputs.append(outs)
outputs = zip(*outputs)
tmp_outputs = []
for out in outputs:
tmp_outputs.append(ndarray.op.stack(*out))
outputs = tmp_outputs

if not_data_list and len(outputs) == 1:
outputs = outputs[0]
return (outputs, states)
Loading