Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Copy only when necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 9, 2018
1 parent 29abfbb commit 031b8f1
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,29 @@ def _cut_subgraph(subg):
# This construct a subgraph for given output nodes.
# If an output node is one of the input nodes, we call identity to make sure
# that outputs nodes are different from input nodes.
def _construct_subgraph(sym_out, sym_states):
def _construct_subgraph(sym_out, sym_states, name):
sym_out = _as_list(sym_out)
sym_states = _as_list(sym_states)
all_outputs = []
all_outputs.extend(sym_out)
all_outputs.extend(sym_states)
g = symbol.Group(all_outputs)

flat_out = []
all_input_names = g.list_inputs()
output_names = {o.name for o in sym_out}
for o in sym_out:
flat_out.append(symbol.op.identity(o))
if o.name in all_input_names or o.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(o))
else:
flat_out.append(o)

for s in sym_states:
flat_out.append(symbol.op.identity(s))
if s.name in all_input_names or s.name in output_names or \
s.list_attr().get("__subgraph_name__", "") != name:
flat_out.append(symbol.op.identity(s))
else:
flat_out.append(s)
return symbol.Group(flat_out)

def foreach(body, data, init_states, name="foreach"):
Expand Down Expand Up @@ -242,7 +254,7 @@ def check_data(inputs, in_type, msg):
num_out_data = len(sym_out)
num_states = len(sym_states)
num_outputs = num_out_data + num_states
g = _construct_subgraph(sym_out, sym_states)
g = _construct_subgraph(sym_out, sym_states, name)

input_syms = _get_graph_inputs(g)
cut_syms = _cut_subgraph(g)
Expand Down Expand Up @@ -456,7 +468,12 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
# group all outputs of graph_func
graph = symbol.Group(list(map(symbol.op.identity, outputs + final_state)))
all_input_names = symbol.Group(outputs + final_state).list_inputs()
in_input = lambda x: x.name in all_input_names
in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
else x
graph = symbol.Group(list(map(make_identity, outputs + final_state)))
return graph, num_out_data, num_outputs

def _union_inputs(*graphs):
Expand Down Expand Up @@ -612,7 +629,12 @@ def _create_subgraph(graph_vars, graph_func, subgraph_name):
# nnvm cut-graph does not allow inputs and outputs overlap
# so we calculate the name of inputs, and copy outputs once it overlaps with inputs
# group all outputs of graph_func
graph = symbol.Group(list(map(symbol.op.identity, outputs)))
all_input_names = symbol.Group(outputs).list_inputs()
in_input = lambda x: x.name in all_input_names
in_graph = lambda x: x.list_attr().get("__subgraph_name__", "") == subgraph_name
make_identity = lambda x: symbol.op.identity(x) if in_input(x) or not in_graph(x) \
else x
graph = symbol.Group(list(map(make_identity, outputs)))
return graph, num_outputs

def _union_inputs(*graphs):
Expand Down

0 comments on commit 031b8f1

Please sign in to comment.