Skip to content

Commit

Permalink
vfpv4 fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Sep 29, 2024
1 parent 392e38b commit 207e166
Show file tree
Hide file tree
Showing 8 changed files with 13,360 additions and 300 deletions.
602 changes: 310 additions & 292 deletions src/layer/arm/gemm_arm.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/layer/arm/gemm_arm.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class Gemm_arm : public Gemm
Mat AT_data;
Mat BT_data;
Mat CT_data;

int input_elemtype; // 0=auto 1=fp32 2=fp16 3=bf16
};

} // namespace ncnn
Expand Down
31 changes: 31 additions & 0 deletions src/layer/arm/gemm_arm_asimddp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
namespace ncnn {

#include "gemm_int8.h"
#include "gemm_int8_fp16s.h"

#if NCNN_BF16
#include "gemm_int8_bf16s.h"
Expand Down Expand Up @@ -79,6 +80,36 @@ void gemm_transB_packed_tile_int8_asimddp(const Mat& AT_tile, const Mat& BT_tile
gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk);
}

void pack_A_tile_fp16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void transpose_pack_A_tile_fp16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
transpose_pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void pack_B_tile_fp16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void transpose_pack_B_tile_fp16_to_int8_asimddp(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
transpose_pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void unpack_output_tile_int32_to_fp16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
unpack_output_tile_int32_to_fp16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}

void transpose_unpack_output_tile_int32_to_fp16_asimddp(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
transpose_unpack_output_tile_int32_to_fp16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}

#if NCNN_BF16
void pack_A_tile_bf16_to_int8_asimddp(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
Expand Down
21 changes: 21 additions & 0 deletions src/layer/arm/gemm_arm_i8mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
namespace ncnn {

#include "gemm_int8.h"
#include "gemm_int8_fp16s.h"

#if NCNN_BF16
#include "gemm_int8_bf16s.h"
Expand Down Expand Up @@ -69,6 +70,26 @@ void gemm_transB_packed_tile_int8_i8mm(const Mat& AT_tile, const Mat& BT_tile, M
gemm_transB_packed_tile_int8(AT_tile, BT_tile, topT_tile, i, max_ii, j, max_jj, k, max_kk);
}

void pack_A_tile_fp16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void transpose_pack_A_tile_fp16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
transpose_pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void pack_B_tile_fp16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void transpose_pack_B_tile_fp16_to_int8_i8mm(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
transpose_pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

#if NCNN_BF16
void pack_A_tile_bf16_to_int8_i8mm(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
Expand Down
51 changes: 51 additions & 0 deletions src/layer/arm/gemm_arm_vfpv4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ namespace ncnn {
#include "gemm_bf16s_fp16s.h"
#include "gemm_fp16s.h"

#if NCNN_INT8
#include "gemm_int8_fp16s.h"
#endif

extern void pack_A_tile(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk);

static int gemm_arm_fp16s(const Mat& A, const Mat& B, const Mat& C, Mat& top_blob, int broadcast_type_C, int transA, int transB, int output_transpose, float alpha, int constant_TILE_M, int constant_TILE_N, int constant_TILE_K, int nT, const Option& opt)
Expand Down Expand Up @@ -709,4 +713,51 @@ int Gemm_arm::forward_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Ma
return ret;
}

#if NCNN_INT8
void compute_A_tile_fp16_int8_scales_vfpv4(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii)
{
compute_A_tile_fp16_int8_scales(A, scales, B_scale, out_descales, i, max_ii);
}

void transpose_compute_A_tile_fp16_int8_scales_vfpv4(const Mat& A, Mat& scales, float B_scale, Mat& out_descales, int i, int max_ii)
{
transpose_compute_A_tile_fp16_int8_scales(A, scales, B_scale, out_descales, i, max_ii);
}

void pack_A_tile_fp16_to_int8_vfpv4(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void transpose_pack_A_tile_fp16_to_int8_vfpv4(const Mat& A, Mat& AT, int i, int max_ii, int k, int max_kk, const Mat& scales)
{
transpose_pack_A_tile_fp16_to_int8(A, AT, i, max_ii, k, max_kk, scales);
}

void compute_B_fp16_int8_scale_vfpv4(const Mat& B, float& scale)
{
compute_B_fp16_int8_scale(B, scale);
}

void pack_B_tile_fp16_to_int8_vfpv4(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void transpose_pack_B_tile_fp16_to_int8_vfpv4(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
transpose_pack_B_tile_fp16_to_int8(B, BT, j, max_jj, k, max_kk, scale);
}

void unpack_output_tile_int32_to_fp16_vfpv4(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
unpack_output_tile_int32_to_fp16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}

void transpose_unpack_output_tile_int32_to_fp16_vfpv4(const Mat& topT, const Mat& C, Mat& top_blob, int broadcast_type_C, int i, int max_ii, int j, int max_jj, const Mat& descales, float alpha, float beta)
{
transpose_unpack_output_tile_int32_to_fp16(topT, C, top_blob, broadcast_type_C, i, max_ii, j, max_jj, descales, alpha, beta);
}
#endif // NCNN_INT8

} // namespace ncnn
36 changes: 36 additions & 0 deletions src/layer/arm/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -3932,6 +3932,42 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int
}
}

static void compute_B_fp32_int8_scale(const Mat& B, float& scale)
{
float absmax = 0.f;
#if __ARM_NEON
float32x4_t _absmax = vdupq_n_f32(0.f);
#endif
for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++)
{
const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w;
const float* ptr = (const float*)B + i * B_hstep * B.elempack;

const int size = B.w * B.elempack;

int j = 0;
#if __ARM_NEON
for (; j + 3 < size; j += 4)
{
float32x4_t _p = vld1q_f32(ptr);
_absmax = vmaxq_f32(_absmax, vabsq_f32(_p));
ptr += 4;
}
#endif
for (; j < size; j++)
{
absmax = std::max(absmax, (float)fabs(ptr[0]));
ptr++;
}
}
#if __ARM_NEON
float32x2_t _aa = vmax_f32(vget_low_f32(_absmax), vget_high_f32(_absmax));
absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)));
#endif

