diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc index 84b30ea2d79d..31b6825f9ed7 100644 --- a/src/operator/subgraph/build_subgraph.cc +++ b/src/operator/subgraph/build_subgraph.cc @@ -554,6 +554,7 @@ void FindOutputEntries(nnvm::Graph* g, */ void CutGraphInputs(const std::vector &input_entries, std::vector *orig_entries, + std::vector *all_entries, const bool skip_var = false) { // map for creating unique var nodes for deduplicating entries from the same node std::unordered_map name_count_map; @@ -587,6 +588,8 @@ void CutGraphInputs(const std::vector &input_entries, n->attrs.dict["isArg"] = "True"; else n->attrs.dict["isArg"] = "False"; + all_entries->push_back(*e); + // lookup the name of the node and set it as the input dependency *e = name_count_map[var_name]; } @@ -620,8 +623,11 @@ void CreateSubgraphNode(nnvm::Graph* g, #endif std::vector input_entries; FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries); + // deduplicated array of inputs to connect to subgraph std::vector orig_input_entries; - CutGraphInputs(input_entries, &orig_input_entries, false); + // all original input connections, used to reattach subgraph inputs + std::vector all_input_entries; + CutGraphInputs(input_entries, &orig_input_entries, &all_input_entries, false); #if DEBUG_SUBGRAPH PrintNodeEntries(input_entries); LOG(INFO) << "Searching for output entries..."; @@ -663,7 +669,7 @@ void CreateSubgraphNode(nnvm::Graph* g, } } } else { - ReattachGraphInputs(input_entries, &orig_input_entries); + ReattachGraphInputs(input_entries, &all_input_entries); } #if DEBUG_SUBGRAPH if (n)