Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] added erf support #4992

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d94770d
added erf support
brightening-eyes Aug 29, 2023
424445d
set the type to 21 to have more tests
brightening-eyes Aug 29, 2023
55a52dc
lets say tests pass or not
brightening-eyes Aug 29, 2023
3f99959
Merge branch 'master' into opt/erf
brightening-eyes Aug 29, 2023
89eb070
completely fixed vulkan shaders
brightening-eyes Aug 30, 2023
3a4c65d
compiled spv files are not needed
brightening-eyes Aug 30, 2023
455c3e9
working on other architectures
brightening-eyes Sep 1, 2023
a7e8802
fixed loongarch stuff
brightening-eyes Sep 1, 2023
e7d7c34
trying to fix arm stuff
brightening-eyes Sep 2, 2023
b0bb078
working on arm stuff
brightening-eyes Sep 2, 2023
91d94a7
fixing x86 stuff
brightening-eyes Sep 2, 2023
dbc5441
mips implementation
brightening-eyes Sep 2, 2023
b39106c
Merge remote-tracking branch 'origin' into opt/erf
brightening-eyes Sep 2, 2023
5d4a226
trying to fix x86 stuff
brightening-eyes Sep 3, 2023
4f9c6df
fixing arm again
brightening-eyes Sep 3, 2023
e076ac2
added ERF to the list of operators
brightening-eyes Sep 3, 2023
05b668b
Merge branch 'opt/erf' of https://github.com/brightening-eyes/ncnn in…
brightening-eyes Sep 3, 2023
c0b256c
fixing arm and x86 operators
brightening-eyes Sep 3, 2023
290c4e3
fixed the errors with the x86 architecture
brightening-eyes Sep 3, 2023
6dccf37
x86: msvc has __m256 and __m512 exponents, but there was nothing for …
brightening-eyes Sep 3, 2023
b36ea45
mips fix
brightening-eyes Sep 3, 2023
4eb96af
Merge remote-tracking branch 'origin' into opt/erf
brightening-eyes Sep 4, 2023
15f9aeb
fixed arm
brightening-eyes Sep 4, 2023
4bb5c30
arm should now be fixed
brightening-eyes Sep 4, 2023
ea2af2e
x86 tests should pass
brightening-eyes Sep 4, 2023
f473607
forgot to return back the name of unaryop.comp as a shader
brightening-eyes Sep 4, 2023
c78c17f
rewrite of erf for x86 stuff to fix the tests
brightening-eyes Sep 4, 2023
cc8b1d0
arm test fix. its now like x86 thing
brightening-eyes Sep 4, 2023
15c2ed4
use provided abs stuff, in order to check if tests work or not
brightening-eyes Sep 5, 2023
8602ed8
removed vectorization stuff, it failes the tests
brightening-eyes Sep 5, 2023
befd7c1
recheck
brightening-eyes Sep 5, 2023
a7f5eb1
fixing x86 without vector intrinsics
brightening-eyes Sep 6, 2023
74c7806
x86 fix
brightening-eyes Sep 6, 2023
64ceec5
some checks. maybe tests pass?
brightening-eyes Sep 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1781,3 +1781,4 @@ Operation type:
- 17 = LOG10
- 18 = ROUND
- 19 = TRUNC
- 20 = ERF
6 changes: 6 additions & 0 deletions src/layer/arm/unaryop_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return UnaryOp::forward_inplace(bottom_top_blob, opt);

return 0;
}

Expand Down Expand Up @@ -686,6 +689,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TRUNC)
return unary_op_inplace_bf16s<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return UnaryOp::forward_inplace(bottom_top_blob, opt);

return 0;
}
#endif // NCNN_BF16
Expand Down
3 changes: 3 additions & 0 deletions src/layer/loongarch/unaryop_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return UnaryOp::forward_inplace(bottom_top_blob, opt);

return 0;
}

Expand Down
24 changes: 24 additions & 0 deletions src/layer/mips/unaryop_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,27 @@ struct unary_op_trunc
#endif // __mips_msa
};

struct unary_op_erf
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __mips_msa
v4f32 func_pack4(const v4f32& x) const
{
// TODO msa optimize
float tmp[4];
__msa_st_w((v4i32)x, tmp, 0);
tmp[0] = erf(tmp[0]);
tmp[1] = erf(tmp[1]);
tmp[2] = erf(tmp[2]);
tmp[3] = erf(tmp[3]);
return (v4f32)__msa_ld_w(tmp, 0);
}
#endif // __mips_msa
};

} // namespace UnaryOp_mips_functor

