From cd623759720bc379b428b530bc7ab81219eb74b1 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 8 Aug 2018 10:37:52 -0700 Subject: [PATCH] Fix cut graph --- python/mxnet/symbol/contrib.py | 26 ++++---------------------- src/c_api/c_api_symbolic.cc | 6 +++--- src/operator/control_flow.cc | 3 +++ 3 files changed, 10 insertions(+), 25 deletions(-) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 1d42cf7c18f8..fb624312e1bf 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -133,25 +133,11 @@ def _construct_subgraph(sym_out, sym_states): all_outputs = [] all_outputs.extend(sym_out) all_outputs.extend(sym_states) - g = symbol.Group(all_outputs) - flat_out = [] - all_input_names = g.list_inputs() - output_names = [o.name for o in sym_out] for o in sym_out: - if o.name in all_input_names: - flat_out.append(symbol.op.identity(o)) - else: - flat_out.append(o) - + flat_out.append(symbol.op.identity(o)) for s in sym_states: - if s.name in all_input_names or s.name in output_names: - # There is a problem if the outputs are the same as the inputs - # or the first output. By calling identity, we can make sure that - # all symbols will refer to different NDArrays. - flat_out.append(symbol.op.identity(s)) - else: - flat_out.append(s) + flat_out.append(symbol.op.identity(s)) return symbol.Group(flat_out) def foreach(body, data, init_states, name="foreach"): @@ -469,10 +455,8 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name): num_outputs = len(outputs) + len(final_state) # nnvm cut-graph does not allow inputs and outputs overlap # so we calculate the name of inputs, and copy outputs once it overlaps with inputs - all_input_names = symbol.Group(outputs + final_state).list_inputs() - make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x # group all outputs of graph_func - graph = symbol.Group(list(map(make_identity, outputs + final_state))) + graph = symbol.Group(list(map(symbol.op.identity, outputs + final_state))) return graph, num_out_data, num_outputs def _union_inputs(*graphs): @@ -627,10 +611,8 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name): num_outputs = len(outputs) # nnvm cut-graph does not allow inputs and outputs overlap # so we calculate the name of inputs, and copy outputs once it overlaps with inputs - all_input_names = symbol.Group(outputs).list_inputs() - make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x # group all outputs of graph_func - graph = symbol.Group(list(map(make_identity, outputs))) + graph = symbol.Group(list(map(symbol.op.identity, outputs))) return graph, num_outputs def _union_inputs(*graphs): diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index c27a59a67c6e..35ecec7e11f6 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -372,13 +372,13 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols, // a subgraph. API_BEGIN(); nnvm::Symbol *s = static_cast(sym); - std::string subg_attr = "__subgraph_name__"; + const std::string subg_attr = "__subgraph_name__"; auto out_node = s->outputs[0].node; auto it = out_node->attrs.dict.find(subg_attr); if (it != out_node->attrs.dict.end()) { - std::string subg_name = it->second; + const std::string &subg_name = it->second; std::vector input_entries; - DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries] + DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries] (nnvm::NodePtr n) { // If the node itself isn't in the subgraph, we ignore it. auto it = n->attrs.dict.find(subg_attr); diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index 7c1beccb0288..f4820149b246 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -1225,6 +1225,9 @@ static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs, CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf)); return ret; }; + for (const dim_t &cond_in: params.cond_input_locs) { + (*out_attrs)[cond_in] = 0; + } bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs); bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs); return succ_0 && succ_1;