Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
[v0.1.0] Multi-output fprop_cache tentative fix (#657)
Browse files Browse the repository at this point in the history
Contains multiple fixes to GetOutputElement, BatchNorm, autodiff, fprop_cache to integrate multi-output batchnorm and fprop_cache
  • Loading branch information
Krovatkin authored Mar 18, 2018
1 parent feeaed5 commit 995671a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/ngraph/autodiff/adjoints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ std::shared_ptr<Node> autodiff::Adjoints::get(const std::shared_ptr<Node>& x)
void autodiff::Adjoints::add_delta(const std::shared_ptr<Node>& x,
const std::shared_ptr<Node>& delta)
{
if (!x->has_same_type(delta))
if (!x->has_same_type(delta) && delta->get_shape() != x->get_outputs().at(0).get_shape())
{
throw ngraph_error("Autodiff internal error: Mismatch on backprop and op in add_delta.");
}
Expand Down
19 changes: 17 additions & 2 deletions src/ngraph/ops/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,23 @@ void ngraph::op::BatchNorm::generate_adjoints(autodiff::Adjoints& adjoints,
auto gamma = get_input_op(0);
auto beta = get_input_op(1);
auto input = get_input_op(2);
auto mean = std::make_shared<op::GetOutputElement>(shared_from_this(), 1);
auto var = std::make_shared<op::GetOutputElement>(shared_from_this(), 2);

//Extract mean and variance outputs from BatchNorm
//as these are used by BatchNormBackprop.
//The users of the outputs (GetOutputElements' Inputs) aren't sorted
//and get_n() is used to sort the inputs in the same order as Batchnorm's outputs
//Next, Mean and Variance (`at(1)` and `at(2)`) are extracted
//Please see `add_output` in `BatchNorm::BatchNorm` for more details
std::vector<std::shared_ptr<Node>> goes(get_outputs().size());

for (auto _input : get_output_inputs(0))
{
auto goe = std::dynamic_pointer_cast<op::GetOutputElement>(_input->get_node());
goes.at(goe->get_n()) = _input->get_node();
}

auto mean = goes.at(1);
auto var = goes.at(2);
auto bbn = std::make_shared<op::BatchNormBackprop>(
get_eps_value(), gamma, beta, input, mean, var, delta);
auto dinput = std::make_shared<op::GetOutputElement>(bbn, 0);
Expand Down
11 changes: 11 additions & 0 deletions src/ngraph/ops/get_output_element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ namespace ngraph
}

protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const std::shared_ptr<Node>& delta) override
{
//Filter out updates(deltas) from mean and variance (for batchnorm)
//as dinput is the only update required.
//This logic needs to be generalized as new multi-output ops are introduced
if (get_n() == 0)
{
adjoints.add_delta(get_inputs().at(0).get_output().get_node(), delta);
}
}
size_t m_n;
};
}
Expand Down
61 changes: 22 additions & 39 deletions src/ngraph/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,56 +189,39 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
{
using namespace ngraph;

// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
NodeMap node_param_map;
ngraph::traverse_nodes(fprop, [&node_param_map](std::shared_ptr<Node> node) {
node_param_map.add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
});

// Traverse bprop to find all of the nodes in the graph
std::unordered_set<std::shared_ptr<Node>> in_bprop;
ngraph::traverse_nodes(bprop, [&in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) == 0)

if (node->get_outputs().size() == 1)
{
in_bprop.insert(node);
if (in_bprop.count(node) == 0)
{
in_bprop.insert(node);
}
}

});

// Get the input paramters of fprop
std::unordered_set<std::shared_ptr<Node>> fprop_params;
for (auto node : fprop->get_parameters())
{
if (fprop_params.count(node) == 0)
// Traverse fprop to make a map that stores parameters with the same
// shape and element type as the nodes in fprop
FpropCache fprop_cache;
fprop_cache.node_param_map = std::make_shared<NodeMap>();
ngraph::traverse_nodes(fprop, [&fprop_cache, &in_bprop](std::shared_ptr<Node> node) {
if (in_bprop.count(node) != 0)
{
fprop_params.insert(node);
fprop_cache.node_param_map->add(
node, std::make_shared<op::Parameter>(node->get_element_type(), node->get_shape()));
}
}
});

// Find all of the nodes that are intermediate values of fprop and used in
// bprop
// and store those nodes that aren't needed in bprop
FpropCache fprop_cache;
std::vector<std::shared_ptr<Node>> unused_nodes;
for (auto kv : node_param_map.get_node_map())
{
// if it's not in bprop, mark it unused
if (in_bprop.count(kv.first) == 0)
{
unused_nodes.push_back(kv.first);
}
// otherwise save in in the ouputs
else
{
fprop_cache.fprop_output_nodes.push_back(kv.first);
}
}

// erase all unused nodes form the map
for (auto node : unused_nodes)
for (auto kv : fprop_cache.node_param_map->get_node_map())
{
node_param_map.get_node_map().erase(node);
fprop_cache.fprop_output_nodes.push_back(kv.first);
}

// create the new outputs for fprop and the new fprop function
Expand All @@ -262,13 +245,13 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,

// clone the nodes in bprop, replacing fprop-related nodes with the
// intermediate parameters
ngraph::clone_nodes(bprop->get_ops(), node_param_map);
ngraph::clone_nodes(bprop->get_ops(), *(fprop_cache.node_param_map));

// get cloned bprop results
ResultVector cloned_results;
for (auto node : bprop->get_results())
{
auto result = std::dynamic_pointer_cast<op::Result>(node_param_map.get(node));
auto result = std::dynamic_pointer_cast<op::Result>(fprop_cache.node_param_map->get(node));
if (!result)
{
throw ngraph_error("Expected op::Result values for op::Result keys in node_param_map");
Expand All @@ -281,14 +264,14 @@ ngraph::FpropCache ngraph::cache_fprop(std::shared_ptr<ngraph::Function> fprop,
for (auto param : adjoints)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(param)));
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(param)));
}

// add the cached fprop nodes as inputs to bprop
for (auto x : fprop_cache.fprop_output_nodes)
{
bprop_input_params.push_back(
std::dynamic_pointer_cast<op::Parameter>(node_param_map.get(x)));
std::dynamic_pointer_cast<op::Parameter>(fprop_cache.node_param_map->get(x)));
}

// create the new bprop function
Expand Down
2 changes: 2 additions & 0 deletions src/ngraph/util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace ngraph
{
class Node;
class Function;
class NodeMap;
class stopwatch;

namespace runtime
Expand Down Expand Up @@ -229,6 +230,7 @@ namespace ngraph
std::shared_ptr<Function> fprop;
std::shared_ptr<Function> bprop;
std::vector<std::shared_ptr<Node>> fprop_output_nodes;
std::shared_ptr<NodeMap> node_param_map;
};

/**
Expand Down

0 comments on commit 995671a

Please sign in to comment.