Skip to content

Commit

Permalink
Update arm support for 3d vector input innerproduct_arm.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
parrotsky authored Nov 16, 2023
1 parent c62f5d1 commit 898fe16
Showing 1 changed file with 237 additions and 0 deletions.
237 changes: 237 additions & 0 deletions src/layer/arm/innerproduct_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,243 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio
return 0;
}

if (bottom_blob.dims == 3 && bottom_blob.c == num_input)
{
// 3d tensor input gemm
int w = bottom_blob.w;
int h = bottom_blob.h;
int c = bottom_blob.c; // num_input
size_t elemsize = bottom_blob.elemsize;
int elempack = bottom_blob.elempack;
ncnn::Mat bottom_blob_flattened = bottom_blob.reshape( c, w * h);

top_blob.create(num_output, w * h, elemsize, elempack, opt.blob_allocator);
if (top_blob.empty())
return -100;

int num_output_elempack = 1;
#if __ARM_NEON
if (opt.use_packing_layout)
{
num_output_elempack = num_output % 4 == 0 ? 4 : 1;
}
#endif

#pragma omp parallel for num_threads(opt.num_threads)
for (int j = 0; j < h; j++)
{
#if __ARM_NEON
if (elempack == 4 && num_output_elempack == 4)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output / num_output_elempack; p++)
{
const float* kptr = weight_data_tm.row(p);
const float* m = bottom_blob.row(j);

float32x4_t _sum0 = vdupq_n_f32(0.f);
float32x4_t _sum1 = vdupq_n_f32(0.f);
float32x4_t _sum2 = vdupq_n_f32(0.f);
float32x4_t _sum3 = vdupq_n_f32(0.f);

if (bias_term)
{
_sum0 = vdupq_n_f32(bias_data[p * 4 + 0]);
_sum1 = vdupq_n_f32(bias_data[p * 4 + 1]);
_sum2 = vdupq_n_f32(bias_data[p * 4 + 2]);
_sum3 = vdupq_n_f32(bias_data[p * 4 + 3]);
}

int i = 0;
for (; i < num_input; i++)
{
float32x4_t _val = vld1q_f32(m);
float32x4_t _w = vld1q_f32(kptr);
#if __aarch64__
_sum0 = vfmaq_laneq_f32(_sum0, _val, _w, 0);
_sum1 = vfmaq_laneq_f32(_sum1, _val, _w, 1);
_sum2 = vfmaq_laneq_f32(_sum2, _val, _w, 2);
_sum3 = vfmaq_laneq_f32(_sum3, _val, _w, 3);
#else
_sum0 = vmlaq_lane_f32(_sum0, _val, vget_low_f32(_w), 0);
_sum1 = vmlaq_lane_f32(_sum1, _val, vget_low_f32(_w), 1);
_sum2 = vmlaq_lane_f32(_sum2, _val, vget_high_f32(_w), 0);
_sum3 = vmlaq_lane_f32(_sum3, _val, vget_high_f32(_w), 1);
#endif
m += 4;
kptr += 4;
}

_sum0 = activation_ps(_sum0, activation_type, activation_params);
_sum1 = activation_ps(_sum1, activation_type, activation_params);
_sum2 = activation_ps(_sum2, activation_type, activation_params);
_sum3 = activation_ps(_sum3, activation_type, activation_params);

vst1q_f32(outptr, _sum0);
vst1q_f32(outptr + 4, _sum1);
vst1q_f32(outptr + 8, _sum2);
vst1q_f32(outptr + 12, _sum3);
outptr += 16;
}
}