scale = absmax == 0.f ? 1.f : 127.f / absmax;
}

static void pack_B_tile_fp32_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8
Expand Down
83 changes: 75 additions & 8 deletions src/layer/arm/gemm_int8_bf16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -2317,6 +2317,51 @@ static void transpose_pack_A_tile_bf16_to_int8(const Mat& A, Mat& AT, int i, int
}
}

static void compute_B_bf16_int8_scale(const Mat& B, float& scale)
{
float absmax = 0.f;
#if __ARM_NEON
float32x4_t _absmax = vdupq_n_f32(0.f);
#endif
for (int i = 0; i < (B.dims == 3 ? B.c : B.h); i++)
{
const int B_hstep = B.dims == 3 ? (int)B.cstep : B.w;
const unsigned short* ptr = (const unsigned short*)B + i * B_hstep * B.elempack;

const int size = B.w * B.elempack;

int j = 0;
#if __ARM_NEON
for (; j + 7 < size; j += 8)
{
uint16x8_t _p = vld1q_u16(ptr);
float32x4_t _p0 = bfloat2float(vget_low_u16(_p));
float32x4_t _p1 = bfloat2float(vget_high_u16(_p));
_absmax = vmaxq_f32(_absmax, vabsq_f32(_p0));
_absmax = vmaxq_f32(_absmax, vabsq_f32(_p1));
ptr += 8;
}
for (; j + 3 < size; j += 4)
{
float32x4_t _p = bfloat2float(vld1_u16(ptr));
_absmax = vmaxq_f32(_absmax, vabsq_f32(_p));
ptr += 4;
}
#endif
for (; j < size; j++)
{
absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(ptr[0])));
ptr++;
}
}
#if __ARM_NEON
float32x2_t _aa = vmax_f32(vget_low_f32(_absmax), vget_high_f32(_absmax));
absmax = std::max(absmax, std::max(vget_lane_f32(_aa, 0), vget_lane_f32(_aa, 1)));
#endif

scale = absmax == 0.f ? 1.f : 127.f / absmax;
}

static void pack_B_tile_bf16_to_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, int max_kk, float scale)
{
#if NCNN_RUNTIME_CPU && NCNN_ARM84I8MM && __aarch64__ && !__ARM_FEATURE_MATMUL_INT8
Expand Down Expand Up @@ -5708,15 +5753,37 @@ static void unpack_output_tile_int32_to_bf16(const Mat& topT, const Mat& C, Mat&
else // if (c_elempack == 4)
{
uint16x4x4_t _cc0 = vld4_u16(pC);
_f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0]));
_f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1]));
_f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2]));
_f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3]));
if (beta == 1.f)
{
_f0 = vaddq_f32(_f0, bfloat2float(_cc0.val[0]));
_f1 = vaddq_f32(_f1, bfloat2float(_cc0.val[1]));
_f2 = vaddq_f32(_f2, bfloat2float(_cc0.val[2]));
_f3 = vaddq_f32(_f3, bfloat2float(_cc0.val[3]));
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f0 = vmlaq_f32(_f0, bfloat2float(_cc0.val[0]), _beta);
_f1 = vmlaq_f32(_f1, bfloat2float(_cc0.val[1]), _beta);
_f2 = vmlaq_f32(_f2, bfloat2float(_cc0.val[2]), _beta);
_f3 = vmlaq_f32(_f3, bfloat2float(_cc0.val[3]), _beta);
}
_cc0 = vld4_u16(pC + c_hstep * 4);
_f4 = vaddq_f32(_f4, bfloat2float(_cc0.val[0]));
_f5 = vaddq_f32(_f5, bfloat2float(_cc0.val[1]));
_f6 = vaddq_f32(_f6, bfloat2float(_cc0.val[2]));
_f7 = vaddq_f32(_f7, bfloat2float(_cc0.val[3]));
if (beta == 1.f)
{
_f4 = vaddq_f32(_f4, bfloat2float(_cc0.val[0]));
_f5 = vaddq_f32(_f5, bfloat2float(_cc0.val[1]));
_f6 = vaddq_f32(_f6, bfloat2float(_cc0.val[2]));
_f7 = vaddq_f32(_f7, bfloat2float(_cc0.val[3]));
}
else
{
float32x4_t _beta = vdupq_n_f32(beta);
_f4 = vmlaq_f32(_f4, bfloat2float(_cc0.val[0]), _beta);
_f5 = vmlaq_f32(_f5, bfloat2float(_cc0.val[1]), _beta);
_f6 = vmlaq_f32(_f6, bfloat2float(_cc0.val[2]), _beta);
_f7 = vmlaq_f32(_f7, bfloat2float(_cc0.val[3]), _beta);
}
pC += 16;
}
}
Expand Down
Loading

0 comments on commit 207e166

Please sign in to comment.