From d94770d76c0f960c8c8fa6c4e2e5cf08ec73e744 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Tue, 29 Aug 2023 12:32:08 +0330 Subject: [PATCH 01/30] added erf support --- src/layer/unaryop.cpp | 11 +++++++++++ src/layer/unaryop.h | 3 ++- src/layer/vulkan/shader/unaryop.comp | 1 + src/layer/vulkan/shader/unaryop_pack4.comp | 2 ++ src/layer/vulkan/shader/unaryop_pack8.comp | 6 ++++++ tools/onnx/onnx2ncnn.cpp | 9 +++++++++ 6 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/layer/unaryop.cpp b/src/layer/unaryop.cpp index 2fe77717ed3c..d31d49eebfbc 100644 --- a/src/layer/unaryop.cpp +++ b/src/layer/unaryop.cpp @@ -218,6 +218,14 @@ struct unary_op_trunc } }; +struct unary_op_erf +{ + float operator()(const float& x) const + { + return (float)erf(x); + } +}; + int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { if (op_type == Operation_ABS) @@ -280,6 +288,9 @@ int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/src/layer/unaryop.h b/src/layer/unaryop.h index 3a6926ce00ee..98cacbe34880 100644 --- a/src/layer/unaryop.h +++ b/src/layer/unaryop.h @@ -49,7 +49,8 @@ class UnaryOp : public Layer Operation_TANH = 16, Operation_LOG10 = 17, Operation_ROUND = 18, - Operation_TRUNC = 19 + Operation_TRUNC = 19, + Operation_ERF = 20 }; public: diff --git a/src/layer/vulkan/shader/unaryop.comp b/src/layer/vulkan/shader/unaryop.comp index 21e48ff187b1..6270c07756a4 100644 --- a/src/layer/vulkan/shader/unaryop.comp +++ b/src/layer/vulkan/shader/unaryop.comp @@ -89,6 +89,7 @@ void main() if (op_type == 17) res = log(v) * afp(0.434294481903); if (op_type == 18) res = round(v); if (op_type == 19) res = trunc(v); + if (op_type == 20) res = erf(v); #if NCNN_image_shader image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/unaryop_pack4.comp b/src/layer/vulkan/shader/unaryop_pack4.comp index 4a66b7df463e..f0649f36b741 100644 --- a/src/layer/vulkan/shader/unaryop_pack4.comp +++ b/src/layer/vulkan/shader/unaryop_pack4.comp @@ -90,6 +90,8 @@ void main() if (op_type == 18) res = round(v); if (op_type == 19) res = trunc(v); + if (op_type == 20) res = erf(v); + #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); #else diff --git a/src/layer/vulkan/shader/unaryop_pack8.comp b/src/layer/vulkan/shader/unaryop_pack8.comp index f4ced42e0441..7ff1b04924f0 100644 --- a/src/layer/vulkan/shader/unaryop_pack8.comp +++ b/src/layer/vulkan/shader/unaryop_pack8.comp @@ -172,6 +172,12 @@ void main() res[1] = trunc(v[1]); } + if (op_type == 20) + { + res[0] = erf(v[0]); + res[1] = erf(v[1]); + } + #if NCNN_image_shader image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); #else diff --git a/tools/onnx/onnx2ncnn.cpp b/tools/onnx/onnx2ncnn.cpp index 37c68cdf01a5..6c946d66b7ec 100644 --- a/tools/onnx/onnx2ncnn.cpp +++ b/tools/onnx/onnx2ncnn.cpp @@ -3714,6 +3714,10 @@ int main(int argc, char** argv) { fprintf(pp, "%-16s", "EmbedLayerNormalization"); } + else if (op == "Erf") + { + fprintf(pp, "%-16s", "UnaryOp"); + } else if (op == "Exp") { fprintf(pp, "%-16s", "UnaryOp"); @@ -4510,6 +4514,11 @@ int main(int argc, char** argv) fwrite_tensor_proto_data(B, bp); } + else if (op == "Erf") + { + int op_type = 20; + fprintf(pp, " 0=%d", op_type); + } else if (op == "Exp") { int op_type = 7; From 424445db5c801cc7cdc7a5d4bfc9fc1ad988046b Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Tue, 29 Aug 2023 12:34:01 +0330 Subject: [PATCH 02/30] set the type to 21 to have more tests --- tests/test_unaryop.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_unaryop.cpp b/tests/test_unaryop.cpp index 44274fd071f9..0573e458a853 100644 --- a/tests/test_unaryop.cpp +++ b/tests/test_unaryop.cpp @@ -15,7 +15,7 @@ #include "layer/unaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 20 +#define OP_TYPE_MAX 21 static int op_type = 0; From 55a52dcf71afcb039bae03da91e95de1c322d223 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Tue, 29 Aug 2023 13:42:47 +0330 Subject: [PATCH 03/30] lets say tests pass or not --- tests/test_unaryop.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_unaryop.cpp b/tests/test_unaryop.cpp index 0573e458a853..44274fd071f9 100644 --- a/tests/test_unaryop.cpp +++ b/tests/test_unaryop.cpp @@ -15,7 +15,7 @@ #include "layer/unaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 21 +#define OP_TYPE_MAX 20 static int op_type = 0; From 89eb07082edfd1cc95fe7b0f0b7378b6ba5f42f2 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Wed, 30 Aug 2023 11:01:41 +0330 Subject: [PATCH 04/30] completely fixed vulkan shaders --- src/layer/vulkan/shader/comp.spv | Bin 0 -> 10828 bytes src/layer/vulkan/shader/unaryop.comp | 15 +++++++++++++++ src/layer/vulkan/shader/unaryop_pack4.comp | 16 +++++++++++++++- src/layer/vulkan/shader/unaryop_pack8.comp | 16 +++++++++++++++- 4 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 src/layer/vulkan/shader/comp.spv diff --git a/src/layer/vulkan/shader/comp.spv b/src/layer/vulkan/shader/comp.spv new file mode 100644 index 0000000000000000000000000000000000000000..1ec53401886f4a8a6c222b6f61d620e739dc1830 GIT binary patch literal 10828 zcmZvh2e?$#6@`Z?QUpa58-l%J0~?BB!I#Ix>kPVz6b0(cvjoQ!9#{_ZozVmPF0;Xy{YMf?~F}rnKIqPL0BJnx#cnIVx8#i zNc<8+7b?oWE-*dFO`3W12D6VIG7JdQsnH!jr-^&vTefKD`-rz}$9=_Hwd0kas!?lpaO>x7oJDXr7@+O1hT#ypMQYn=F`c03t;BBs`9_!%uTTc%BFn=}*bNy@n< zPMPL*s#BEn{Ot2ojXlF-ir80VLKU~q@LVU%YMT_jt;E46gY9*17g*QC3-Obu$aijB zFJJU8CZ1i@b!|L}AK%*6);eu$TkDLm<2eKNwQR+cs^`W|Y-wwe_kK$REbs9=p8mGE zGbUBJrNlEQ&8l)wA-Fo5X6j zE3x-;l^k#2FZN(v?&(?kY`eUN@t`Vdk3E0);4<%-8f$$yInQAb)0`(>36wpqOuhFt zR?|TK>v_p(gNT=HV)>1Utz8c5I`@p>n^0w4?G)l!_O`EYyI)foU^hg-)pYyVIJ5$AMhYpS*id z-=$#BI(@!(UFo|4Y@dADcMGxajnpjymuJ(T>|=1g*6vGvk2hoaHMr{&c<}(R?`>o2 zp2FKB=eZgC{?43dn%Lg#+xNRG`*{KEe&pSE_}2^Gvk(79!Ot)FcMIOLO8xr>pntwa0gOX`*-&VJ3MZbh(Xn7V#Now?Kv1a~Xy1{HPYQnv}%XEO6P zE$YmrZd-8oqHepQ&Rptt1N+^SdAk>N=2ABb?DtFRMi+JFQg;NnS5bFlQD-i7$AJAz z&Acf^ow?MtfqNHqClqz&Qg;TpPf@p^s56(ki@|^(|yWlGDn`7P{p1vE1 zS19^!Ec(o)?rw13qVAre&Rpsq2d`MvEiUTJ<$dAzo@;BKO5Z=eXI;MMKZfyH^*uHU zKN|B{OFX6)kE_MU*5XrY@kOP7Ka2O^(kJ!)I!<*c2PDK62=|2)|{}EXFj{?i5e=Jzee$UG|{o}#* zw_xd?0G3bxM6jIwGMP74x$#bNpjl z-h23(8Gk~bGwpXya+7O0Ycl5qu>B8mFTW4{o7p|eFJfPECt~vMnSUZw%RS4_!Q{NB zICD<|dv^AmhIiRB4}U7Q0JBFNou`Ao z@fTxPVD^Y}CYOOd6Thc2=W?*Tx_v(ytHGJ}J105+HYjt}WX_df`*S8&f!(9LJ<*`g znOqH4x94QN^VD-K-lZOEuEFHokNC4jhYHK({?~!WVfNpQciDdf{wC~p%zkms<5sZe zaSN6?w}Iu=d?Vi2{ojnY-#N)$U&~pOId_2VKa=;ue)u~v_b8v5yTEy_?*+Tu^CJ8` z*aMjN6lc%(f!%ZY{=J`A-hJJTH+Ij?wBI?&<^7wSKKpqPY=8Fr5ZFD+&u{8@7_1g^ z>3ak`8cWTiVC&?)??Sv>>YoP7jo_Q`8T_-@66_evoOO@kpTo>KV{tR~^?eEK@~mFK zzlh1HRqR=H!@rE#=UitlgMS5+@5#N5U&T6LoiTIieGR-6v)B6@TbumX!OqU>;H+<8 z@|pcjaAuoJ@3+9tw%6Il*7}-7hhDAU!MoJ@I{w?3oLa@1`CV{b>-XT~Pit!bK3Fb2 zKLD2+jjc)khv2%#AHm6I=8wU0>G=t`)Mjjr+OluG+TO&w)bHRaf z)M{+4TBAd+*1zIiYJDI77fepA;>`R2T-W+HIQi84{Yz^80VjVJ?{nXG|HSN*Pv3{& z@+^$4Oa5Qr`dNGgC!d;+zoh09IQiWBQ?PyV>H9ahJS$`CJgc08-n05I-sM?+hW{Ls z^Q^>~>5Ec7s}8u1pgPCm1{g1M^M=F+JomZ*Op^)!5qP`+)0Omxq(j z>=nRV4QHE6Z(ne!)!16KMu%RlE8|^i?FU{7lT)iWGy8+RU*R`$zC!g5^z+BaAbLm|ZTxvD8R;|&YSL@n%msVmF2VB>> zE}VS!vmThM`YqdBde;Y+T8*tuegkk_>xOXhnY|I1tD0>ty&Ho|t;W`>H9GWa-3;$i ztG|s5#^ls0&dklhb*)>#$!9-Xg1M^SGt8xTD{!gR*xKZWfa_X^!pUd$)?lt`wz>3f z11_~1TdUUS(5rPA-lbN5-`O6MQ>!>LcL3M5?g%HJ{R{_lRlmuYOYctLQme7G$?puV zYaIb6pV_;BxvJUb(z`3T)M{+4TBAd+){%IZT6YKUfyt>=oSA!q>st4Mlh1zk26I)v zE1660KHyTTv9-zX3$APZ3Y>gqe-+GC%{G_b{lKMGV{6qK9eTAMfOn~NG;Weu4`?9lh1y}gSo2T_spet0=U#_Y;E!r!F8>Z;N&y=XfRhb z+gy4lgG;T()~Yo+^lCj0?^5d&@UfVjTE&?;6qp-d1p_)!5qP zXMpQkkB5`b?3rM$YPPxb&H|TOjjdH{bm-N3BHpFe6Tq`EIkk#2a}Kz!buOIzDE@l% z@5htC>)}f~xB0i_$>hx2lbQ3t&OPUoGjC63o(?Y0 z#@IT~Cc5>W%~^PtXR`o&CMM_Eh%@tSaQ$q~fs@Z`buO5T{}v#;ZfG*8A`-wcZ217n8F_oSFB7>slXxlh1x01amQ!-`)V} zeF$7?HMTbShrxBNkHE=i_M>2~YPPxbJ_ate8e6N@=+LY6NxVy~i@{G|a%vT4<`QsS z>(g-Z{@(Ny-gvv_^Zg9C2Xo~!<5_U2+1R?|p9j}9zW^uiJ)XlG=N>PDy@z~eyaX<_ z8(XLL=+ftR!=>Pb%*gM9FB99J-w|ISUP62fX07{>v(Dd)x5fPbTz=-e7ni@GY(t&D z8+FG#%bs}uAD3J&%=*NAYH{CMyizS*r53MVi~YBHIe)EMylyQX+>F)c-g;wV_q9S3 zyC*-d<=s#JCU!6W8&ckV48Zcc=s>Xg^SllMb8OZBs~m{8$Gys@XKirVvktL6oiTIv zs7F3M>w~>_;tjyof!Pdyb$&v)yl}NBu=ahJJur=A=7GP`SQ?n)5bIzW&0$V4anjzqF&QN0Om9_F$if%ozr@PChj|fXg{M5?d#qn&DudvCP@2%6V@8uh|(~&KW^$oqXo( z0`@h|Z^pZV<&xj6na>&TPHc^Qp2t1FK3hJ&=2JTo>^;rr_1_a*e_j?epO?MJnYYi^ g-}jST_OdtF=REN~U~6)}eZkhqf6ixxv7amd0}%`s>Hq)$ literal 0 HcmV?d00001 diff --git a/src/layer/vulkan/shader/unaryop.comp b/src/layer/vulkan/shader/unaryop.comp index 6270c07756a4..40e8a9888aa2 100644 --- a/src/layer/vulkan/shader/unaryop.comp +++ b/src/layer/vulkan/shader/unaryop.comp @@ -46,6 +46,21 @@ layout (push_constant) uniform parameter int cstep; } p; +float erf(float x) +{ + float a1 = 0.254829592f; + float a2 = -0.284496736f; + float a3 = 1.421413741f; + float a4 = -1.453152027f; + float a5 = 1.061405429f; + float p = 0.3275911f; + float s = sign(x); + float x_abs = abs(x); + float t = 1.0f/(1.0f + p*x_abs); + float y = 1.0f - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x_abs*x_abs); + return s * y; +} + void main() { int gx = int(gl_GlobalInvocationID.x); diff --git a/src/layer/vulkan/shader/unaryop_pack4.comp b/src/layer/vulkan/shader/unaryop_pack4.comp index f0649f36b741..b48c21e14cb7 100644 --- a/src/layer/vulkan/shader/unaryop_pack4.comp +++ b/src/layer/vulkan/shader/unaryop_pack4.comp @@ -46,6 +46,21 @@ layout (push_constant) uniform parameter int cstep; } p; +afpvec4 erf(afpvec4 x) +{ + afpvec4 a1 = afpvec4(0.254829592f); + afpvec4 a2 = afpvec4(-0.284496736f); + afpvec4 a3 = afpvec4(1.421413741f); + afpvec4 a4 = afpvec4(-1.453152027f); + afpvec4 a5 = afpvec4(1.061405429f); + afpvec4 p = afpvec4(0.3275911f); + afpvec4 s = sign(x); + afpvec4 x_abs = abs(x); + afpvec4 t = 1.0f / (1.0f + p * x_abs); + afpvec4 y = 1.0f - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(-x_abs * x_abs); + return s * y; +} + void main() { int gx = int(gl_GlobalInvocationID.x); @@ -89,7 +104,6 @@ void main() if (op_type == 17) res = log(v) * afp(0.434294481903); if (op_type == 18) res = round(v); if (op_type == 19) res = trunc(v); - if (op_type == 20) res = erf(v); #if NCNN_image_shader diff --git a/src/layer/vulkan/shader/unaryop_pack8.comp b/src/layer/vulkan/shader/unaryop_pack8.comp index 7ff1b04924f0..a6d9930321f2 100644 --- a/src/layer/vulkan/shader/unaryop_pack8.comp +++ b/src/layer/vulkan/shader/unaryop_pack8.comp @@ -47,6 +47,21 @@ layout (push_constant) uniform parameter int cstep; } p; +afpvec4 erf(afpvec4 x) +{ + afpvec4 a1 = afpvec4(0.254829592f); + afpvec4 a2 = afpvec4(-0.284496736f); + afpvec4 a3 = afpvec4(1.421413741f); + afpvec4 a4 = afpvec4(-1.453152027f); + afpvec4 a5 = afpvec4(1.061405429f); + afpvec4 p = afpvec4(0.3275911f); + afpvec4 s = sign(x); + afpvec4 x_abs = abs(x); + afpvec4 t = 1.0f / (1.0f + p * x_abs); + afpvec4 y = 1.0f - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * exp(-x_abs * x_abs); + return s * y; +} + void main() { int gx = int(gl_GlobalInvocationID.x); @@ -171,7 +186,6 @@ void main() res[0] = trunc(v[0]); res[1] = trunc(v[1]); } - if (op_type == 20) { res[0] = erf(v[0]); From 3a4c65da0885da53783cc221d6ef97dff75a9c89 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Wed, 30 Aug 2023 11:04:01 +0330 Subject: [PATCH 05/30] compiled spv files are not needed --- src/layer/vulkan/shader/comp.spv | Bin 10828 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/layer/vulkan/shader/comp.spv diff --git a/src/layer/vulkan/shader/comp.spv b/src/layer/vulkan/shader/comp.spv deleted file mode 100644 index 1ec53401886f4a8a6c222b6f61d620e739dc1830..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 10828 zcmZvh2e?$#6@`Z?QUpa58-l%J0~?BB!I#Ix>kPVz6b0(cvjoQ!9#{_ZozVmPF0;Xy{YMf?~F}rnKIqPL0BJnx#cnIVx8#i zNc<8+7b?oWE-*dFO`3W12D6VIG7JdQsnH!jr-^&vTefKD`-rz}$9=_Hwd0kas!?lpaO>x7oJDXr7@+O1hT#ypMQYn=F`c03t;BBs`9_!%uTTc%BFn=}*bNy@n< zPMPL*s#BEn{Ot2ojXlF-ir80VLKU~q@LVU%YMT_jt;E46gY9*17g*QC3-Obu$aijB zFJJU8CZ1i@b!|L}AK%*6);eu$TkDLm<2eKNwQR+cs^`W|Y-wwe_kK$REbs9=p8mGE zGbUBJrNlEQ&8l)wA-Fo5X6j zE3x-;l^k#2FZN(v?&(?kY`eUN@t`Vdk3E0);4<%-8f$$yInQAb)0`(>36wpqOuhFt zR?|TK>v_p(gNT=HV)>1Utz8c5I`@p>n^0w4?G)l!_O`EYyI)foU^hg-)pYyVIJ5$AMhYpS*id z-=$#BI(@!(UFo|4Y@dADcMGxajnpjymuJ(T>|=1g*6vGvk2hoaHMr{&c<}(R?`>o2 zp2FKB=eZgC{?43dn%Lg#+xNRG`*{KEe&pSE_}2^Gvk(79!Ot)FcMIOLO8xr>pntwa0gOX`*-&VJ3MZbh(Xn7V#Now?Kv1a~Xy1{HPYQnv}%XEO6P zE$YmrZd-8oqHepQ&Rptt1N+^SdAk>N=2ABb?DtFRMi+JFQg;NnS5bFlQD-i7$AJAz z&Acf^ow?MtfqNHqClqz&Qg;TpPf@p^s56(ki@|^(|yWlGDn`7P{p1vE1 zS19^!Ec(o)?rw13qVAre&Rpsq2d`MvEiUTJ<$dAzo@;BKO5Z=eXI;MMKZfyH^*uHU zKN|B{OFX6)kE_MU*5XrY@kOP7Ka2O^(kJ!)I!<*c2PDK62=|2)|{}EXFj{?i5e=Jzee$UG|{o}#* zw_xd?0G3bxM6jIwGMP74x$#bNpjl z-h23(8Gk~bGwpXya+7O0Ycl5qu>B8mFTW4{o7p|eFJfPECt~vMnSUZw%RS4_!Q{NB zICD<|dv^AmhIiRB4}U7Q0JBFNou`Ao z@fTxPVD^Y}CYOOd6Thc2=W?*Tx_v(ytHGJ}J105+HYjt}WX_df`*S8&f!(9LJ<*`g znOqH4x94QN^VD-K-lZOEuEFHokNC4jhYHK({?~!WVfNpQciDdf{wC~p%zkms<5sZe zaSN6?w}Iu=d?Vi2{ojnY-#N)$U&~pOId_2VKa=;ue)u~v_b8v5yTEy_?*+Tu^CJ8` z*aMjN6lc%(f!%ZY{=J`A-hJJTH+Ij?wBI?&<^7wSKKpqPY=8Fr5ZFD+&u{8@7_1g^ z>3ak`8cWTiVC&?)??Sv>>YoP7jo_Q`8T_-@66_evoOO@kpTo>KV{tR~^?eEK@~mFK zzlh1HRqR=H!@rE#=UitlgMS5+@5#N5U&T6LoiTIieGR-6v)B6@TbumX!OqU>;H+<8 z@|pcjaAuoJ@3+9tw%6Il*7}-7hhDAU!MoJ@I{w?3oLa@1`CV{b>-XT~Pit!bK3Fb2 zKLD2+jjc)khv2%#AHm6I=8wU0>G=t`)Mjjr+OluG+TO&w)bHRaf z)M{+4TBAd+*1zIiYJDI77fepA;>`R2T-W+HIQi84{Yz^80VjVJ?{nXG|HSN*Pv3{& z@+^$4Oa5Qr`dNGgC!d;+zoh09IQiWBQ?PyV>H9ahJS$`CJgc08-n05I-sM?+hW{Ls z^Q^>~>5Ec7s}8u1pgPCm1{g1M^M=F+JomZ*Op^)!5qP`+)0Omxq(j z>=nRV4QHE6Z(ne!)!16KMu%RlE8|^i?FU{7lT)iWGy8+RU*R`$zC!g5^z+BaAbLm|ZTxvD8R;|&YSL@n%msVmF2VB>> zE}VS!vmThM`YqdBde;Y+T8*tuegkk_>xOXhnY|I1tD0>ty&Ho|t;W`>H9GWa-3;$i ztG|s5#^ls0&dklhb*)>#$!9-Xg1M^SGt8xTD{!gR*xKZWfa_X^!pUd$)?lt`wz>3f z11_~1TdUUS(5rPA-lbN5-`O6MQ>!>LcL3M5?g%HJ{R{_lRlmuYOYctLQme7G$?puV zYaIb6pV_;BxvJUb(z`3T)M{+4TBAd+){%IZT6YKUfyt>=oSA!q>st4Mlh1zk26I)v zE1660KHyTTv9-zX3$APZ3Y>gqe-+GC%{G_b{lKMGV{6qK9eTAMfOn~NG;Weu4`?9lh1y}gSo2T_spet0=U#_Y;E!r!F8>Z;N&y=XfRhb z+gy4lgG;T()~Yo+^lCj0?^5d&@UfVjTE&?;6qp-d1p_)!5qP zXMpQkkB5`b?3rM$YPPxb&H|TOjjdH{bm-N3BHpFe6Tq`EIkk#2a}Kz!buOIzDE@l% z@5htC>)}f~xB0i_$>hx2lbQ3t&OPUoGjC63o(?Y0 z#@IT~Cc5>W%~^PtXR`o&CMM_Eh%@tSaQ$q~fs@Z`buO5T{}v#;ZfG*8A`-wcZ217n8F_oSFB7>slXxlh1x01amQ!-`)V} zeF$7?HMTbShrxBNkHE=i_M>2~YPPxbJ_ate8e6N@=+LY6NxVy~i@{G|a%vT4<`QsS z>(g-Z{@(Ny-gvv_^Zg9C2Xo~!<5_U2+1R?|p9j}9zW^uiJ)XlG=N>PDy@z~eyaX<_ z8(XLL=+ftR!=>Pb%*gM9FB99J-w|ISUP62fX07{>v(Dd)x5fPbTz=-e7ni@GY(t&D z8+FG#%bs}uAD3J&%=*NAYH{CMyizS*r53MVi~YBHIe)EMylyQX+>F)c-g;wV_q9S3 zyC*-d<=s#JCU!6W8&ckV48Zcc=s>Xg^SllMb8OZBs~m{8$Gys@XKirVvktL6oiTIv zs7F3M>w~>_;tjyof!Pdyb$&v)yl}NBu=ahJJur=A=7GP`SQ?n)5bIzW&0$V4anjzqF&QN0Om9_F$if%ozr@PChj|fXg{M5?d#qn&DudvCP@2%6V@8uh|(~&KW^$oqXo( z0`@h|Z^pZV<&xj6na>&TPHc^Qp2t1FK3hJ&=2JTo>^;rr_1_a*e_j?epO?MJnYYi^ g-}jST_OdtF=REN~U~6)}eZkhqf6ixxv7amd0}%`s>Hq)$ From 455c3e9ba04f323a420c728d7b88f44e74fe8738 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 01:19:07 +0330 Subject: [PATCH 06/30] working on other architectures --- src/layer/arm/unaryop_arm.cpp | 35 +++++++++++ src/layer/loongarch/unaryop_loongarch.cpp | 35 +++++++++++ src/layer/x86/unaryop_x86.cpp | 74 +++++++++++++++++++++++ 3 files changed, 144 insertions(+) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 5a054cc7c4d6..3b3714fe17bb 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -472,6 +472,35 @@ struct unary_op_trunc #endif // __ARM_NEON }; +struct unary_op_erf +{ + float func(const float& x) const + { + return (float)erf(x); + } +#if __ARM_NEON + float32x4_t func_pack4(const float32x4_t& x) const + { + float32x4_t a1 = vmovq_n_f32(0.254829592f); + float32x4_t a2 = vmovq_n_f32(-0.284496736f); + float32x4_t a3 = vmovq_n_f32(1.421413741f); + float32x4_t a4 = vmovq_n_f32(-1.453152027f); + float32x4_t a5 = vmovq_n_f32(1.061405429f); + float32x4_t p = vmovq_n_f32(0.3275911f); + float32x4_t s = vsign_f32(x); + float32x4_t x_abs = vabs_f32(x); + float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); + float32x4_t y = vsub_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); + y = vsub_f32(y, vmulq_f32(vmulq_f32(a3, t), t)); + y = vsub_f32(y, vmulq_f32(vmulq_f32(a2, t), t)); + y = vsub_f32(y, vmulq_f32(vmulq_f32(a1, t), t)); + y = vmulq_f32(y, t); + y = vmulq_f32(y, exp_f32(-vmulq_f32(x_abs, x_abs))); + return s * y; + } +#endif // __ARM_NEON +}; + } // namespace UnaryOp_arm_functor int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -550,6 +579,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } @@ -686,6 +718,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TRUNC) return unary_op_inplace_bf16s(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + return 0; } #endif // NCNN_BF16 diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index 4d4818cb5af0..d49fb2e435cf 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -416,6 +416,38 @@ struct unary_op_trunc #endif // __loongarch_sx }; +struct unary_op_erf +{ + float func(const float& x) const + { + return (float)erf(x); + } +#if __loongarch_sx + __m128 func_pack4(const __m128& x) const + { + __m128 a1 = (__m128)__lsx_vreplfr2vr_s(0.254829592f); + __m128 a2 = (__m128)__lsx_vreplfr2vr_s(-0.284496736f); + __m128 a3 = (__m128)__lsx_vreplfr2vr_s(1.421413741f); + __m128 a4 = (__m128)__lsx_vreplfr2vr_s(-1.453152027f); + __m128 a5 = (__m128)__lsx_vreplfr2vr_s(1.061405429f); + __m128 p = (__m128)__lsx_vreplfr2vr_s(0.3275911f); + __m128 x2 = (__m128)__lsx_vbitclri_w((__m128i)x, 31); + __m128i tiny_mask = __lsx_vfcmp_clt_s((__m128)x2, (__m128)(__m128)__lsx_vreplgr2vr_w(c_tanh_tiny.i)); + __m128i sig_mask = __lsx_vreplgr2vr_w(1 << 31); + __m128i s = __lsx_vand_v((__m128i)x, sig_mask); + __m128 x_abs = (__m128)__lsx_vbitclri_w(x, 31); + __m128 t = (__m128)__lsx_vfadd_s(x_abs, p); + __m128 y = __lsx_vfsub_s(__lsx_vfmul_s(__lsx_vfmul_s(a5, t), t), __lsx_vfmul_s(__lsx_vfmul_s(a4, t), t)); + y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfsub_s(a3, t), t)); + y = __lsx_vfsub_s(y, __lsx_vfmul_s(__lsx_vfmul_s(a2, t), t)); + y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfmul_s(a1, t), t)); + y = __lsx_vfmul_s(y, t); + y = __lsx_vfmul_s(y, exp_ps(-__lsx_vfmul_s(x_abs, x_abs))); + return (__m128)__lsx_vfmul_s(x, y); + } +#endif // __loongarch_sx +}; + } // namespace UnaryOp_loongarch_functor int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -482,6 +514,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 8629ab2093b4..89a1d1a58029 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -642,6 +642,77 @@ struct unary_op_trunc #endif // __SSE2__ }; +struct unary_op_trunc +{ + float func(const float& x) const + { + return (float)erf(x); + } +#if __SSE2__ + __m128 func_pack4(const __m128& x) const + { + __m128 a1 = _mm_set1_ps(0.254829592f); + __m128 a2 = _mm_set1_ps(-0.284496736f); + __m128 a3 = _mm_set1_ps(1.421413741f); + __m128 a4 = _mm_set1_ps(-1.453152027f); + __m128 a5 = _mm_set1_ps(1.061405429f); + __m128 p = _mm_set1_ps(0.3275911f); + __m128 s = _mm_sign_ps(x); + __m128 x_abs = _mm_abs_ps(x); + __m128 t = _mm_rcp_ps(_mm_add_ps(x_abs, p)); + __m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t)); + y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t)); + y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a2, t), t)); + y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a1, t), t)); + y = _mm_mul_ps(y, t); + y = _mm_mul_ps(y, _mm_exp_ps(-_mm_mul_ps(x_abs, x_abs))); + return _mm_mul_ps(s, y); + } +#if __AVX__ + __m256 func_pack8(const __m256& x) const + { + __m256 a1 = _mm256_set1_ps(0.254829592f); + __m256 a2 = _mm256_set1_ps(-0.284496736f); + __m256 a3 = _mm256_set1_ps(1.421413741f); + __m256 a4 = _mm256_set1_ps(-1.453152027f); + __m256 a5 = _mm256_set1_ps(1.061405429f); + __m256 p = _mm256_set1_ps(0.3275911f); + __m256 s = _mm256_sign_ps(x); + __m256 x_abs = _mm256_abs_ps(x); + __m256 t = _mm256_rcp_ps(_mm256_add_ps(x_abs, p)); + __m256 y = _mm256_sub_ps(_mm256_mul_ps(_mm256_mul_ps(a5, t), t), _mm256_mul_ps(_mm256_mul_ps(a4, t), t)); + y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a3, t), t)); + y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t)); + y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a1, t), t)); + y = _mm256_mul_ps(y, t); + y = _mm256_mul_ps(y, _mm256_exp_ps(-_mm256_mul_ps(x_abs, x_abs))); + return _mm256_mul_ps(s, y); + } +#if __AVX512F__ + __m512 func_pack16(const __m512& x) const + { + __m512 a1 = _mm512_set1_ps(0.254829592f); + __m512 a2 = _mm512_set1_ps(-0.284496736f); + __m512 a3 = _mm512_set1_ps(1.421413741f); + __m512 a4 = _mm512_set1_ps(-1.453152027f); + __m512 a5 = _mm512_set1_ps(1.061405429f); + __m512 p = _mm512_set1_ps(0.3275911f); + __m512 s = _mm512_sign_ps(x); + __m512 x_abs = _mm512_abs_ps(x); + __m512 t = _mm512_rcp_ps(_mm512_add_ps(x_abs, p)); + __m512 y = _mm512_sub_ps(_mm512_mul_ps(_mm512_mul_ps(a5, t), t), _mm512_mul_ps(_mm512_mul_ps(a4, t), t)); + y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a3, t), t)); + y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t)); + y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a1, t), t)); + y = _mm512_mul_ps(y, t); + y = _mm512_mul_ps(y, _mm512_exp_ps(-_mm512_mul_ps(x_abs, x_abs))); + return _mm512_mul_ps(s, y); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + } // namespace UnaryOp_x86_functor int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -707,6 +778,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } From a7e8802edd01ec5a150262c2ae1e1efe1bc166d7 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 01:48:05 +0330 Subject: [PATCH 07/30] fixed loongarch stuff --- src/layer/loongarch/unaryop_loongarch.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index d49fb2e435cf..0eb8a7ca27aa 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -425,6 +425,7 @@ struct unary_op_erf #if __loongarch_sx __m128 func_pack4(const __m128& x) const { +__m128 ones = (__m128)__lsx_vreplfr2vr_s(1.0f); __m128 a1 = (__m128)__lsx_vreplfr2vr_s(0.254829592f); __m128 a2 = (__m128)__lsx_vreplfr2vr_s(-0.284496736f); __m128 a3 = (__m128)__lsx_vreplfr2vr_s(1.421413741f); @@ -436,7 +437,7 @@ struct unary_op_erf __m128i sig_mask = __lsx_vreplgr2vr_w(1 << 31); __m128i s = __lsx_vand_v((__m128i)x, sig_mask); __m128 x_abs = (__m128)__lsx_vbitclri_w(x, 31); - __m128 t = (__m128)__lsx_vfadd_s(x_abs, p); + __m128 t = (__m128)__lsx_vfdiv_s(ones, __lsx_vfadd_s(__lsx_vfmul_s(x_abs, p))); __m128 y = __lsx_vfsub_s(__lsx_vfmul_s(__lsx_vfmul_s(a5, t), t), __lsx_vfmul_s(__lsx_vfmul_s(a4, t), t)); y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfsub_s(a3, t), t)); y = __lsx_vfsub_s(y, __lsx_vfmul_s(__lsx_vfmul_s(a2, t), t)); From e7d7c3445f2535e247da43669d24dbed9521dfa1 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 10:16:01 +0330 Subject: [PATCH 08/30] trying to fix arm stuff --- src/layer/arm/unaryop_arm.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 3b3714fe17bb..1a04d81125a9 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -481,14 +481,14 @@ struct unary_op_erf #if __ARM_NEON float32x4_t func_pack4(const float32x4_t& x) const { - float32x4_t a1 = vmovq_n_f32(0.254829592f); - float32x4_t a2 = vmovq_n_f32(-0.284496736f); - float32x4_t a3 = vmovq_n_f32(1.421413741f); - float32x4_t a4 = vmovq_n_f32(-1.453152027f); - float32x4_t a5 = vmovq_n_f32(1.061405429f); - float32x4_t p = vmovq_n_f32(0.3275911f); - float32x4_t s = vsign_f32(x); - float32x4_t x_abs = vabs_f32(x); + float32x4_t a1 = vdupq_n_f32(0.254829592f); + float32x4_t a2 = vdupq_n_f32(-0.284496736f); + float32x4_t a3 = vdupq_n_f32(1.421413741f); + float32x4_t a4 = vdupq_n_f32(-1.453152027f); + float32x4_t a5 = vdupq_n_f32(1.061405429f); + float32x4_t p = vdupq_n_f32(0.3275911f); + float32x4_t s = vcltq_f32(x, vdupq_n_f32(0)); + float32x4_t x_abs = vabsq_f32(x); float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); float32x4_t y = vsub_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); y = vsub_f32(y, vmulq_f32(vmulq_f32(a3, t), t)); From b0bb0787a20befff8bdb5286e81fe821934bac8c Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 10:23:54 +0330 Subject: [PATCH 09/30] working on arm stuff --- src/layer/arm/unaryop_arm.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 1a04d81125a9..dc62dbe2ecec 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -490,13 +490,13 @@ struct unary_op_erf float32x4_t s = vcltq_f32(x, vdupq_n_f32(0)); float32x4_t x_abs = vabsq_f32(x); float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); - float32x4_t y = vsub_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); - y = vsub_f32(y, vmulq_f32(vmulq_f32(a3, t), t)); - y = vsub_f32(y, vmulq_f32(vmulq_f32(a2, t), t)); - y = vsub_f32(y, vmulq_f32(vmulq_f32(a1, t), t)); + float32x4_t y = vsubq_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); + y = vsubq_f32(y, vmulq_f32(vmulq_f32(a3, t), t)); + y = vsubq_f32(y, vmulq_f32(vmulq_f32(a2, t), t)); + y = vsubq_f32(y, vmulq_f32(vmulq_f32(a1, t), t)); y = vmulq_f32(y, t); - y = vmulq_f32(y, exp_f32(-vmulq_f32(x_abs, x_abs))); - return s * y; + y = vmulq_f32(y, exp_ps(vnegq_f32(vmulq_f32(x_abs, x_abs)))); + return vmulq_f32(s, y); } #endif // __ARM_NEON }; From 91d94a7b42ee6b3e2b1c9ecdef29229b9cd06a5c Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 10:48:22 +0330 Subject: [PATCH 10/30] fixing x86 stuff --- src/layer/x86/unaryop_x86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 89a1d1a58029..e29d0e46f6f7 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -642,7 +642,7 @@ struct unary_op_trunc #endif // __SSE2__ }; -struct unary_op_trunc +struct unary_op_erf { float func(const float& x) const { From dbc5441986def861d5c04cff2358c68010c6f921 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sat, 2 Sep 2023 11:01:12 +0330 Subject: [PATCH 11/30] mips implementation --- src/layer/mips/unaryop_mips.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/layer/mips/unaryop_mips.cpp b/src/layer/mips/unaryop_mips.cpp index b923535a2d8a..3997a4be84c4 100644 --- a/src/layer/mips/unaryop_mips.cpp +++ b/src/layer/mips/unaryop_mips.cpp @@ -436,6 +436,27 @@ struct unary_op_trunc #endif // __mips_msa }; +struct unary_op_sin +{ + float func(const float& x) const + { + return (float)sin(x); + } +#if __mips_msa + v4f32 func_pack4(const v4f32& x) const + { + // TODO msa optimize + float tmp[4]; + __msa_st_w((v4i32)x, tmp, 0); + tmp[0] = erf(tmp[0]); + tmp[1] = erf(tmp[1]); + tmp[2] = erf(tmp[2]); + tmp[3] = erf(tmp[3]); + return (v4f32)__msa_ld_w(tmp, 0); + } +#endif // __mips_msa +}; + } // namespace UnaryOp_mips_functor int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -502,6 +523,9 @@ int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } From 5d4a226f87952a1176e63dccdedce7395d9d47f4 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 07:38:59 +0330 Subject: [PATCH 12/30] trying to fix x86 stuff --- src/layer/x86/unaryop_x86.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index e29d0e46f6f7..3eafe2193345 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -657,8 +657,11 @@ struct unary_op_erf __m128 a4 = _mm_set1_ps(-1.453152027f); __m128 a5 = _mm_set1_ps(1.061405429f); __m128 p = _mm_set1_ps(0.3275911f); - __m128 s = _mm_sign_ps(x); - __m128 x_abs = _mm_abs_ps(x); + const __m128 zero = _mm_set_ps1 (0.0f); + __m128 positives = _mm_and_ps(_mm_cmpgt_ps (value, zero), _mm_set_ps1(1.0f)); + __m128 negatives = _mm_and_ps(_mm_cmplt_ps (value, zero), _mm_set_ps1(-1.0f)); + __m128 s = _mm_or_ps(positives, negatives); + __m128 x_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), m); __m128 t = _mm_rcp_ps(_mm_add_ps(x_abs, p)); __m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t)); y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t)); From 4f9c6df580be67df71f2973573f6807e86b55bfa Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 07:43:17 +0330 Subject: [PATCH 13/30] fixing arm again --- src/layer/arm/unaryop_arm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index dc62dbe2ecec..465ec87c74ee 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -487,7 +487,7 @@ struct unary_op_erf float32x4_t a4 = vdupq_n_f32(-1.453152027f); float32x4_t a5 = vdupq_n_f32(1.061405429f); float32x4_t p = vdupq_n_f32(0.3275911f); - float32x4_t s = vcltq_f32(x, vdupq_n_f32(0)); + float32x4_t s = vcltq_f32(x, vdupq_n_f32(0.0f)); float32x4_t x_abs = vabsq_f32(x); float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); float32x4_t y = vsubq_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); From e076ac23003f13d3605f9d130e85c4a1cf403687 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 12:23:22 +0330 Subject: [PATCH 14/30] added ERF to the list of operators --- docs/developer-guide/operators.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 17acf4ec03f3..839e45402173 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1781,3 +1781,4 @@ Operation type: - 17 = LOG10 - 18 = ROUND - 19 = TRUNC +- 20 = ERF From c0b256cb9a4113b1440002b037494814afbde80b Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 13:23:57 +0330 Subject: [PATCH 15/30] fixing arm and x86 operators --- src/layer/arm/unaryop_arm.cpp | 2 +- src/layer/x86/unaryop_x86.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 465ec87c74ee..624de1caa846 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -487,7 +487,7 @@ struct unary_op_erf float32x4_t a4 = vdupq_n_f32(-1.453152027f); float32x4_t a5 = vdupq_n_f32(1.061405429f); float32x4_t p = vdupq_n_f32(0.3275911f); - float32x4_t s = vcltq_f32(x, vdupq_n_f32(0.0f)); + float32x4_t s = vcvtq_f32_s32(vcltq_f32(x, vdupq_n_f32(0.0f))); float32x4_t x_abs = vabsq_f32(x); float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); float32x4_t y = vsubq_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 3eafe2193345..dd8a12c61ec8 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -658,17 +658,17 @@ struct unary_op_erf __m128 a5 = _mm_set1_ps(1.061405429f); __m128 p = _mm_set1_ps(0.3275911f); const __m128 zero = _mm_set_ps1 (0.0f); - __m128 positives = _mm_and_ps(_mm_cmpgt_ps (value, zero), _mm_set_ps1(1.0f)); - __m128 negatives = _mm_and_ps(_mm_cmplt_ps (value, zero), _mm_set_ps1(-1.0f)); + __m128 positives = _mm_and_ps(_mm_cmpgt_ps (x, zero), _mm_set_ps1(1.0f)); + __m128 negatives = _mm_and_ps(_mm_cmplt_ps (x, zero), _mm_set_ps1(-1.0f)); __m128 s = _mm_or_ps(positives, negatives); - __m128 x_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), m); + __m128 x_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), x); __m128 t = _mm_rcp_ps(_mm_add_ps(x_abs, p)); __m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t)); y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t)); y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a2, t), t)); y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a1, t), t)); y = _mm_mul_ps(y, t); - y = _mm_mul_ps(y, _mm_exp_ps(-_mm_mul_ps(x_abs, x_abs))); + y = _mm_mul_ps(y, exp_ps(_mm_sub_ps(_mm_setzero_ps(), _mm_mul_ps(x_abs, x_abs)))); return _mm_mul_ps(s, y); } #if __AVX__ From 290c4e38ee34470026eb3ac86185aac68ef91d96 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 17:04:16 +0330 Subject: [PATCH 16/30] fixed the errors with the x86 architecture --- src/layer/x86/unaryop_x86.cpp | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index dd8a12c61ec8..2b34a6d11759 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -680,15 +680,18 @@ struct unary_op_erf __m256 a4 = _mm256_set1_ps(-1.453152027f); __m256 a5 = _mm256_set1_ps(1.061405429f); __m256 p = _mm256_set1_ps(0.3275911f); - __m256 s = _mm256_sign_ps(x); - __m256 x_abs = _mm256_abs_ps(x); + const __m256 zero = _mm256_set1_ps(0.0f); + __m256 positives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_GT_OQ), _mm256_set1_ps(1.0f)); + __m256 negatives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_LT_OQ), _mm256_set1_ps(-1.0f)); + __m256 s = _mm256_or_ps(positives, negatives); + __m256 x_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), x); __m256 t = _mm256_rcp_ps(_mm256_add_ps(x_abs, p)); __m256 y = _mm256_sub_ps(_mm256_mul_ps(_mm256_mul_ps(a5, t), t), _mm256_mul_ps(_mm256_mul_ps(a4, t), t)); y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a3, t), t)); y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t)); y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a1, t), t)); y = _mm256_mul_ps(y, t); - y = _mm256_mul_ps(y, _mm256_exp_ps(-_mm256_mul_ps(x_abs, x_abs))); + y = _mm256_mul_ps(y, _mm256_exp_ps(_mm256_sub_ps(_mm256_setzero_ps(), _mm256_mul_ps(x_abs, x_abs)))); return _mm256_mul_ps(s, y); } #if __AVX512F__ @@ -700,15 +703,18 @@ struct unary_op_erf __m512 a4 = _mm512_set1_ps(-1.453152027f); __m512 a5 = _mm512_set1_ps(1.061405429f); __m512 p = _mm512_set1_ps(0.3275911f); - __m512 s = _mm512_sign_ps(x); - __m512 x_abs = _mm512_abs_ps(x); - __m512 t = _mm512_rcp_ps(_mm512_add_ps(x_abs, p)); + const __m512 zero = _mm512_set1_ps(0.0f); + __m512 positives = _mm512_and_ps(_mm512_mask_blend_ps(_mm512_cmplt_ps_mask(zero, x), x, zero), _mm512_set1_ps(1.0f)); + __m512 negatives = _mm512_and_ps(_mm512_mask_blend_ps(_mm512_cmplt_ps_mask(x, zero), x, zero), _mm512_set1_ps(-1.0f)); + __m512 s = _mm512_or_ps(positives, negatives); + __m512 x_abs = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), x); + __m512 t = _mm512_div_ps(_mm512_set1_ps(1.0), _mm512_add_ps(x_abs, p)); __m512 y = _mm512_sub_ps(_mm512_mul_ps(_mm512_mul_ps(a5, t), t), _mm512_mul_ps(_mm512_mul_ps(a4, t), t)); y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a3, t), t)); y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t)); y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a1, t), t)); y = _mm512_mul_ps(y, t); - y = _mm512_mul_ps(y, _mm512_exp_ps(-_mm512_mul_ps(x_abs, x_abs))); + y = _mm512_mul_ps(y, _mm512_exp_ps(_mm512_sub_ps(_mm512_setzero_ps(), _mm512_mul_ps(x_abs, x_abs)))); return _mm512_mul_ps(s, y); } #endif // __AVX512F__ From 6dccf379bb30ce01e90336218b052b253221fd5f Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 21:38:44 +0330 Subject: [PATCH 17/30] x86: msvc has __m256 and __m512 exponents, but there was nothing for gcc/clang as far as I'm aware. now they are fixed with avx_mathfun.h --- src/layer/x86/unaryop_x86.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 2b34a6d11759..c6c7b43b7083 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -691,7 +691,7 @@ struct unary_op_erf y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t)); y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a1, t), t)); y = _mm256_mul_ps(y, t); - y = _mm256_mul_ps(y, _mm256_exp_ps(_mm256_sub_ps(_mm256_setzero_ps(), _mm256_mul_ps(x_abs, x_abs)))); + y = _mm256_mul_ps(y, exp256_ps(_mm256_sub_ps(_mm256_setzero_ps(), _mm256_mul_ps(x_abs, x_abs)))); return _mm256_mul_ps(s, y); } #if __AVX512F__ @@ -714,7 +714,7 @@ struct unary_op_erf y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t)); y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a1, t), t)); y = _mm512_mul_ps(y, t); - y = _mm512_mul_ps(y, _mm512_exp_ps(_mm512_sub_ps(_mm512_setzero_ps(), _mm512_mul_ps(x_abs, x_abs)))); + y = _mm512_mul_ps(y, exp512_ps(_mm512_sub_ps(_mm512_setzero_ps(), _mm512_mul_ps(x_abs, x_abs)))); return _mm512_mul_ps(s, y); } #endif // __AVX512F__ From b36ea45db856fbd63a231ac99df41dd6dd393a2b Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Sun, 3 Sep 2023 21:49:09 +0330 Subject: [PATCH 18/30] mips fix --- src/layer/loongarch/unaryop_loongarch.cpp | 5 ++--- src/layer/mips/unaryop_mips.cpp | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index 0eb8a7ca27aa..d35818e1b4d0 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -433,18 +433,17 @@ __m128 ones = (__m128)__lsx_vreplfr2vr_s(1.0f); __m128 a5 = (__m128)__lsx_vreplfr2vr_s(1.061405429f); __m128 p = (__m128)__lsx_vreplfr2vr_s(0.3275911f); __m128 x2 = (__m128)__lsx_vbitclri_w((__m128i)x, 31); - __m128i tiny_mask = __lsx_vfcmp_clt_s((__m128)x2, (__m128)(__m128)__lsx_vreplgr2vr_w(c_tanh_tiny.i)); __m128i sig_mask = __lsx_vreplgr2vr_w(1 << 31); __m128i s = __lsx_vand_v((__m128i)x, sig_mask); __m128 x_abs = (__m128)__lsx_vbitclri_w(x, 31); - __m128 t = (__m128)__lsx_vfdiv_s(ones, __lsx_vfadd_s(__lsx_vfmul_s(x_abs, p))); + __m128 t = (__m128)__lsx_vfdiv_s(ones, __lsx_vfadd_s(ones, __lsx_vfmul_s(x_abs, p))); __m128 y = __lsx_vfsub_s(__lsx_vfmul_s(__lsx_vfmul_s(a5, t), t), __lsx_vfmul_s(__lsx_vfmul_s(a4, t), t)); y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfsub_s(a3, t), t)); y = __lsx_vfsub_s(y, __lsx_vfmul_s(__lsx_vfmul_s(a2, t), t)); y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfmul_s(a1, t), t)); y = __lsx_vfmul_s(y, t); y = __lsx_vfmul_s(y, exp_ps(-__lsx_vfmul_s(x_abs, x_abs))); - return (__m128)__lsx_vfmul_s(x, y); + return (__m128)__lsx_vfmul_s(s, y); } #endif // __loongarch_sx }; diff --git a/src/layer/mips/unaryop_mips.cpp b/src/layer/mips/unaryop_mips.cpp index 3997a4be84c4..0792720ce921 100644 --- a/src/layer/mips/unaryop_mips.cpp +++ b/src/layer/mips/unaryop_mips.cpp @@ -436,11 +436,11 @@ struct unary_op_trunc #endif // __mips_msa }; -struct unary_op_sin +struct unary_op_erf { float func(const float& x) const { - return (float)sin(x); + return (float)erf(x); } #if __mips_msa v4f32 func_pack4(const v4f32& x) const From 15f9aeb780a4fdb3dbcd52e5e662d8dad1a8b2a2 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Mon, 4 Sep 2023 09:44:15 +0330 Subject: [PATCH 19/30] fixed arm --- src/layer/arm/unaryop_arm.cpp | 5 +++-- tests/test_unaryop.cpp | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 624de1caa846..265cfe119887 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -487,7 +487,8 @@ struct unary_op_erf float32x4_t a4 = vdupq_n_f32(-1.453152027f); float32x4_t a5 = vdupq_n_f32(1.061405429f); float32x4_t p = vdupq_n_f32(0.3275911f); - float32x4_t s = vcvtq_f32_s32(vcltq_f32(x, vdupq_n_f32(0.0f))); + const uint32x4_t szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); + uint32x4_t s = vandq_u32(vreinterpretq_u32_f32(a), szero); float32x4_t x_abs = vabsq_f32(x); float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); float32x4_t y = vsubq_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); @@ -496,7 +497,7 @@ struct unary_op_erf y = vsubq_f32(y, vmulq_f32(vmulq_f32(a1, t), t)); y = vmulq_f32(y, t); y = vmulq_f32(y, exp_ps(vnegq_f32(vmulq_f32(x_abs, x_abs)))); - return vmulq_f32(s, y); + return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(y), s)); } #endif // __ARM_NEON }; diff --git a/tests/test_unaryop.cpp b/tests/test_unaryop.cpp index 44274fd071f9..0573e458a853 100644 --- a/tests/test_unaryop.cpp +++ b/tests/test_unaryop.cpp @@ -15,7 +15,7 @@ #include "layer/unaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 20 +#define OP_TYPE_MAX 21 static int op_type = 0; From 4bb5c3075896407183c99cc4a6d59209585404f5 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Mon, 4 Sep 2023 11:28:36 +0330 Subject: [PATCH 20/30] arm should now be fixed --- src/layer/arm/unaryop_arm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 265cfe119887..7aa0e43c9b69 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -488,7 +488,7 @@ struct unary_op_erf float32x4_t a5 = vdupq_n_f32(1.061405429f); float32x4_t p = vdupq_n_f32(0.3275911f); const uint32x4_t szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); - uint32x4_t s = vandq_u32(vreinterpretq_u32_f32(a), szero); + uint32x4_t s = vandq_u32(vreinterpretq_u32_f32(x), szero); float32x4_t x_abs = vabsq_f32(x); float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); float32x4_t y = vsubq_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); From ea2af2ebc66328953e6110472c2474b679c4c085 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Mon, 4 Sep 2023 14:10:19 +0330 Subject: [PATCH 21/30] x86 tests should pass --- .../vulkan/shader/{unaryop.comp => unaryop.comp.txt} | 0 src/layer/x86/unaryop_x86.cpp | 11 ++++------- 2 files changed, 4 insertions(+), 7 deletions(-) rename src/layer/vulkan/shader/{unaryop.comp => unaryop.comp.txt} (100%) diff --git a/src/layer/vulkan/shader/unaryop.comp b/src/layer/vulkan/shader/unaryop.comp.txt similarity index 100% rename from src/layer/vulkan/shader/unaryop.comp rename to src/layer/vulkan/shader/unaryop.comp.txt diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index c6c7b43b7083..69cfd339884b 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -662,7 +662,7 @@ struct unary_op_erf __m128 negatives = _mm_and_ps(_mm_cmplt_ps (x, zero), _mm_set_ps1(-1.0f)); __m128 s = _mm_or_ps(positives, negatives); __m128 x_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), x); - __m128 t = _mm_rcp_ps(_mm_add_ps(x_abs, p)); + __m128 t = _mm_rcp_ps(_mm_mul_ps(_mm_add_ps(_mm_set_ps1(1.0f), p), x_abs)); __m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t)); y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t)); y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a2, t), t)); @@ -685,7 +685,7 @@ struct unary_op_erf __m256 negatives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_LT_OQ), _mm256_set1_ps(-1.0f)); __m256 s = _mm256_or_ps(positives, negatives); __m256 x_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), x); - __m256 t = _mm256_rcp_ps(_mm256_add_ps(x_abs, p)); + __m256 t = _mm256_rcp_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_set1_ps(1.0f), p), x_abs)); __m256 y = _mm256_sub_ps(_mm256_mul_ps(_mm256_mul_ps(a5, t), t), _mm256_mul_ps(_mm256_mul_ps(a4, t), t)); y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a3, t), t)); y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t)); @@ -703,12 +703,9 @@ struct unary_op_erf __m512 a4 = _mm512_set1_ps(-1.453152027f); __m512 a5 = _mm512_set1_ps(1.061405429f); __m512 p = _mm512_set1_ps(0.3275911f); - const __m512 zero = _mm512_set1_ps(0.0f); - __m512 positives = _mm512_and_ps(_mm512_mask_blend_ps(_mm512_cmplt_ps_mask(zero, x), x, zero), _mm512_set1_ps(1.0f)); - __m512 negatives = _mm512_and_ps(_mm512_mask_blend_ps(_mm512_cmplt_ps_mask(x, zero), x, zero), _mm512_set1_ps(-1.0f)); - __m512 s = _mm512_or_ps(positives, negatives); + __m512 s = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 x_abs = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), x); - __m512 t = _mm512_div_ps(_mm512_set1_ps(1.0), _mm512_add_ps(x_abs, p)); + __m512 t = _mm512_div_ps(_mm512_set1_ps(1.0f), _mm512_mul_ps(_mm512_add_ps(_mm512_set1_ps(1.0f), p), x_abs)); __m512 y = _mm512_sub_ps(_mm512_mul_ps(_mm512_mul_ps(a5, t), t), _mm512_mul_ps(_mm512_mul_ps(a4, t), t)); y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a3, t), t)); y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t)); From f47360767ae50927cfd201d74840f1372d0cd334 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Mon, 4 Sep 2023 16:08:16 +0330 Subject: [PATCH 22/30] forgot to return back the name of unaryop.comp as a shader --- src/layer/vulkan/shader/{unaryop.comp.txt => unaryop.comp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/layer/vulkan/shader/{unaryop.comp.txt => unaryop.comp} (100%) diff --git a/src/layer/vulkan/shader/unaryop.comp.txt b/src/layer/vulkan/shader/unaryop.comp similarity index 100% rename from src/layer/vulkan/shader/unaryop.comp.txt rename to src/layer/vulkan/shader/unaryop.comp From c78c17f494ea1fe8620db2f8cf86a5baa12f0a07 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Mon, 4 Sep 2023 18:52:56 +0330 Subject: [PATCH 23/30] rewrite of erf for x86 stuff to fix the tests --- src/layer/x86/unaryop_x86.cpp | 37 +++++++++++++---------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 69cfd339884b..541e452bb6dc 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -657,18 +657,13 @@ struct unary_op_erf __m128 a4 = _mm_set1_ps(-1.453152027f); __m128 a5 = _mm_set1_ps(1.061405429f); __m128 p = _mm_set1_ps(0.3275911f); - const __m128 zero = _mm_set_ps1 (0.0f); - __m128 positives = _mm_and_ps(_mm_cmpgt_ps (x, zero), _mm_set_ps1(1.0f)); - __m128 negatives = _mm_and_ps(_mm_cmplt_ps (x, zero), _mm_set_ps1(-1.0f)); - __m128 s = _mm_or_ps(positives, negatives); + __m128 s = _mm_and_ps(x, _mm_set1_ps(-0.0f)); __m128 x_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), x); __m128 t = _mm_rcp_ps(_mm_mul_ps(_mm_add_ps(_mm_set_ps1(1.0f), p), x_abs)); - __m128 y = _mm_sub_ps(_mm_mul_ps(_mm_mul_ps(a5, t), t), _mm_mul_ps(_mm_mul_ps(a4, t), t)); - y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a3, t), t)); - y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a2, t), t)); - y = _mm_sub_ps(y, _mm_mul_ps(_mm_mul_ps(a1, t), t)); - y = _mm_mul_ps(y, t); - y = _mm_mul_ps(y, exp_ps(_mm_sub_ps(_mm_setzero_ps(), _mm_mul_ps(x_abs, x_abs)))); + __m128 y = _mm_set1_ps(1.0f); + __m128 err = _mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t); + err = exp_ps(_mm_mul_ps(_mm_sub_ps(_mm_setzero_ps(), x_abs), x_abs)); + y = _mm_sub_ps(y, err); return _mm_mul_ps(s, y); } #if __AVX__ @@ -686,12 +681,10 @@ struct unary_op_erf __m256 s = _mm256_or_ps(positives, negatives); __m256 x_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), x); __m256 t = _mm256_rcp_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_set1_ps(1.0f), p), x_abs)); - __m256 y = _mm256_sub_ps(_mm256_mul_ps(_mm256_mul_ps(a5, t), t), _mm256_mul_ps(_mm256_mul_ps(a4, t), t)); - y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a3, t), t)); - y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a2, t), t)); - y = _mm256_sub_ps(y, _mm256_mul_ps(_mm256_mul_ps(a1, t), t)); - y = _mm256_mul_ps(y, t); - y = _mm256_mul_ps(y, exp256_ps(_mm256_sub_ps(_mm256_setzero_ps(), _mm256_mul_ps(x_abs, x_abs)))); + __m256 y = _mm256_set1_ps(1.0f); + __m256 err = _mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t); + err = exp256_ps(_mm256_mul_ps(_mm256_sub_ps(_mm256_setzero_ps(), x_abs), x_abs)); + y = _mm256_sub_ps(y, err); return _mm256_mul_ps(s, y); } #if __AVX512F__ @@ -705,13 +698,11 @@ struct unary_op_erf __m512 p = _mm512_set1_ps(0.3275911f); __m512 s = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); __m512 x_abs = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), x); - __m512 t = _mm512_div_ps(_mm512_set1_ps(1.0f), _mm512_mul_ps(_mm512_add_ps(_mm512_set1_ps(1.0f), p), x_abs)); - __m512 y = _mm512_sub_ps(_mm512_mul_ps(_mm512_mul_ps(a5, t), t), _mm512_mul_ps(_mm512_mul_ps(a4, t), t)); - y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a3, t), t)); - y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a2, t), t)); - y = _mm512_sub_ps(y, _mm512_mul_ps(_mm512_mul_ps(a1, t), t)); - y = _mm512_mul_ps(y, t); - y = _mm512_mul_ps(y, exp512_ps(_mm512_sub_ps(_mm512_setzero_ps(), _mm512_mul_ps(x_abs, x_abs)))); +__m512 t = _mm512_div_ps(_mm512_set1_ps(1.0f), _mm512_mul_ps(_mm512_add_ps(_mm512_set1_ps(1.0f), p), x_abs)); + __m512 y = _mm512_set1_ps(1.0f); + __m512 err = _mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t); + err = exp512_ps(_mm512_mul_ps(_mm512_sub_ps(_mm512_setzero_ps(), x_abs), x_abs)); + y = _mm512_sub_ps(y, err); return _mm512_mul_ps(s, y); } #endif // __AVX512F__ From cc8b1d06984e22bc90a11b482f0b3d8da87e0571 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Mon, 4 Sep 2023 21:31:43 +0330 Subject: [PATCH 24/30] arm test fix. its now like x86 thing --- src/layer/arm/unaryop_arm.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index 7aa0e43c9b69..ef8ea45b1e7b 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -488,15 +488,14 @@ struct unary_op_erf float32x4_t a5 = vdupq_n_f32(1.061405429f); float32x4_t p = vdupq_n_f32(0.3275911f); const uint32x4_t szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); + float32x4_t sone = vdupq_n_f32(1.0f); uint32x4_t s = vandq_u32(vreinterpretq_u32_f32(x), szero); float32x4_t x_abs = vabsq_f32(x); - float32x4_t t = vrecpeq_f32(vaddq_f32(x_abs, p)); - float32x4_t y = vsubq_f32(vmulq_f32(vmulq_f32(a5, t), t), vmulq_f32(vmulq_f32(a4, t), t)); - y = vsubq_f32(y, vmulq_f32(vmulq_f32(a3, t), t)); - y = vsubq_f32(y, vmulq_f32(vmulq_f32(a2, t), t)); - y = vsubq_f32(y, vmulq_f32(vmulq_f32(a1, t), t)); - y = vmulq_f32(y, t); - y = vmulq_f32(y, exp_ps(vnegq_f32(vmulq_f32(x_abs, x_abs)))); + float32x4_t t = vrecpeq_f32(vmulq_f32(vaddq_f32(sone, p), x_abs)); + float32x4_t y = vdupq_n_f32(1.0f); + float32x4_t err = vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(a5, t), a4), t), a3), t), a2), t), a1), t); + err = exp_ps(vmulq_f32(vsubq_f32(vdupq_n_f32(0.0f), x_abs), x_abs)); + y = vsubq_f32(y, err); return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(y), s)); } #endif // __ARM_NEON From 15c2ed4d2086f6e3502f506765ba58ca1e5674d9 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Tue, 5 Sep 2023 08:48:28 +0330 Subject: [PATCH 25/30] use provided abs stuff, in order to check if tests work or not --- src/layer/x86/unaryop_x86.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 541e452bb6dc..6394ed3d83d8 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -658,7 +658,7 @@ struct unary_op_erf __m128 a5 = _mm_set1_ps(1.061405429f); __m128 p = _mm_set1_ps(0.3275911f); __m128 s = _mm_and_ps(x, _mm_set1_ps(-0.0f)); - __m128 x_abs = _mm_andnot_ps(_mm_set1_ps(-0.0f), x); + __m128 x_abs = abs_ps(x); __m128 t = _mm_rcp_ps(_mm_mul_ps(_mm_add_ps(_mm_set_ps1(1.0f), p), x_abs)); __m128 y = _mm_set1_ps(1.0f); __m128 err = _mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t); @@ -679,7 +679,7 @@ struct unary_op_erf __m256 positives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_GT_OQ), _mm256_set1_ps(1.0f)); __m256 negatives = _mm256_and_ps(_mm256_cmp_ps(x, zero, _CMP_LT_OQ), _mm256_set1_ps(-1.0f)); __m256 s = _mm256_or_ps(positives, negatives); - __m256 x_abs = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), x); + __m256 x_abs = abs256_ps(x); __m256 t = _mm256_rcp_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_set1_ps(1.0f), p), x_abs)); __m256 y = _mm256_set1_ps(1.0f); __m256 err = _mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_add_ps(_mm256_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t); @@ -697,7 +697,7 @@ struct unary_op_erf __m512 a5 = _mm512_set1_ps(1.061405429f); __m512 p = _mm512_set1_ps(0.3275911f); __m512 s = _mm512_and_ps(x, _mm512_set1_ps(-0.0f)); - __m512 x_abs = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), x); + __m512 x_abs = abs512_ps(x); __m512 t = _mm512_div_ps(_mm512_set1_ps(1.0f), _mm512_mul_ps(_mm512_add_ps(_mm512_set1_ps(1.0f), p), x_abs)); __m512 y = _mm512_set1_ps(1.0f); __m512 err = _mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_add_ps(_mm512_mul_ps(a5, t), a4), t), a3), t), a2), t), a1), t); From 8602ed8da87c6ce98ecebc709b7c4957c00be822 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Tue, 5 Sep 2023 10:48:00 +0330 Subject: [PATCH 26/30] removed vectorization stuff, it failes the tests --- src/layer/arm/unaryop_arm.cpp | 33 ++-------------------- src/layer/loongarch/unaryop_loongarch.cpp | 34 +---------------------- 2 files changed, 3 insertions(+), 64 deletions(-) diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index ef8ea45b1e7b..350da9ce3a9d 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -472,35 +472,6 @@ struct unary_op_trunc #endif // __ARM_NEON }; -struct unary_op_erf -{ - float func(const float& x) const - { - return (float)erf(x); - } -#if __ARM_NEON - float32x4_t func_pack4(const float32x4_t& x) const - { - float32x4_t a1 = vdupq_n_f32(0.254829592f); - float32x4_t a2 = vdupq_n_f32(-0.284496736f); - float32x4_t a3 = vdupq_n_f32(1.421413741f); - float32x4_t a4 = vdupq_n_f32(-1.453152027f); - float32x4_t a5 = vdupq_n_f32(1.061405429f); - float32x4_t p = vdupq_n_f32(0.3275911f); - const uint32x4_t szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); - float32x4_t sone = vdupq_n_f32(1.0f); - uint32x4_t s = vandq_u32(vreinterpretq_u32_f32(x), szero); - float32x4_t x_abs = vabsq_f32(x); - float32x4_t t = vrecpeq_f32(vmulq_f32(vaddq_f32(sone, p), x_abs)); - float32x4_t y = vdupq_n_f32(1.0f); - float32x4_t err = vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(a5, t), a4), t), a3), t), a2), t), a1), t); - err = exp_ps(vmulq_f32(vsubq_f32(vdupq_n_f32(0.0f), x_abs), x_abs)); - y = vsubq_f32(y, err); - return vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(y), s)); - } -#endif // __ARM_NEON -}; - } // namespace UnaryOp_arm_functor int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -580,7 +551,7 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return unary_op_inplace(bottom_top_blob, opt); if (op_type == Operation_ERF) - return unary_op_inplace(bottom_top_blob, opt); + return UnaryOp::forward_inplace(bottom_top_blob, opt); return 0; } @@ -719,7 +690,7 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) return unary_op_inplace_bf16s(bottom_top_blob, opt); if (op_type == Operation_ERF) - return unary_op_inplace_bf16s(bottom_top_blob, opt); + return UnaryOp::forward_inplace(bottom_top_blob, opt); return 0; } diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index d35818e1b4d0..5865a09eeee2 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -416,38 +416,6 @@ struct unary_op_trunc #endif // __loongarch_sx }; -struct unary_op_erf -{ - float func(const float& x) const - { - return (float)erf(x); - } -#if __loongarch_sx - __m128 func_pack4(const __m128& x) const - { -__m128 ones = (__m128)__lsx_vreplfr2vr_s(1.0f); - __m128 a1 = (__m128)__lsx_vreplfr2vr_s(0.254829592f); - __m128 a2 = (__m128)__lsx_vreplfr2vr_s(-0.284496736f); - __m128 a3 = (__m128)__lsx_vreplfr2vr_s(1.421413741f); - __m128 a4 = (__m128)__lsx_vreplfr2vr_s(-1.453152027f); - __m128 a5 = (__m128)__lsx_vreplfr2vr_s(1.061405429f); - __m128 p = (__m128)__lsx_vreplfr2vr_s(0.3275911f); - __m128 x2 = (__m128)__lsx_vbitclri_w((__m128i)x, 31); - __m128i sig_mask = __lsx_vreplgr2vr_w(1 << 31); - __m128i s = __lsx_vand_v((__m128i)x, sig_mask); - __m128 x_abs = (__m128)__lsx_vbitclri_w(x, 31); - __m128 t = (__m128)__lsx_vfdiv_s(ones, __lsx_vfadd_s(ones, __lsx_vfmul_s(x_abs, p))); - __m128 y = __lsx_vfsub_s(__lsx_vfmul_s(__lsx_vfmul_s(a5, t), t), __lsx_vfmul_s(__lsx_vfmul_s(a4, t), t)); - y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfsub_s(a3, t), t)); - y = __lsx_vfsub_s(y, __lsx_vfmul_s(__lsx_vfmul_s(a2, t), t)); - y = __lsx_vfsub_s(y, __lsx_vfsub_s(__lsx_vfmul_s(a1, t), t)); - y = __lsx_vfmul_s(y, t); - y = __lsx_vfmul_s(y, exp_ps(-__lsx_vfmul_s(x_abs, x_abs))); - return (__m128)__lsx_vfmul_s(s, y); - } -#endif // __loongarch_sx -}; - } // namespace UnaryOp_loongarch_functor int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -515,7 +483,7 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) return unary_op_inplace(bottom_top_blob, opt); if (op_type == Operation_ERF) - return unary_op_inplace(bottom_top_blob, opt); + return UnaryOp::forward_inplace(bottom_top_blob, opt); return 0; } From befd7c182e74ee893cce8e1be817faae3484a0d7 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Tue, 5 Sep 2023 11:47:48 +0330 Subject: [PATCH 27/30] recheck From a7f5eb1a66c495dbbd7ff9e2e5bc22791898cf51 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Wed, 6 Sep 2023 05:31:47 +0330 Subject: [PATCH 28/30] fixing x86 without vector intrinsics --- src/layer/x86/unaryop_x86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 6394ed3d83d8..5f47dae19b5a 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -776,7 +776,7 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return unary_op_inplace(bottom_top_blob, opt); if (op_type == Operation_ERF) - return unary_op_inplace(bottom_top_blob, opt); + return UnaryOp::forward_inplace(bottom_top_blob, opt);; return 0; } From 74c780634c913e93cda667cfb46e3be16f458f78 Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Wed, 6 Sep 2023 06:14:31 +0330 Subject: [PATCH 29/30] x86 fix --- src/layer/x86/unaryop_x86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index 5f47dae19b5a..4b1418084be3 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -776,7 +776,7 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const return unary_op_inplace(bottom_top_blob, opt); if (op_type == Operation_ERF) - return UnaryOp::forward_inplace(bottom_top_blob, opt);; + return UnaryOp::forward_inplace(bottom_top_blob, opt); return 0; } From 64ceec5d21cb69744ae4543d450023c33a5a573c Mon Sep 17 00:00:00 2001 From: brightening-eyes Date: Wed, 6 Sep 2023 06:36:09 +0330 Subject: [PATCH 30/30] some checks. maybe tests pass? --- src/layer/riscv/unaryop_riscv.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/layer/riscv/unaryop_riscv.cpp b/src/layer/riscv/unaryop_riscv.cpp index 444312df1de2..83fbfd4a82e4 100644 --- a/src/layer/riscv/unaryop_riscv.cpp +++ b/src/layer/riscv/unaryop_riscv.cpp @@ -360,6 +360,9 @@ int UnaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons if (op_type == Operation_TRUNC) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return UnaryOp::forward_inplace(bottom_top_blob, opt); + return 0; #else // __riscv_vector return UnaryOp::forward_inplace(bottom_top_blob, opt); @@ -683,6 +686,9 @@ int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt if (op_type == Operation_TRUNC) return unary_op_inplace_fp16s(bottom_top_blob, opt); + if (op_type == Operation_ERF) + return UnaryOp::forward_inplace(bottom_top_blob, opt); + return 0; } #endif // __riscv_vector && __riscv_zfh