Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-749] Correct usages of CutSubgraph in 3 control flow operato…
Browse files Browse the repository at this point in the history
…rs (#12078)

* Fix cut graph

* Copy only when necessary

* Add unittest for while_loop

* Add unittest for foreach

* Add unittest for cond

* Avoid magic number: 0 => kUndefinedStorage
  • Loading branch information
junrushao authored and eric-haibin-lin committed Aug 10, 2018
1 parent bb0f8a6 commit af15853
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 15 deletions.
28 changes: 16 additions & 12 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _cut_subgraph(subg):
# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
def _construct_subgraph(sym_out, sym_states):
def _construct_subgraph(sym_out, sym_states, name):
sym_out = _as_list(sym_out)
sym_states = _as_list(sym_states)
all_outputs = []
Expand All @@ -137,18 +137,16 @@ def _construct_subgraph(sym_out, sym_states):

flat_out = []
all_input_names = g.list_inputs()
output_names = [o.name for o in sym_out]
output_names = {o.name for o in sym_out}
for o in sym_out:
if o.name in all_input_names:
if o.name in all_input_names or o.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(o))
else:
flat_out.append(o)

for s in sym_states:
if s.name in all_input_names or s.name in output_names:
# There is a problem if the outputs are the same as the inputs
# or the first output. By calling identity, we can make sure that
# all symbols will refer to different NDArrays.
if s.name in all_input_names or s.name in output_names or \
s.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(s))
else:
flat_out.append(s)
Expand Down Expand Up @@ -256,7 +254,7 @@ def check_data(inputs, in_type, msg):
num_out_data = len(sym_out)
num_states = len(sym_states)
num_outputs = num_out_data + num_states
g = _construct_subgraph(sym_out, sym_states)
g = _construct_subgraph(sym_out, sym_states, name)

input_syms = _get_graph_inputs(g)
cut_syms = _cut_subgraph(g)
Expand Down Expand Up @@ -469,9 +467,12 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
num_outputs = len(outputs) + len(final_state)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
all_input_names = symbol.Group(outputs + final_state).list_inputs()
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
all_input_names = symbol.Group(outputs + final_state).list_inputs()
in_input = lambda x: x.name in all_input_names
in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
else x
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
return graph, num_out_data, num_outputs

Expand Down Expand Up @@ -627,9 +628,12 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
num_outputs = len(outputs)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
all_input_names = symbol.Group(outputs).list_inputs()
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
all_input_names = symbol.Group(outputs).list_inputs()
in_input = lambda x: x.name in all_input_names
in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
else x
graph = symbol.Group(list(map(make_identity, outputs)))
return graph, num_outputs

Expand Down
6 changes: 3 additions & 3 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,13 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
// a subgraph.
API_BEGIN();
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
std::string subg_attr = "__subgraph_name__";
const std::string subg_attr = "__subgraph_name__";
auto out_node = s->outputs[0].node;
auto it = out_node->attrs.dict.find(subg_attr);
if (it != out_node->attrs.dict.end()) {
std::string subg_name = it->second;
const std::string &subg_name = it->second;
std::vector<nnvm::NodeEntry *> input_entries;
DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries]
DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
(nnvm::NodePtr n) {
// If the node itself isn't in the subgraph, we ignore it.
auto it = n->attrs.dict.find(subg_attr);
Expand Down
3 changes: 3 additions & 0 deletions src/operator/control_flow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,9 @@ static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs,
CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf));
return ret;
};
for (const dim_t &cond_in : params.cond_input_locs) {
(*out_attrs)[cond_in] = kDefaultStorage;
}
bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs);
bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs);
return succ_0 && succ_1;
Expand Down
101 changes: 101 additions & 0 deletions tests/python/unittest/test_contrib_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,107 @@ def test_foreach_rnn():
check_foreach_rnn(cell_type, num_states)


@with_seed()
def test_cut_subgraph_foreach():
class TestLayer(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestLayer, self).__init__(prefix=prefix, params=params)

def hybrid_forward(self, F, inputs, states):
def step1(data, states):
return data + 1, states
out1, states1 = F.contrib.foreach(step1, inputs, states)
out2, states2 = F.contrib.foreach(step1, out1, states)
def step2(data, states):
return data + states[0], states1
out, states = F.contrib.foreach(step2, out2, states)
return out

data = mx.nd.normal(loc=0, scale=1, shape=(5, 10))
states = mx.nd.normal(loc=0, scale=1, shape=(10))
layer = TestLayer()
layer.initialize(ctx=default_context())
res1 = layer(data, [states])

with mx.autograd.record():
res1 = layer(data, [states])

layer = TestLayer()
layer.initialize(ctx=default_context())
layer.hybridize()
res2 = layer(data, [states])

with mx.autograd.record():
res2 = layer(data, [states])
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)


@with_seed()
def test_cut_subgraph_while_loop():
class TestLayer(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestLayer, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, data):
out1, data1 = F.contrib.while_loop(
cond=lambda i: i <= 5,
func=lambda i: (None, (i + 1, )),
loop_vars=(data, ),
max_iterations=10,
)
out2, data2 = F.contrib.while_loop(
cond=lambda i: data1[0],
func=lambda i: (None, (i + 1, )),
loop_vars=data1[0],
max_iterations=10,
)
return data2[0]
data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
layer = TestLayer()
layer.initialize(ctx=default_context())
res1 = layer(data)
with mx.autograd.record():
res1 = layer(data)
layer = TestLayer()
layer.initialize(ctx=default_context())
layer.hybridize()
res2 = layer(data)
with mx.autograd.record():
res2 = layer(data)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)


@with_seed()
def test_cut_subgraph_cond():
class TestLayer(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(TestLayer, self).__init__(prefix=prefix, params=params)
def hybrid_forward(self, F, data):
(data1, ) = F.contrib.cond(
data > 0.5,
then_func=lambda: data * 2,
else_func=lambda: data * 3,
)
(data2, ) = F.contrib.cond(
data1 > 0.5,
then_func=lambda: data1 * 2,
else_func=lambda: data1 * 3,
)
return data2
data = mx.nd.normal(loc=0, scale=1, shape=(1, ))
layer = TestLayer()
layer.initialize(ctx=default_context())
res1 = layer(data)
with mx.autograd.record():
res1 = layer(data)
layer = TestLayer()
layer.initialize(ctx=default_context())
layer.hybridize()
res2 = layer(data)
with mx.autograd.record():
res2 = layer(data)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit af15853

Please sign in to comment.