Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Merge pull request #1662 from NervanaSystems/aprocter/cherry-picks
Browse files Browse the repository at this point in the history
Cherry-pick "Common pass registration for codegen and Dex (#1642)" to r0.8
  • Loading branch information
Adam Procter authored Sep 21, 2018
2 parents 2822885 + f117269 commit 0ea01ea
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 36 deletions.
51 changes: 17 additions & 34 deletions src/ngraph/runtime/cpu/cpu_external_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,20 +368,8 @@ static void
writer << "}\n";
}

void runtime::cpu::CPU_ExternalFunction::compile()
void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Manager& pass_manager)
{
if (m_is_compiled)
{
return;
}

m_mkldnn_emitter.reset(new MKLDNNEmitter());

ngraph::pass::Manager pass_manager;

// nv_cwi is required only by some frontends
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::LikeReplacement>();
pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
Expand All @@ -396,11 +384,25 @@ void runtime::cpu::CPU_ExternalFunction::compile()
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
NodeVector nv_cwi; // We dont need CPUWorkspaceInsertion to return list of indices
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi, false);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
}

void runtime::cpu::CPU_ExternalFunction::compile()
{
if (m_is_compiled)
{
return;
}

m_mkldnn_emitter.reset(new MKLDNNEmitter());

ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager);
unordered_map<Node*, Node*> node_function_map;
string common_function_string;
auto femitter = bind(&ngraph::runtime::cpu::CPU_ExternalFunction::emit_op_as_function,
Expand Down Expand Up @@ -1132,27 +1134,8 @@ void runtime::cpu::CPU_ExternalFunction::build()
m_mkldnn_emitter.reset(new MKLDNNEmitter());

ngraph::pass::Manager pass_manager;
register_common_passes(pass_manager);

// nv_cwi is required only by some frontends
// in which case they should run this pass(CPUWorkspaceInsertion) explicitly
NodeVector nv_cwi;
pass_manager.register_pass<ngraph::pass::NopElimination>();
// TODO (pruthvi): Enable all the disabeled RNN fusion graph pass after fixing
// failing mxnet unit tests.
// pass_manager.register_pass<runtime::cpu::pass::LSTMFusion>();
// pass_manager.register_pass<runtime::cpu::pass::RNNFusion>();
// pass_manager.register_pass<runtime::cpu::pass::ConcatInputs>();
pass_manager.register_pass<ngraph::pass::AlgebraicSimplification>();
pass_manager.register_pass<runtime::cpu::pass::CPUBatchFusion>();
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
pass_manager.register_pass<ngraph::pass::CoreFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
pass_manager.register_pass<runtime::cpu::pass::CPUCollapseDims>();
pass_manager.register_pass<runtime::cpu::pass::CPUWorkspaceInsertion>(nv_cwi);
pass_manager.register_pass<runtime::cpu::pass::CPUAssignment>(this);
pass_manager.register_pass<runtime::cpu::pass::CPULayout>(this);
pass_manager.register_pass<runtime::cpu::pass::CPUPostLayoutOptimizations>();
pass_manager.register_pass<ngraph::pass::GetOutputElementElimination>();
pass_manager.register_pass<ngraph::pass::Liveness>();
pass_manager.register_pass<ngraph::pass::MemoryLayout>(size_t(s_memory_pool_alignment), true);
pass_manager.run_passes(m_function, false);
Expand Down
4 changes: 4 additions & 0 deletions src/ngraph/runtime/cpu/cpu_external_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#endif

#include "ngraph/function.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/runtime/cpu/cpu_call_frame.hpp"
#include "ngraph/runtime/cpu/cpu_layout_descriptor.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
Expand Down Expand Up @@ -139,6 +140,9 @@ namespace ngraph
#endif

private:
// Register passes that are common to codegen and DEX
void register_common_passes(ngraph::pass::Manager& pass_manager);

// For non-destructive passthrough kernels, propagate function
// input buffers to internal ops
void propagate_in_place_input(ngraph::descriptor::Output* output,
Expand Down
5 changes: 4 additions & 1 deletion src/ngraph/runtime/cpu/pass/cpu_workspace_insertion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ bool runtime::cpu::pass::CPUWorkspaceInsertion::transform(pattern::Matcher& m)
m_max_pool->get_padding_above());

ngraph::replace_node(m_max_pool_bprop, max_pool_with_indices_bprop);
m_indices_list.push_back(max_pool_with_indices_indices);
if (m_return_indices)
{
m_indices_list.push_back(max_pool_with_indices_indices);
}
return true;
}
4 changes: 3 additions & 1 deletion src/ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,17 @@ namespace ngraph
class ngraph::runtime::cpu::pass::CPUWorkspaceInsertion : public ngraph::pass::FunctionPass
{
public:
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list)
CPUWorkspaceInsertion(ngraph::NodeVector& indices_list, bool return_indices = true)
: FunctionPass()
, m_indices_list(indices_list)
, m_return_indices(return_indices)
{
}

virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);

private:
ngraph::NodeVector& m_indices_list;
bool m_return_indices;
bool transform(ngraph::pattern::Matcher& m);
};

0 comments on commit 0ea01ea

Please sign in to comment.