Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Copy only when necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 9, 2018
1 parent 29abfbb commit 0333927
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,23 @@ def _cut_subgraph(subg):
# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
def _construct_subgraph(sym_out, sym_states):
def _construct_subgraph(sym_out, sym_states, name):
sym_out = _as_list(sym_out)
sym_states = _as_list(sym_states)
all_outputs = []
all_outputs.extend(sym_out)
all_outputs.extend(sym_states)
flat_out = []
for o in sym_out:
flat_out.append(symbol.op.identity(o))
if o.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(o))
else:
flat_out.append(o)
for s in sym_states:
flat_out.append(symbol.op.identity(s))
if s.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(s))
else:
flat_out.append(s)
return symbol.Group(flat_out)

def foreach(body, data, init_states, name="foreach"):
Expand Down Expand Up @@ -242,7 +248,7 @@ def check_data(inputs, in_type, msg):
num_out_data = len(sym_out)
num_states = len(sym_states)
num_outputs = num_out_data + num_states
g = _construct_subgraph(sym_out, sym_states)
g = _construct_subgraph(sym_out, sym_states, name)

input_syms = _get_graph_inputs(g)
cut_syms = _cut_subgraph(g)
Expand Down Expand Up @@ -456,7 +462,10 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
# group all outputs of graph_func
graph = symbol.Group(list(map(symbol.op.identity, outputs + final_state)))
make_identity = lambda x: \
x if x.list_attr().get("__subgraph_name__", "") == subgraph_name \
else symbol.op.identity(x)
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
return graph, num_out_data, num_outputs

def _union_inputs(*graphs):
Expand Down Expand Up @@ -612,7 +621,10 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
# group all outputs of graph_func
graph = symbol.Group(list(map(symbol.op.identity, outputs)))
make_identity = lambda x: \
x if x.list_attr().get("__subgraph_name__", "") == subgraph_name \
else symbol.op.identity(x)
graph = symbol.Group(list(map(make_identity, outputs)))
return graph, num_outputs

def _union_inputs(*graphs):
Expand Down

0 comments on commit 0333927

Please sign in to comment.