Skip to content

Commit

Permalink
scale_calc_offline_pass and bug fix test=huawei_ascend_npu
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed Jan 18, 2022
1 parent 37c3961 commit a4dbfa9
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 50 deletions.
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,4 @@ USE_MIR_PASS(range_calc_offline_pass);
USE_MIR_PASS(p_norm_fill_constant_max_div_fuse_pass);
USE_MIR_PASS(fill_constant_calc_offline_pass);
USE_MIR_PASS(unsqueeze_calc_offline_pass);
USE_MIR_PASS(scale_calc_offline_pass);
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ int ConvertPad(Converter* converter, hal::Operation* operation) {
std::string pad_mode = ConvertPadModeCodeToGEPadMode(mode);
int32_t value =
static_cast<int32_t>(*reinterpret_cast<float*>(value_operand->buffer));
#if NNADAPTER_HUAWEI_ASCEND_NPU_CANN_VERSION_LESS_THAN(5, 0, 3)
NNADAPTER_CHECK_EQ(pad_mode, "constant")
<< "Only support mode=constant right now, "
"but received mode is "
<< pad_mode;
NNADAPTER_CHECK_EQ(value, 0) << "Only support constant_values=0 right now, "
"but received constant_value is "
<< value;
#endif
auto input_operator = converter->GetMappedOperator(input_operand);
if (!input_operator) {
input_operator = converter->ConvertOperand(input_operand);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ namespace huawei_ascend_npu {
int ConvertResizeLinear(Converter* converter, hal::Operation* operation) {
RESIZE_LINEAR_OPERATION_EXTRACT_INPUTS_OUTPUTS
NNADAPTER_CHECK(!(align_mode == 0 && align_corners))
<< "HuiweiAscendNPU does not support align_mode=0 and "
"align_corners=true.";
<< "Unsupported align_mode=0 when align_corners=true.";

// Convert to GE operators
auto resize_linear_op =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,6 @@ endif()
add_library(atc_register SHARED IMPORTED GLOBAL)
set_property(TARGET atc_register PROPERTY IMPORTED_LOCATION ${HUAWEI_ASCEND_NPU_SDK_ATC_REGISTER_FILE})

# libascend_protobuf.so
find_library(HUAWEI_ASCEND_NPU_SDK_ATC_ASCEND_PROTOBUF_FILE NAMES ascend_protobuf
PATHS ${NNADAPTER_HUAWEI_ASCEND_NPU_SDK_ROOT}/atc/lib64
CMAKE_FIND_ROOT_PATH_BOTH)
if(NOT HUAWEI_ASCEND_NPU_SDK_ATC_ASCEND_PROTOBUF_FILE)
message(FATAL_ERROR "Missing libascend_protobuf.so in ${NNADAPTER_HUAWEI_ASCEND_NPU_SDK_ROOT}/atc/lib64")
endif()
add_library(atc_ascend_protobuf SHARED IMPORTED GLOBAL)
set_property(TARGET atc_ascend_protobuf PROPERTY IMPORTED_LOCATION ${HUAWEI_ASCEND_NPU_SDK_ATC_ASCEND_PROTOBUF_FILE})

# libgraph.so
find_library(HUAWEI_ASCEND_NPU_SDK_ATC_GRAPH_FILE NAMES graph
PATHS ${NNADAPTER_HUAWEI_ASCEND_NPU_SDK_ROOT}/atc/lib64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ void InitializeGraphBuilder() {
std::map<ge::AscendString, ge::AscendString> global_options;
global_options.insert(
std::make_pair(ge::ir_option::SOC_VERSION, soc_version));
global_options.insert(std::make_pair(ge::ir_option::OP_DEBUG_LEVEL, "0"));
global_options.insert(std::make_pair(ge::ir_option::DEBUG_DIR, "/tmp/"));
ge::aclgrphBuildInitialize(global_options);
// Register 'FinalizeGraphBuilder' to be called at normal process
// termination
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace huawei_ascend_npu {
major, minor, patch) \
NNADAPTER_HUAWEI_ASCEND_NPU_CANN_MAJOR_VERSION * 1000 + \
NNADAPTER_HUAWEI_ASCEND_NPU_CANN_MINOR_VERSION * 100 + \
NNADAPTER_HUAWEI_ASCEND_NPU_CANN_PATCH_VERSION <= \
NNADAPTER_HUAWEI_ASCEND_NPU_CANN_PATCH_VERSION < \
major * 1000 + minor * 100 + patch

// Prepare AscendCL environment and register the finalizer to be called at
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ void FillConstantCalcOfflinePass::RemoveFillConstantPattern(

REGISTER_MIR_PASS(fill_constant_calc_offline_pass,
paddle::lite::mir::FillConstantCalcOfflinePass)
.BindTargets({TARGET(kNNAdapter)});
.BindTargets({TARGET(kNNAdapter), TARGET(kARM), TARGET(kX86)});
87 changes: 87 additions & 0 deletions lite/core/optimizer/mir/elimination/scale_calc_offline_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/core/optimizer/mir/elimination/scale_calc_offline_pass.h"
#include <algorithm>
#include <cmath>
#include <list>
#include <memory>
#include <set>
#include <vector>
#include "lite/core/optimizer/mir/pass.h"
#include "lite/core/optimizer/mir/pass_registry.h"
#include "lite/core/optimizer/mir/pattern_matcher.h"
#include "lite/model_parser/cpp_desc.h"

namespace paddle {
namespace lite {
namespace mir {

void ScaleCalcOfflinePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
RemoveScalePattern(graph);
}

void ScaleCalcOfflinePass::RemoveScalePattern(
const std::unique_ptr<SSAGraph>& graph) {
for (auto& node : graph->StmtTopologicalOrder()) {
if (node->AsStmt().op_type() != "scale") continue;

std::set<const Node*> nodes2rm_;
auto& scale_instruct = node->AsStmt();
auto* scope = scale_instruct.op()->scope();
auto op_desc = scale_instruct.mutable_op_info();
// Get scale's input tensor
auto x_var = scope->FindVar(op_desc->Input("X").front());
auto x_t = x_var->GetMutable<lite::Tensor>();
if (!x_t->persistable()) {
LOG(WARNING) << "ScaleCalcOfflinePass does not support input that is not "
"persistable";
continue;
}
auto x_data = x_t->mutable_data<float>();
auto x_dims = x_t->dims();
// Get scale's attr
auto scale = op_desc->GetAttr<float>("scale");
auto bias = op_desc->GetAttr<float>("bias");
auto bias_after_scale = op_desc->GetAttr<bool>("bias_after_scale");
if (!bias_after_scale) {
bias *= scale;
}
// Get scale's output tensor
auto out_var = scope->FindVar(op_desc->Output("Out").front());
auto out_t = out_var->GetMutable<lite::Tensor>();
out_t->Resize(x_dims);
auto out_data = out_t->mutable_data<float>();
for (int i = 0; i < x_dims.production(); i++) {
out_data[i] = x_data[i] * scale + bias;
}

// Offline calc scale, only retain output tensor as persistable tensor
out_t->set_persistable(true);
auto scale_outlinks = node->outlinks;
for (auto& scale_out_link : scale_outlinks) {
scale_out_link->arg()->is_weight = true;
}
nodes2rm_.insert(node);
GraphSafeRemoveNodes(graph.get(), nodes2rm_);
}
}

} // namespace mir
} // namespace lite
} // namespace paddle

REGISTER_MIR_PASS(scale_calc_offline_pass,
paddle::lite::mir::ScaleCalcOfflinePass)
.BindTargets({TARGET(kNNAdapter), TARGET(kARM), TARGET(kX86)});
38 changes: 38 additions & 0 deletions lite/core/optimizer/mir/elimination/scale_calc_offline_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "lite/core/optimizer/mir/pass.h"
#include "lite/core/optimizer/mir/pass_registry.h"
#include "lite/core/tensor.h"
#include "lite/core/types.h"

namespace paddle {
namespace lite {
namespace mir {

class ScaleCalcOfflinePass : public mir::StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
void RemoveScalePattern(const std::unique_ptr<SSAGraph>& graph);
};

} // namespace mir
} // namespace lite
} // namespace paddle
30 changes: 12 additions & 18 deletions lite/core/optimizer/mir/elimination/unsqueeze_calc_offline_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,18 @@ void UnsqueezeCalcOfflinePass::RemoveUnsqueezePattern(
auto out_t = out_var->GetMutable<lite::Tensor>();
std::vector<int64_t> output_shape(input_shape);
output_shape.insert(output_shape.end(), axes.size(), 1);

auto infer_output_shape = [&](int64_t* input_dimensions,
int64_t* output_dimensions) {
uint32_t cur_size = input_shape.size();
for (size_t i = 0; i < axes.size(); i++) {
int32_t axis = axes[i] < 0 ? axes[i] + cur_size + 1 : axes[i];
CHECK_GE(axis, 0);
CHECK_LE(axis, cur_size);
for (uint32_t j = cur_size; j > axis; j--) {
output_dimensions[j] = output_dimensions[j - 1];
}
output_dimensions[axis] = 1;
cur_size++;
}
};

out_t->CopyDataFrom(*input_t);
infer_output_shape(input_shape.data(), output_shape.data());
uint32_t cur_size = input_shape.size();
for (size_t i = 0; i < axes.size(); i++) {
int32_t axis = axes[i] < 0 ? axes[i] + cur_size + 1 : axes[i];
CHECK_GE(axis, 0);
CHECK_LE(axis, cur_size);
for (uint32_t j = cur_size; j > axis; j--) {
output_shape[j] = output_shape[j - 1];
}
output_shape[axis] = 1;
cur_size++;
}
out_t->Resize(DDim(output_shape));
// Offline calc unsqueeze, only retain output tensor as persistable
// tensor
Expand All @@ -98,4 +92,4 @@ void UnsqueezeCalcOfflinePass::RemoveUnsqueezePattern(

REGISTER_MIR_PASS(unsqueeze_calc_offline_pass,
paddle::lite::mir::UnsqueezeCalcOfflinePass)
.BindTargets({TARGET(kNNAdapter)});
.BindTargets({TARGET(kNNAdapter), TARGET(kARM), TARGET(kX86)});
29 changes: 14 additions & 15 deletions lite/core/optimizer/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,15 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
Optimizer optim(valid_places, kernel_pick_factor);

std::vector<std::string> passes_local{
{"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", //
"op_transformation_pass", //
"remove_scale1_pass", //
"adaptive_1x1_pool2d_convert_global_pass", //
"lite_unsqueeze2_pad3d_squeeze2_fuse_pass", //

"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
{"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", //
"op_transformation_pass", //
"remove_scale1_pass", //
"adaptive_1x1_pool2d_convert_global_pass", //
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
"lite_conv_conv_fuse_pass", //
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
Expand Down Expand Up @@ -174,10 +172,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
"lite_conv_elementwise_tree_fuse_pass",
"lite_greater_than_cast_fuse_pass",
"fill_range_fuse_pass",
"range_calc_offline_pass",
"p_norm_fill_constant_max_div_fuse_pass",
"fill_constant_calc_offline_pass",
"unsqueeze_calc_offline_pass",
"identity_dropout_eliminate_pass",
"sparse_conv_detect_pass",
"__xpu__max_pooling_pad_zero_detect_fuse_pass",
Expand All @@ -203,8 +198,12 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
"fix_mismatched_precision_pass",
"__xpu__dynamic_lstm_fuse_pass",
"__xpu__multi_softmax_fuse_pass",
"ssd_boxes_calc_offline_pass",
"assign_value_calc_offline_pass",
"range_calc_offline_pass",
"fill_constant_calc_offline_pass",
"scale_calc_offline_pass",
"unsqueeze_calc_offline_pass",
"ssd_boxes_calc_offline_pass",
// Only for fully quantized model, infer the output scale and fix the
// attribute 'enable_int8' for all of the quantized ops.
"quantized_op_attributes_inference_pass",
Expand Down
6 changes: 4 additions & 2 deletions lite/kernels/nnadapter/converter/lookup_table_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ int ConvertLookupTableV2(Converter* converter, OpInfo* op, Scope* scope) {

// Padding_idx
if (op->HasAttr("padding_idx")) {
auto padding_idx = op->GetAttr<int64_t>("padding_idx");
// TODO(zhupengyang): support padding_idx later.
CHECK_EQ(op->GetAttr<int64_t>("padding_idx"), -1L)
<< "Only support padding_idx = -1";
if (padding_idx != -1 || padding_idx != 0) {
LOG(FATAL) << "Only support padding_idx = -1 or 0";
}
}

// Output operand
Expand Down
1 change: 0 additions & 1 deletion lite/operators/lookup_table_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ bool LookupTableV2OpLite::AttachImpl(const cpp::OpDesc &op_desc,
param_.Out = scope->FindMutableTensor(out);

param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");

return true;
}

Expand Down

0 comments on commit a4dbfa9

Please sign in to comment.