Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
added updates for reattach function
Browse files Browse the repository at this point in the history
  • Loading branch information
samskalicky authored and Ubuntu committed Feb 13, 2020
1 parent fdd477b commit 7a210d6
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ void FindOutputEntries(nnvm::Graph* g,
*/
void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
std::vector<nnvm::NodeEntry> *orig_entries,
std::vector<nnvm::NodeEntry> *all_entries,
const bool skip_var = false) {
// map for creating unique var nodes for deduplicating entries from the same node
std::unordered_map<std::string, nnvm::NodeEntry> name_count_map;
Expand Down Expand Up @@ -587,6 +588,8 @@ void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
n->attrs.dict["isArg"] = "True";
else
n->attrs.dict["isArg"] = "False";
all_entries->push_back(*e);

// lookup the name of the node and set it as the input dependency
*e = name_count_map[var_name];
}
Expand Down Expand Up @@ -620,8 +623,11 @@ void CreateSubgraphNode(nnvm::Graph* g,
#endif
std::vector<nnvm::NodeEntry*> input_entries;
FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
// deduplicated array of inputs to connect to subgraph
std::vector<nnvm::NodeEntry> orig_input_entries;
CutGraphInputs(input_entries, &orig_input_entries, false);
// all original input connections, used to reattach subgraph inputs
std::vector<nnvm::NodeEntry> all_input_entries;
CutGraphInputs(input_entries, &orig_input_entries, &all_input_entries, false);
#if DEBUG_SUBGRAPH
PrintNodeEntries(input_entries);
LOG(INFO) << "Searching for output entries...";
Expand Down Expand Up @@ -663,7 +669,7 @@ void CreateSubgraphNode(nnvm::Graph* g,
}
}
} else {
ReattachGraphInputs(input_entries, &orig_input_entries);
ReattachGraphInputs(input_entries, &all_input_entries);
}
#if DEBUG_SUBGRAPH
if (n)
Expand Down

0 comments on commit 7a210d6

Please sign in to comment.