Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix cut graph
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 8, 2018
1 parent 9b18af8 commit be5bd89
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 25 deletions.
26 changes: 4 additions & 22 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,11 @@ def _construct_subgraph(sym_out, sym_states):
all_outputs = []
all_outputs.extend(sym_out)
all_outputs.extend(sym_states)
g = symbol.Group(all_outputs)

flat_out = []
all_input_names = g.list_inputs()
output_names = [o.name for o in sym_out]
for o in sym_out:
if o.name in all_input_names:
flat_out.append(symbol.op.identity(o))
else:
flat_out.append(o)

flat_out.append(symbol.op.identity(o))
for s in sym_states:
if s.name in all_input_names or s.name in output_names:
# There is a problem if the outputs are the same as the inputs
# or the first output. By calling identity, we can make sure that
# all symbols will refer to different NDArrays.
flat_out.append(symbol.op.identity(s))
else:
flat_out.append(s)
flat_out.append(symbol.op.identity(s))
return symbol.Group(flat_out)

def foreach(body, data, init_states, name="foreach"):
Expand Down Expand Up @@ -469,10 +455,8 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
num_outputs = len(outputs) + len(final_state)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
all_input_names = symbol.Group(outputs + final_state).list_inputs()
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
graph = symbol.Group(list(map(symbol.op.identity, outputs + final_state)))
return graph, num_out_data, num_outputs

def _union_inputs(*graphs):
Expand Down Expand Up @@ -627,10 +611,8 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
num_outputs = len(outputs)
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
all_input_names = symbol.Group(outputs).list_inputs()
make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x
# group all outputs of graph_func
graph = symbol.Group(list(map(make_identity, outputs)))
graph = symbol.Group(list(map(symbol.op.identity, outputs)))
return graph, num_outputs

def _union_inputs(*graphs):
Expand Down
6 changes: 3 additions & 3 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,13 @@ int MXSymbolCutSubgraph(SymbolHandle sym, SymbolHandle **input_symbols,
// a subgraph.
API_BEGIN();
nnvm::Symbol *s = static_cast<nnvm::Symbol*>(sym);
std::string subg_attr = "__subgraph_name__";
const std::string subg_attr = "__subgraph_name__";
auto out_node = s->outputs[0].node;
auto it = out_node->attrs.dict.find(subg_attr);
if (it != out_node->attrs.dict.end()) {
std::string subg_name = it->second;
const std::string &subg_name = it->second;
std::vector<nnvm::NodeEntry *> input_entries;
DFSVisit(s->outputs, [subg_attr, subg_name, &input_entries]
DFSVisit(s->outputs, [&subg_attr, &subg_name, &input_entries]
(nnvm::NodePtr n) {
// If the node itself isn't in the subgraph, we ignore it.
auto it = n->attrs.dict.find(subg_attr);
Expand Down

0 comments on commit be5bd89

Please sign in to comment.