Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
added updates for reattach function
Browse files Browse the repository at this point in the history
  • Loading branch information
samskalicky committed Jan 16, 2020
1 parent fbf6a7e commit 1e4fdc5
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ void FindOutputEntries(nnvm::Graph* g,
*/
void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
std::vector<nnvm::NodeEntry> *orig_entries,
std::vector<nnvm::NodeEntry> *all_entries,
const bool skip_var = false) {
// map for creating unique var nodes for deduplicating entries from the same node
std::unordered_map<std::string, nnvm::NodeEntry> name_count_map;
Expand Down Expand Up @@ -584,7 +585,9 @@ void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,

// store the node in the map
name_count_map.emplace(var_name, e_);
}
}
all_entries->push_back(*e);

// lookup the name of the node and set it as the input dependency
*e = name_count_map[var_name];
}
Expand Down Expand Up @@ -618,8 +621,11 @@ void CreateSubgraphNode(nnvm::Graph* g,
#endif
std::vector<nnvm::NodeEntry*> input_entries;
FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
// deduplicated array of inputs to connect to subgraph
std::vector<nnvm::NodeEntry> orig_input_entries;
CutGraphInputs(input_entries, &orig_input_entries, false);
// all original input connections, used to reattach subgraph inputs
std::vector<nnvm::NodeEntry> all_input_entries;
CutGraphInputs(input_entries, &orig_input_entries, &all_input_entries, false);
#if DEBUG_SUBGRAPH
PrintNodeEntries(input_entries);
LOG(INFO) << "Searching for output entries...";
Expand Down Expand Up @@ -661,7 +667,7 @@ void CreateSubgraphNode(nnvm::Graph* g,
}
}
} else {
ReattachGraphInputs(input_entries, &orig_input_entries);
ReattachGraphInputs(input_entries, &all_input_entries);
}
#if DEBUG_SUBGRAPH
if (n)
Expand Down

0 comments on commit 1e4fdc5

Please sign in to comment.