diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index b7f41c4888f..3344d30b8ff 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -72,7 +72,8 @@ USE_MIR_PASS(fp16_attribute_pass); USE_MIR_PASS(apu_subgraph_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); USE_MIR_PASS(restrict_quantized_op_with_same_input_output_scale_pass); -USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass) +USE_MIR_PASS(control_flow_op_unused_inputs_and_outputs_eliminate_pass); +USE_MIR_PASS(control_flow_op_shared_inputs_and_outputs_place_sync_pass); USE_MIR_PASS(lite_scale_activation_fuse_pass); USE_MIR_PASS(lite_instance_norm_activation_fuse_pass); USE_MIR_PASS(lite_fc_prelu_fuse_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index aed35b18671..ff5898504f1 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -61,6 +61,7 @@ lite_cc_library(mir_passes elimination/remove_scale1_pass.cc adaptive_1x1_pool2d_convert_global_pass.cc elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.cc + control_flow_op_shared_inputs_and_outputs_place_sync_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc fpga_kernel_place_correct_pass.cc diff --git a/lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.cc b/lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.cc new file mode 100644 index 00000000000..bb166ab222b --- /dev/null +++ b/lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void CheckAndSyncTypeOfVarNode( + Node* sub_var_node, + const std::unordered_map& ref_var_types) { + CHECK(sub_var_node->IsArg()); + auto& sub_var_name = sub_var_node->AsArg().name; + if (ref_var_types.count(sub_var_name)) { + sub_var_node->AsArg().type = ref_var_types.at(sub_var_name); + } +} + +void ControlFlowOpSharedInputsAndOutputsPlaceSyncPass::SetAllGraphs( + std::vector>* graphs) { + CHECK(graphs && !graphs->empty()); + graphs_ = graphs; +} + +void ControlFlowOpSharedInputsAndOutputsPlaceSyncPass::Apply( + const std::unique_ptr& graph) { + const std::unordered_set control_flow_op_types = { + "while", "conditional_block"}; + auto block_size = graphs_->size(); + for (auto& op_node : graph->StmtTopologicalOrder()) { + if (!op_node->IsStmt()) continue; + auto op_info = op_node->AsStmt().mutable_op_info(); + auto op_type = op_info->Type(); + if (!control_flow_op_types.count(op_type)) continue; + int sub_block_idx = op_info->GetAttr("sub_block"); + CHECK(sub_block_idx >= 0 && sub_block_idx < block_size); + std::unordered_map ref_var_types; + for (auto* var_node : op_node->inlinks) { + CHECK(var_node->IsArg()); + auto& var_name = var_node->AsArg().name; + if (!ref_var_types.count(var_name)) { + ref_var_types.insert(std::pair( + var_name, var_node->AsArg().type)); + } + } + for (auto& sub_op_node : + (*graphs_)[sub_block_idx]->StmtTopologicalOrder()) { + if (!sub_op_node->IsStmt()) continue; + for (auto* sub_var_node : sub_op_node->inlinks) { + CheckAndSyncTypeOfVarNode(sub_var_node, ref_var_types); + } + for (auto* sub_var_node : sub_op_node->outlinks) { + CheckAndSyncTypeOfVarNode(sub_var_node, ref_var_types); + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS( + control_flow_op_shared_inputs_and_outputs_place_sync_pass, + paddle::lite::mir::ControlFlowOpSharedInputsAndOutputsPlaceSyncPass) + .BindTargets({TARGET(kXPU)}); diff --git a/lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.h b/lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.h new file mode 100644 index 00000000000..c449993dbb8 --- /dev/null +++ b/lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.h @@ -0,0 +1,93 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "lite/core/mir/pass.h" +#include "lite/core/types.h" + +namespace paddle { +namespace lite { +namespace mir { + +// Sync the type of variable to the shared one in subblocks +// +// For example: +// graph[0]: main block +// in_x(target:x86) +// | +// | +// | +// while(target:host) ------- in_w(target:x86) +// | +// | +// | +// out_x(target:host) +// +// graph[1]: sub block +// in_x(target:xpu) +// | +// | +// | +// fc(target:xpu) ------ in_w(target:x86) +// | +// | +// softmax(target:xpu) +// | +// | +// out_x(target:xpu) +// +// After the pass is applied: +// +// graph[0]: main block +// in_x(target:x86) +// | +// | +// | +// while(target:host) ------- in_w(target:x86) +// | +// | +// | +// out_x(target:host) +// +// graph[1]: sub block +// in_x(target:x86) +// | +// | +// | +// fc(target:xpu) ------ in_w(target:x86) +// | +// | +// softmax(target:xpu) +// | +// | +// out_x(target:host) + +class ControlFlowOpSharedInputsAndOutputsPlaceSyncPass : public mir::StmtPass { + public: + void Apply(const std::unique_ptr &graph) override; + void SetAllGraphs(std::vector> *graphs); + + private: + std::vector> *graphs_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/static_kernel_pick_pass.cc b/lite/core/mir/static_kernel_pick_pass.cc index b7b2e6e72cb..ffc2f69e8cd 100644 --- a/lite/core/mir/static_kernel_pick_pass.cc +++ b/lite/core/mir/static_kernel_pick_pass.cc @@ -87,7 +87,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr& graph) { // Just keep a single best kernel. // TODO(Superjomn) reconsider this. instruct.kernels().emplace_back(std::move(scored.front().second)); - VLOG(2) << "pick " << instruct.kernels().front()->name() << "\n\n"; + VLOG(2) << "pick " << instruct.kernels().front()->summary() << "\n\n"; } else { bool out_type_int8 = true; diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 5aa6c0c256f..f34fc806fb8 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -19,6 +19,7 @@ #include #include #include +#include "lite/core/mir/control_flow_op_shared_inputs_and_outputs_place_sync_pass.h" #include "lite/core/mir/elimination/control_flow_op_unused_inputs_and_outputs_eliminate_pass.h" #include "lite/core/mir/fp16_attribute_pass.h" #include "lite/core/mir/generate_program_pass.h" @@ -80,6 +81,7 @@ class Optimizer { SpecifyKernelPickTactic(kernel_pick_factor); InitTargetTypeTransformPass(); InitControlFlowOpUnusedInputsAndOutputsEliminatePass(); + InitControlFlowOpSharedInputsAndOutputsPlaceSyncPass(); std::vector passes_local{ {"lite_quant_dequant_fuse_pass", // @@ -157,6 +159,7 @@ class Optimizer { "remove_tf_redundant_ops_pass", "variable_place_inference_pass", // inference arg/var's + "control_flow_op_shared_inputs_and_outputs_place_sync_pass", "__fpga_kernel_place_correct_pass", "mlu_postprocess_pass", // info(target/precision/layout/device) @@ -169,23 +172,27 @@ class Optimizer { // different targets when last and next // node "variable_place_inference_pass", // - "argument_type_display_pass", // + "control_flow_op_shared_inputs_and_outputs_place_sync_pass", + "argument_type_display_pass", // "io_copy_kernel_pick_pass", // "argument_type_display_pass", // "variable_place_inference_pass", // - "argument_type_display_pass", // + "control_flow_op_shared_inputs_and_outputs_place_sync_pass", + "argument_type_display_pass", // "type_precision_cast_pass", // "variable_place_inference_pass", // - "argument_type_display_pass", // + "control_flow_op_shared_inputs_and_outputs_place_sync_pass", + "argument_type_display_pass", // "type_layout_cast_pass", // add layout/layout_once op if meet // different layout when last and next node "argument_type_display_pass", // "variable_place_inference_pass", // + "control_flow_op_shared_inputs_and_outputs_place_sync_pass", "argument_type_display_pass", "runtime_context_assign_pass", @@ -277,6 +284,16 @@ class Optimizer { pass->SetAllGraphs(&graphs_); } + void InitControlFlowOpSharedInputsAndOutputsPlaceSyncPass() { + auto* pass = + mir::PassManager::Global() + .LookUp( + "control_flow_op_shared_inputs_and_outputs_place_sync_pass"); + CHECK(pass); + CHECK(!graphs_.empty()); + pass->SetAllGraphs(&graphs_); + } + // Generate C++ code which combines the inference program, model and weights. void GenCode(const std::string& code_dir);