Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 30, 2024
1 parent 5fd9ab3 commit e0575ea
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 145 deletions.
30 changes: 30 additions & 0 deletions src/layer/arm/lstm_arm_vfpv4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "cpu.h"
#include "mat.h"
#include "layer.h"
#include "arm_activation.h"
#include "arm_usability.h"

namespace ncnn {

#include "lstm_int8.h"

void lstm_int8_gate_output_vfpv4(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
{
lstm_int8_gate_output(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
}

} // namespace ncnn
336 changes: 192 additions & 144 deletions src/layer/arm/lstm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ void lstm_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_
void lstm_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt);
#endif

#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
void lstm_int8_gate_output_vfpv4(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt);
#endif

static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt)
{
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Expand Down Expand Up @@ -181,6 +185,193 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x
}
}

static void lstm_int8_gate_output(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt)
{
#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
if (ncnn::cpu_support_arm_vfpv4())
{
lstm_int8_gate_output_vfpv4(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
return;
}
#endif

const int num_output = top_blob.w;
const int hidden_size = cell_state.w;

// lstm unit
// sigmoid(I)
// sigmoid(F)
// sigmoid(O)
// tanh(G)
// c_t := f_t .* c_{t-1} + i_t .* g_t
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
float* tmp_hidden_ptr = tmp_hidden_state;

int remain_hidden_size_start = 0;
#if __ARM_NEON
int nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 4;

const float* gates_data = gates.row(q);

float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);

float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]);
float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]);
float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]);
float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]);

float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G));
float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2));

vst1q_f32(cell_ptr + q, _cell2);

if (num_output == hidden_size)
{
vst1q_f32(hidden_ptr + q, _lstm_H);

if (elemtype == 1)
{
// fp32
vst1q_f32(output_data + q, _lstm_H);
}
if (elemtype == 2)
{
// fp16
unsigned short* outptr = (unsigned short*)output_data + q;
#if (__ARM_FP & 2)
#if NCNN_GNU_INLINE_ASM
#if __aarch64__
asm volatile(
"fcvtn v0.4h, %2.4s \n"
"st1 {v0.4h}, [%0] \n"
: "=r"(outptr) // %0
: "0"(outptr),
"w"(_lstm_H)
: "memory", "v0");
#else // __aarch64__
asm volatile(
"vcvt.f16.f32 d0, %q2 \n"
"vst1.u16 {d0}, [%0] \n"
: "=r"(outptr) // %0
: "0"(outptr),
"w"(_lstm_H)
: "memory", "q0");
#endif // __aarch64__
#else // NCNN_GNU_INLINE_ASM
vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_lstm_H));
#endif // NCNN_GNU_INLINE_ASM
#else
outptr[q] = float32_to_float16(hidden_ptr[q]);
outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]);
outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]);
outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]);
#endif // (__ARM_FP & 2)
}
if (elemtype == 4)
{
// bf16
vst1_u16((unsigned short*)output_data + q, float2bfloat(_lstm_H));
}
}
else
{
vst1q_f32(tmp_hidden_ptr + q, _lstm_H);
}
}
#endif // __ARM_NEON
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];

I = 1.f / (1.f + expf(-I));
F = 1.f / (1.f + expf(-F));
O = 1.f / (1.f + expf(-O));
G = tanhf(G);

float cell2 = F * cell_ptr[q] + I * G;
float H = O * tanhf(cell2);

cell_ptr[q] = cell2;
if (num_output == hidden_size)
{
hidden_ptr[q] = H;

if (elemtype == 1)
{
output_data[q] = H;
}
if (elemtype == 2)
{
((unsigned short*)output_data)[q] = float32_to_float16(H);
}
if (elemtype == 4)
{
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
}
}
else
{
tmp_hidden_ptr[q] = H;
}
}

if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

hidden_ptr[q] = H;