int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -502,6 +523,9 @@ int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down
6 changes: 6 additions & 0 deletions src/layer/riscv/unaryop_riscv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ int UnaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return UnaryOp::forward_inplace(bottom_top_blob, opt);

return 0;
#else // __riscv_vector
return UnaryOp::forward_inplace(bottom_top_blob, opt);
Expand Down Expand Up @@ -683,6 +686,9 @@ int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt
if (op_type == Operation_TRUNC)
return unary_op_inplace_fp16s<unary_op_trunc_fp16s>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return UnaryOp::forward_inplace(bottom_top_blob, opt);

return 0;
}
#endif // __riscv_vector && __riscv_zfh
Expand Down
11 changes: 11 additions & 0 deletions src/layer/unaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,14 @@ struct unary_op_trunc
}
};

struct unary_op_erf
{
float operator()(const float& x) const
{
return (float)erf(x);
}
};

int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
if (op_type == Operation_ABS)
Expand Down Expand Up @@ -280,6 +288,9 @@ int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return unary_op_inplace<unary_op_erf>(bottom_top_blob, opt);

return 0;
}

Expand Down
3 changes: 2 additions & 1 deletion src/layer/unaryop.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class UnaryOp : public Layer
Operation_TANH = 16,
Operation_LOG10 = 17,
Operation_ROUND = 18,
Operation_TRUNC = 19
Operation_TRUNC = 19,
Operation_ERF = 20
};

public:
Expand Down
16 changes: 16 additions & 0 deletions src/layer/vulkan/shader/unaryop.comp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ layout (push_constant) uniform parameter
int cstep;
} p;

float erf(float x)
{
float a1 = 0.254829592f;
float a2 = -0.284496736f;
float a3 = 1.421413741f;
float a4 = -1.453152027f;
float a5 = 1.061405429f;
float p = 0.3275911f;
float s = sign(x);
float x_abs = abs(x);
float t = 1.0f/(1.0f + p*x_abs);
float y = 1.0f - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x_abs*x_abs);
return s * y;
}

void main()
{
int gx = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -89,6 +104,7 @@ void main()
if (op_type == 17) res = log(v) * afp(0.434294481903);
if (op_type == 18) res = round(v);
if (op_type == 19) res = trunc(v);
if (op_type == 20) res = erf(v);

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
16 changes: 16 additions & 0 deletions src/layer/vulkan/shader/unaryop_pack4.comp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ layout (push_constant) uniform parameter
int cstep;
} p;

afpvec4 erf(afpvec4 x)
{
afpvec4 a1 = afpvec4(0.254829592f);
afpvec4 a2 = afpvec4(-0.284496736f);
afpvec4 a3 = afpvec4(1.421413741f);
afpvec4 a4 = afpvec4(-1.453152027f);
afpvec4 a5 = afpvec4(1.061405429f);
afpvec4 p = afpvec4(0.3275911f);
afpvec4 s = sign(x);
afpvec4 x_abs = abs(x);
afpvec4 t = 1.0f / (1.0f + p * x_abs);
afpvec4 y = 1.0f - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(-x_abs * x_abs);
return s * y;
}

void main()
{
int gx = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -89,6 +104,7 @@ void main()
if (op_type == 17) res = log(v) * afp(0.434294481903);
if (op_type == 18) res = round(v);
if (op_type == 19) res = trunc(v);
if (op_type == 20) res = erf(v);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
20 changes: 20 additions & 0 deletions src/layer/vulkan/shader/unaryop_pack8.comp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ layout (push_constant) uniform parameter
int cstep;
} p;

afpvec4 erf(afpvec4 x)
{
afpvec4 a1 = afpvec4(0.254829592f);
afpvec4 a2 = afpvec4(-0.284496736f);
afpvec4 a3 = afpvec4(1.421413741f);
afpvec4 a4 = afpvec4(-1.453152027f);
afpvec4 a5 = afpvec4(1.061405429f);
afpvec4 p = afpvec4(0.3275911f);
afpvec4 s = sign(x);
afpvec4 x_abs = abs(x);
afpvec4 t = 1.0f / (1.0f + p * x_abs);
afpvec4 y = 1.0f - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(-x_abs * x_abs);
return s * y;
}

void main()
{
int gx = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -171,6 +186,11 @@ void main()
res[0] = trunc(v[0]);
res[1] = trunc(v[1]);
}
if (op_type == 20)
{
res[0] = erf(v[0]);
res[1] = erf(v[1]);
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
Expand Down
71 changes: 71 additions & 0 deletions src/layer/x86/unaryop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,74 @@ struct unary_op_trunc
#endif // __SSE2__
};