if (elempack == 1 && num_output_elempack == 4)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output / num_output_elempack; p++)
{
const float* kptr = weight_data_tm.row(p);
const float* m = bottom_blob.row(j);

float32x4_t _sum0 = vdupq_n_f32(0.f);
float32x4_t _sum1 = vdupq_n_f32(0.f);
float32x4_t _sum2 = vdupq_n_f32(0.f);
float32x4_t _sum3 = vdupq_n_f32(0.f);

if (bias_term)
{
_sum0 = vld1q_f32((const float*)bias_data + p * 4);
}

int i = 0;
for (; i + 3 < num_input; i += 4)
{
float32x4_t _val = vld1q_f32(m);

float32x4_t _w0 = vld1q_f32(kptr);
float32x4_t _w1 = vld1q_f32(kptr + 4);
float32x4_t _w2 = vld1q_f32(kptr + 8);
float32x4_t _w3 = vld1q_f32(kptr + 12);

#if __aarch64__
_sum0 = vfmaq_laneq_f32(_sum0, _w0, _val, 0);
_sum1 = vfmaq_laneq_f32(_sum1, _w1, _val, 1);
_sum2 = vfmaq_laneq_f32(_sum2, _w2, _val, 2);
_sum3 = vfmaq_laneq_f32(_sum3, _w3, _val, 3);
#else
_sum0 = vmlaq_lane_f32(_sum0, _w0, vget_low_f32(_val), 0);
_sum1 = vmlaq_lane_f32(_sum1, _w1, vget_low_f32(_val), 1);
_sum2 = vmlaq_lane_f32(_sum2, _w2, vget_high_f32(_val), 0);
_sum3 = vmlaq_lane_f32(_sum3, _w3, vget_high_f32(_val), 1);
#endif

m += 4;
kptr += 16;
}
for (; i < num_input; i++)
{
float32x4_t _val = vld1q_dup_f32(m);
float32x4_t _k = vld1q_f32(kptr);
_sum0 = vmlaq_f32(_sum0, _val, _k);

m += 1;
kptr += 4;
}

_sum0 = vaddq_f32(_sum0, _sum1);
_sum2 = vaddq_f32(_sum2, _sum3);
_sum0 = vaddq_f32(_sum0, _sum2);

_sum0 = activation_ps(_sum0, activation_type, activation_params);

vst1q_f32(outptr, _sum0);
outptr += 4;
}
}

if (elempack == 4 && num_output_elempack == 1)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output; p++)
{
const float* kptr = (const float*)weight_data_tm + num_input * p;
const float* m = bottom_blob.row(j);

float32x4_t _sum = vdupq_n_f32(0.f);

if (bias_term)
{
_sum = vdupq_n_f32(bias_data[p]);
}

for (int i = 0; i < num_input; i++)
{
float32x4_t _val = vld1q_f32(m);
float32x4_t _k = vdupq_n_f32(kptr[0]);
_sum = vmlaq_f32(_sum, _val, _k);

m += 4;
kptr += 1;
}

_sum = activation_ps(_sum, activation_type, activation_params);

vst1q_f32(outptr, _sum);
outptr += 4;
}
}
#endif // __ARM_NEON

if (elempack == 1 && num_output_elempack == 1)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output; p++)
{
const float* kptr = (const float*)weight_data_tm + num_input * p;
const float* m = bottom_blob.row(j);

float sum = 0.f;

if (bias_term)
{
sum = bias_data[p];
}

int i = 0;
#if __ARM_NEON
float32x4_t _sum = vdupq_n_f32(0.f);
for (; i + 3 < num_input; i += 4)
{
float32x4_t _val = vld1q_f32(m);
float32x4_t _k = vld1q_f32(kptr);
_sum = vmlaq_f32(_sum, _val, _k);

m += 4;
kptr += 4;
}
#if __aarch64__
sum += vaddvq_f32(_sum);
#else
float32x2_t _ss = vadd_f32(vget_low_f32(_sum), vget_high_f32(_sum));
_ss = vpadd_f32(_ss, _ss);
sum += vget_lane_f32(_ss, 0);
#endif
#endif // __ARM_NEON
for (; i < num_input; i++)
{
sum += *m * *kptr;

m += 1;
kptr += 1;
}

sum = activation_ss(sum, activation_type, activation_params);

outptr[0] = sum;
outptr += 1;
}
}
}
top_blob = top_blob.reshape(w, h, num_output);
return 0;
}



// flatten
Mat bottom_blob_flattened = bottom_blob;
if (bottom_blob.dims != 1)
Expand Down

0 comments on commit 898fe16

Please sign in to comment.