if (elemtype == 1)
{
output_data[q] = H;
}
if (elemtype == 2)
{
((unsigned short*)output_data)[q] = float32_to_float16(H);
}
if (elemtype == 4)
{
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
}
}
}
}

static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
// TODO dispatch for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Expand Down Expand Up @@ -476,149 +667,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d
#endif // __ARM_NEON
}

// lstm unit
// sigmoid(I)
// sigmoid(F)
// sigmoid(O)
// tanh(G)
// c_t := f_t .* c_{t-1} + i_t .* g_t
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);

float* cell_ptr = cell_state;
float* hidden_ptr = hidden_state;
float* tmp_hidden_ptr = tmp_hidden_state;

int remain_hidden_size_start = 0;
#if __ARM_NEON
int nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;

#pragma omp parallel for num_threads(opt.num_threads)
for (int qq = 0; qq < nn_hidden_size; qq++)
{
int q = qq * 4;

const float* gates_data = gates.row(q);

float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data);

float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]);
float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]);
float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]);
float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]);

float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G));
float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2));

vst1q_f32(cell_ptr + q, _cell2);

if (num_output == hidden_size)
{
vst1q_f32(hidden_ptr + q, _lstm_H);

if (elemtype == 1)
{
// fp32
vst1q_f32(output_data + q, _lstm_H);
}
if (elemtype == 2)
{
// fp16
vst1_u16((unsigned short*)output_data + q, (uint16x4_t)vcvt_f16_f32(_lstm_H));
}
if (elemtype == 4)
{
// bf16
vst1_u16((unsigned short*)output_data + q, float2bfloat(_lstm_H));
}
}
else
{
vst1q_f32(tmp_hidden_ptr + q, _lstm_H);
}
}
#endif // __ARM_NEON
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_hidden_size_start; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];

I = 1.f / (1.f + expf(-I));
F = 1.f / (1.f + expf(-F));
O = 1.f / (1.f + expf(-O));
G = tanhf(G);

float cell2 = F * cell_ptr[q] + I * G;
float H = O * tanhf(cell2);

cell_ptr[q] = cell2;
if (num_output == hidden_size)
{
hidden_ptr[q] = H;

if (elemtype == 1)
{
output_data[q] = H;
}
if (elemtype == 2)
{
((unsigned short*)output_data)[q] = float32_to_float16(H);
}
if (elemtype == 4)
{
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
}
}
else
{
tmp_hidden_ptr[q] = H;
}
}

if (num_output != hidden_size)
{
// int nn_num_output = num_output >> 2;
// int remain_num_output_start = nn_num_output << 2;
// #pragma omp parallel for num_threads(opt.num_threads)
// for (int qq = 0; qq < nn_num_output; qq++)
// {
// int q = qq * 4;
//
// }
int remain_num_output_start = 0;
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = remain_num_output_start; q < num_output; q++)
{
const float* hr = weight_hr.row(q);
const float* tmp_hidden_ptr = tmp_hidden_state;

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_ptr[i] * hr[i];
}

hidden_ptr[q] = H;

if (elemtype == 1)
{
output_data[q] = H;
}
if (elemtype == 2)
{
((unsigned short*)output_data)[q] = float32_to_float16(H);
}
if (elemtype == 4)
{
((unsigned short*)output_data)[q] = float32_to_bfloat16(H);
}
}
}
lstm_int8_gate_output(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt);
}
}
2 changes: 1 addition & 1 deletion src/layer/arm/rnn_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ static void rnn_int8_gate_output(const Mat& gates, Mat& hidden_state, Mat& top_b
asm volatile(
"fcvtn v0.4h, %2.4s \n"
"st1 {v0.4h}, [%0] \n"
: "=r"(_rnn_H) // %0
: "=r"(outptr) // %0
: "0"(outptr),
"w"(_rnn_H)
: "memory", "v0");
Expand Down

0 comments on commit e0575ea

Please sign in to comment.