struct unary_op_erf
{
float func(const float& x) const
{
return (float)erf(x);
}
#if __SSE2__
__m128 func_pack4(const __m128& x) const
{
__m128 a1 = _mm_set1_ps(0.254829592f);
__m128 a2 = _mm_set1_ps(-0.284496736f);
__m128 a3 = _mm_set1_ps(1.421413741f);
__m128 a4 = _mm_set1_ps(-1.453152027f);
__m128 a5 = _mm_set1_ps(1.061405429f);
__m128 p = _mm_set1_ps(0.3275911f);
__m128 s = _mm_and_ps(x, _mm_set1_ps(-0.0f));
__m128 x_abs = abs_ps(x);
__m128 t = _mm_rcp_ps(_mm_mul_ps(_mm_add_ps(_mm_set_ps1(1.0f), p), x_abs));
__m128 y = _mm_set1_ps(1.0f);
__m128 err = _mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t);
err = exp_ps(_mm_mul_ps(_mm_sub_ps(_mm_setzero_ps(), x_abs), x_abs));
y = _mm_sub_ps(y, err);
return _mm_mul_ps(s, y);
}
#if __AVX__
__m256 func_pack8(const __m256& x) const
{
__m256 a1 = _mm256_set1_ps(0.254829592f);
__m256 a2 = _mm256_set1_ps(-0.284496736f);
__m256 a3 = _mm256_set1_ps(1.421413741f);
__m256 a4 = _mm256_set1_ps(-1.453152027f);
__m256 a5 = _mm256_set1_ps(1.061405429f);
__m256 p = _mm256_set1_ps(0.3275911f);
const __m256 zero = _mm256_set1_ps(0.0f);
__m256 positives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_GT_OQ), _mm256_set1_ps(1.0f));
__m256 negatives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_LT_OQ), _mm256_set1_ps(-1.0f));
__m256 s = _mm256_or_ps(positives, negatives);
__m256 x_abs = abs256_ps(x);
__m256 t = _mm256_rcp_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_set1_ps(1.0f), p), x_abs));
__m256 y = _mm256_set1_ps(1.0f);
__m256 err = _mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t);
err = exp256_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_setzero_ps(), x_abs), x_abs));
y = _mm256_sub_ps(y, err);
return _mm256_mul_ps(s, y);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x) const
{
__m512 a1 = _mm512_set1_ps(0.254829592f);
__m512 a2 = _mm512_set1_ps(-0.284496736f);
__m512 a3 = _mm512_set1_ps(1.421413741f);
__m512 a4 = _mm512_set1_ps(-1.453152027f);
__m512 a5 = _mm512_set1_ps(1.061405429f);
__m512 p = _mm512_set1_ps(0.3275911f);
__m512 s = _mm512_and_ps(x, _mm512_set1_ps(-0.0f));
__m512 x_abs = abs512_ps(x);
__m512 t = _mm512_div_ps(_mm512_set1_ps(1.0f), _mm512_mul_ps(_mm512_add_ps(_mm512_set1_ps(1.0f), p), x_abs));
__m512 y = _mm512_set1_ps(1.0f);
__m512 err = _mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t);
err = exp512_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_setzero_ps(), x_abs), x_abs));
y = _mm512_sub_ps(y, err);
return _mm512_mul_ps(s, y);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};

} // namespace UnaryOp_x86_functor

int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
Expand Down Expand Up @@ -707,6 +775,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TRUNC)
return unary_op_inplace<unary_op_trunc>(bottom_top_blob, opt);

if (op_type == Operation_ERF)
return UnaryOp::forward_inplace(bottom_top_blob, opt);

return 0;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_unaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "layer/unaryop.h"
#include "testutil.h"

#define OP_TYPE_MAX 20
#define OP_TYPE_MAX 21

static int op_type = 0;

Expand Down
9 changes: 9 additions & 0 deletions tools/onnx/onnx2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3714,6 +3714,10 @@ int main(int argc, char** argv)
{
fprintf(pp, "%-16s", "EmbedLayerNormalization");
}
else if (op == "Erf")
{
fprintf(pp, "%-16s", "UnaryOp");
}
else if (op == "Exp")
{
fprintf(pp, "%-16s", "UnaryOp");
Expand Down Expand Up @@ -4510,6 +4514,11 @@ int main(int argc, char** argv)

fwrite_tensor_proto_data(B, bp);
}
else if (op == "Erf")
{
int op_type = 20;
fprintf(pp, " 0=%d", op_type);
}
else if (op == "Exp")
{
int op_type = 7;
Expand Down
Loading