Skip to content

Commit

Permalink
gru dynamic quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 26, 2024
1 parent 3c425e9 commit cacf762
Show file tree
Hide file tree
Showing 7 changed files with 1,715 additions and 2,476 deletions.
2,066 changes: 571 additions & 1,495 deletions src/layer/arm/gru_arm.cpp

Large diffs are not rendered by default.

14 changes: 9 additions & 5 deletions src/layer/arm/gru_arm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ class GRU_arm : public GRU
virtual int forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;

protected:
#if NCNN_INT8
int create_pipeline_int8(const Option& opt);
#endif
#if NCNN_ARM82
int create_pipeline_fp16s(const Option& opt);
int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;
Expand All @@ -42,15 +39,22 @@ class GRU_arm : public GRU
int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;
int forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;
#endif
#if NCNN_INT8
int create_pipeline_int8(const Option& opt);
void dynamic_quantize(const Mat& bottom_blob, int elemtype, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const;
int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;
int forward_int8(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const;
#endif

public:
Mat weight_xc_data_packed;
Mat bias_c_data_packed;
Mat weight_hc_data_packed;

Mat weight_data_tm;

#if NCNN_INT8
Mat weight_hc_data_int8_descales_packed;
Mat weight_xc_data_int8_descales_packed;
Mat weight_data_tm_int8_descales;
#endif
};

Expand Down
35 changes: 35 additions & 0 deletions src/layer/arm/gru_arm_asimddp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "cpu.h"
#include "mat.h"
#include "layer.h"
#include "arm_activation.h"
#include "arm_usability.h"

namespace ncnn {

#include "gru_int8.h"

void gru_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt)
{
gru_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, opt);
}

void gru_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt)
{
gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt);
}

} // namespace ncnn
Loading

0 comments on commit cacf762

Please sign in to comment.