diff --git a/intel_extension_for_transformers/llm/runtime/graph/README.md b/intel_extension_for_transformers/llm/runtime/graph/README.md index 4fb40bf22ab..2044563f64d 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/README.md +++ b/intel_extension_for_transformers/llm/runtime/graph/README.md @@ -61,7 +61,7 @@ cd build cmake .. cmake --build . -j ``` - +Note: add compile args ```-DNE_AVX512=OFF -DNE_AVX512_VBMI=OFF -DNE_AVX512_VNNI=OFF``` to ```cmake``` when compiling it on a CPU without AVX512 ### 2. Run LLM with Python API You can use Python API to run Hugging Face model simply. Here is the sample code: diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp index 9e62eec3288..2233a27d642 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp @@ -92,6 +92,15 @@ using PerNFp32Fp32 = jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB< jblas::utils::parallel::Parallel2DGemm>; } // namespace avx512_vnni +namespace avx2 { +JBLAS_ISA constexpr DefaultISA = JblasAVX2; +template <template <class GC, JBLAS_ISA ISA> class ProB> +using Default = jblas::wrapper::gemm_pack_weight::GemmInterfacePackWeight< + jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<DefaultISA, jblas::gemm::GemmCore_Row_NN_2x48_AVX2, + jblas::prologue::gemm::ActivationBase, ProB, + jblas::epilogue::gemm::AccumulatorWriteBackFp32>, + jblas::utils::parallel::Parallel2DGemm>; +} // namespace avx2 } // namespace static JBLAS_CODE jblas_s4fp32kblock_f32f32_forward(float* activation, SS4Fp32* weiptr, float* output, int _m, int _n, @@ -125,6 +134,10 @@ static JBLAS_CODE jblas_s4fp32kblock_f32f32_forward(float* activation, SS4Fp32* using GemmKernel = avx512f::Default<WeiS4ClipFp32>; static GemmKernel kernel; ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo}); + } else if (_cd->AVX2()) { + using GemmKernel = avx2::Default<WeiS4ClipFp32>; + static GemmKernel kernel; + ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo}); } } else if (weiptr->mCoreType == GcCompBf16::TYPE) { if (_cd->AMX_BF16()) { @@ -159,6 +172,10 @@ static JBLAS_CODE jblas_s8fp32kblock_f32f32_forward(float* activation, SS8Fp32* using GemmKernel = avx512f::Default<WeiS8Fp32>; static GemmKernel kernel; ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo}); + } else if (_cd->AVX2()) { + using GemmKernel = avx2::Default<WeiS8Fp32>; + static GemmKernel kernel; + ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo}); } } return ret; diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp index 5c084825ae3..2ffa21be9f0 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp @@ -37,9 +37,13 @@ bool jblas_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr, if (sameKernel) { if (w1tmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32) || w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) { - constexpr size_t EleNum = sizeof(GcCompInt8KBlockSet) / sizeof(GcCompInt8KBlockSet[0]); - support = contains(w1tmp->mCoreType, GcCompInt8KBlockSet, EleNum); - support &= hasISA(GcCompInt8KBlockSet, EleNum); + constexpr jblas::gemm::GemmCoreType cores[] = { + jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK, + jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48, + jblas::gemm::GemmCoreType::AVX2_2X48}; + constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]); + support = contains(w1tmp->mCoreType, cores, EleNum); + support &= hasISA(cores, EleNum); } else if (w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32PerChannelN) || w1tmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32PerChannelN)) { constexpr size_t EleNum = sizeof(GcCompInt8Set) / sizeof(GcCompInt8Set[0]); @@ -115,6 +119,42 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s4fp32_f32f32_forward(float* activation, SS4Fp3 w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL}); } } + } else if (w1ptr->mCoreType == GcCompFp32::TYPE) { + if (_cd->AVX512F()) { + using GemmKernel = custom::wrapper::kblock::avx512f::GemmS4KBlock; + using SiluGemmKernel = custom::wrapper::kblock::avx512f::SiluGemmS4KBlock; + using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>; + static FusedInter finter; + int lda = fin; + int ldtmp1 = fmid; + int ldtmp2 = fmid; + int ldo = fout; + GemmKernel::AParam paramA = {activation, lda}; + SiluGemmKernel::BParam paramW1 = {w1ptr}; + GemmKernel::BParam paramW2 = {w2ptr}; + GemmKernel::BParam paramW3 = {w3ptr}; + SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1}; + GemmKernel::EpiParam param2 = {output, ldo, NULL}; + GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL}; + ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3}); + } else if (_cd->AVX2()) { + using GemmKernel = custom::wrapper::kblock::avx2::GemmS4KBlock; + using SiluGemmKernel = custom::wrapper::kblock::avx2::SiluGemmS4KBlock; + using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>; + static FusedInter finter; + int lda = fin; + int ldtmp1 = fmid; + int ldtmp2 = fmid; + int ldo = fout; + GemmKernel::AParam paramA = {activation, lda}; + SiluGemmKernel::BParam paramW1 = {w1ptr}; + GemmKernel::BParam paramW2 = {w2ptr}; + GemmKernel::BParam paramW3 = {w3ptr}; + SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1}; + GemmKernel::EpiParam param2 = {output, ldo, NULL}; + GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL}; + ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3}); + } } return ret; } @@ -160,6 +200,42 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s8fp32_f32f32_forward(float* activation, SS8Fp3 ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr, w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL}); } + } else if (w1ptr->mCoreType == GcCompFp32::TYPE) { + if (_cd->AVX512F()) { + using GemmKernel = custom::wrapper::kblock::avx512f::GemmS8KBlock; + using SiluGemmKernel = custom::wrapper::kblock::avx512f::SiluGemmS8KBlock; + using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>; + static FusedInter finter; + int lda = fin; + int ldtmp1 = fmid; + int ldtmp2 = fmid; + int ldo = fout; + GemmKernel::AParam paramA = {activation, lda}; + SiluGemmKernel::BParam paramW1 = {w1ptr}; + GemmKernel::BParam paramW2 = {w2ptr}; + GemmKernel::BParam paramW3 = {w3ptr}; + SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1}; + GemmKernel::EpiParam param2 = {output, ldo, NULL}; + GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL}; + ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3}); + } else if (_cd->AVX2()) { + using GemmKernel = custom::wrapper::kblock::avx2::GemmS8KBlock; + using SiluGemmKernel = custom::wrapper::kblock::avx2::SiluGemmS8KBlock; + using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>; + static FusedInter finter; + int lda = fin; + int ldtmp1 = fmid; + int ldtmp2 = fmid; + int ldo = fout; + GemmKernel::AParam paramA = {activation, lda}; + SiluGemmKernel::BParam paramW1 = {w1ptr}; + GemmKernel::BParam paramW2 = {w2ptr}; + GemmKernel::BParam paramW3 = {w3ptr}; + SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1}; + GemmKernel::EpiParam param2 = {output, ldo, NULL}; + GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL}; + ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3}); + } } return ret; } @@ -438,7 +514,8 @@ bool jblas_fusion_FFN_Add_GeLu_f32f32_support(void* w1ptr, void* w2ptr, int seq, w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) { constexpr jblas::gemm::GemmCoreType cores[] = { jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK, - jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48}; + jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48, + jblas::gemm::GemmCoreType::AVX2_2X48}; constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]); support = contains(w1tmp->mCoreType, cores, EleNum); support &= hasISA(cores, EleNum); @@ -516,10 +593,16 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4fp32_f32f32_forward(float* activation, SS int lda = fin; int ldtmp1 = fmid; int ldo = fout; - // FusedInter::Arguments::paramA paramA={activation, lda}; - // FusedInter::Arguments::paramW1 paramW1={w1tmp}; - // FusedInter::Arguments::paramW2 paramW2={w2tmp}; - // FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1}; + ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1, + broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo}); + } else if (_cd->AVX2()) { + using GemmKernel = custom::wrapper::kblock::avx2::AddGemmS4KBlock; + using GeluGemmKernel = custom::wrapper::kblock::avx2::AddGeluGemmS4KBlock; + using FusedInter = custom::wrapper::transformer::FpGeluFusedInterface<GeluGemmKernel, GemmKernel>; + static FusedInter finter; + int lda = fin; + int ldtmp1 = fmid; + int ldo = fout; ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo}); } @@ -532,10 +615,6 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4fp32_f32f32_forward(float* activation, SS int lda = fin; int ldtmp1 = fmid; int ldo = fout; - // FusedInter::Arguments::paramA paramA={activation, lda}; - // FusedInter::Arguments::paramW1 paramW1={w1tmp}; - // FusedInter::Arguments::paramW2 paramW2={w2tmp}; - // FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1}; ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo}); } @@ -604,10 +683,16 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s8fp32_f32f32_forward(float* activation, SS int lda = fin; int ldtmp1 = fmid; int ldo = fout; - // FusedInter::Arguments::paramA paramA={activation, lda}; - // FusedInter::Arguments::paramW1 paramW1={w1tmp}; - // FusedInter::Arguments::paramW2 paramW2={w2tmp}; - // FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1}; + ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1, + broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo}); + } else if (_cd->AVX2()) { + using GemmKernel = custom::wrapper::kblock::avx2::AddGemmS8KBlock; + using GeluGemmKernel = custom::wrapper::kblock::avx2::AddGeluGemmS8KBlock; + using FusedInter = custom::wrapper::transformer::FpGeluFusedInterface<GeluGemmKernel, GemmKernel>; + static FusedInter finter; + int lda = fin; + int ldtmp1 = fmid; + int ldo = fout; ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo}); } diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp index 5a8fde53742..87d5f936614 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp @@ -63,6 +63,7 @@ static bool hasISA(const jblas::gemm::GemmCoreType* set, size_t len) { support |= _cd->AVX512F(); break; case jblas::gemm::GemmCoreType::AVX2_4X24: + case jblas::gemm::GemmCoreType::AVX2_2X48: support |= _cd->AVX2(); break; case jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK: @@ -113,6 +114,7 @@ using SS8Fp32 = jblas::prologue::weight_comp::gemm_kblcok::StorageWeightS8ScaleF using SS8Fp32PerN = jblas::prologue::weight_comp::gemm_kblcok::StorageWeightS8ScaleFp32PerChannelN; using SS4Fp32PerN = jblas::prologue::weight_comp::gemm_kblcok::StorageWeightS4ScaleFp32PerChannelN; +using GcCompAVX2 = jblas::gemm::GemmCore_Row_NN_4x24_AVX2; using GcCompFp32 = jblas::gemm::GemmCore_Row_NN_8x48_AVX512F; using GcCompInt8KBlock = jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK; using GcCompBf16 = jblas::gemm::GemmCore_Row_NN_16x64_AMX_BF16; @@ -454,6 +456,96 @@ class FFNFusedInterface { _SiluLauncher_T mActLauncher; }; +template <class _SiluLauncher_T, class _Launcher_T> +class FPFFNFusedInterface { + public: + static_assert(std::is_same<typename _Launcher_T::AParam, typename _SiluLauncher_T::AParam>::value, + "Prologue A param of the 2 Launcher (w/wo SILU) should be the same."); + struct Arguments { + const int Seq, Fin, FMid, FOut; + const typename _Launcher_T::AParam paramA; + const typename _SiluLauncher_T::BParam paramW1; + const typename _Launcher_T::BParam paramW2, paramW3; + const typename _SiluLauncher_T::EpiParam param1; + const typename _Launcher_T::EpiParam param2, param3; + }; + using Config = typename _Launcher_T::ParallelConfig; + using ActConfig = typename _SiluLauncher_T::ParallelConfig; + using ActivationType = typename _Launcher_T::PrologueA; + using WeightType = typename _Launcher_T::PrologueB; + using GemmCore = typename _Launcher_T::GemmCore; + using LArguments = typename _Launcher_T::Param; + using CParam = typename _Launcher_T::EpiParam; + using Parallel = jblas::utils::parallel::Parallel2DGemmKBlockFixed<GemmCore>; + ActivationType* getActivationPtr() { return &mLauncher.mProA; } + // forward=packB+compute + JBLAS_CODE compute(const Arguments& _param) { + auto bptr = (jblas::prologue::weight_comp::gemm_kblcok::WeightBase*)(_param.paramW1.packedW); + if (bptr == nullptr) { + return JblasInvalidParam; + } + // dynamic quantization: Seq*Fin + auto cb = jblas::utils::CpuBase(); + + Parallel _paral = Parallel(); // w1&w3 from Seq* Fin=>FMid + Parallel _paral2 = Parallel(); // w2 from Seq* FMid=>Fout + _paral.update(_param.Seq, _param.FMid, _param.Fin, bptr->mBlockSize, cb.mNumThreads); + _paral2.update(_param.Seq, _param.FOut, _param.FMid, bptr->mBlockSize, cb.mNumThreads); + + omp_set_num_threads(cb.mNumThreads); +#pragma omp parallel + { + int tidx = omp_get_thread_num(); + { + int colidx, rowidx, rowsize, colsize; + _paral.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize); + if (rowsize > 0 && colsize > 0) { + ActConfig _actconfig{ + rowidx, colidx, rowsize, colsize, _paral.getMStep(), _paral.getNStep(), _paral.getKStep(), cb.mL2Cache}; + Config _config{rowidx, colidx, rowsize, colsize, _paral.getMStep(), _paral.getNStep(), _paral.getKStep(), + cb.mL2Cache}; + mActLauncher.launch( + _actconfig, {_param.Seq, _param.FMid, _param.Fin, _param.paramA, _param.paramW1, _param.param1, NULL}); + mLauncher.launch(_config, + {_param.Seq, _param.FMid, _param.Fin, _param.paramA, _param.paramW3, _param.param3, NULL}); + int row_r = jblas::utils::remainsize(rowidx, _paral.mRows, rowsize); + int col_r = jblas::utils::remainsize(colidx, _paral.mCols, colsize); + + // TODO(Yu): replace the naive inplace eltwise mul + for (int i = 0; i < row_r; i++) { + for (int j = 0; j < col_r; j++) { + _param.param1.C[(rowidx + i) * _param.param1.ldc + colidx + j] *= + _param.param3.C[(rowidx + i) * _param.param3.ldc + colidx + j]; + } + } + } + } +#pragma omp barrier + { + int colidx, rowidx, rowsize, colsize; + _paral2.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize); + if (rowsize > 0 && colsize > 0) { + Config _config{ + rowidx, colidx, rowsize, colsize, _paral2.getMStep(), _paral2.getNStep(), _paral2.getKStep(), + cb.mL2Cache}; + mLauncher.launch(_config, {_param.Seq, + _param.FOut, + _param.FMid, + {_param.param1.C, _param.param1.ldc}, + _param.paramW2, + _param.param2, + NULL}); + } + } + } + return JblasSuccess; + } + + protected: + _Launcher_T mLauncher; + _SiluLauncher_T mActLauncher; +}; + template <class _SiluLauncher_T, class _Launcher_T> class FFNFusedInterfacePerN { public: @@ -843,10 +935,29 @@ using DefaultGemmFp32 = jblas::prologue::gemm::ActivationBase, ProB, Epi>; using AddGeluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::Add_GeluFp32>; using AddGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::AddFp32>; +using GemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; +using SiluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::SiluFp32>; using AddGeluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::Add_GeluFp32>; using AddGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::AddFp32>; +using GemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; +using SiluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::SiluFp32>; } // namespace avx512f +namespace avx2 { +template <template <class GC, JBLAS_ISA ISA> class ProB, template <JBLAS_ISA ISA> class Epi> +using DefaultGemmFp32 = + jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<JblasAVX2, jblas::gemm::GemmCore_Row_NN_2x48_AVX2, + jblas::prologue::gemm::ActivationBase, ProB, Epi>; +using AddGeluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::Add_GeluFp32>; +using AddGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::AddFp32>; +using GemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; +using SiluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::SiluFp32>; + +using AddGeluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::Add_GeluFp32>; +using AddGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::AddFp32>; +using GemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>; +using SiluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::SiluFp32>; +} // namespace avx2 namespace amx_bf16 { template <template <class GC, JBLAS_ISA ISA> class ProB, template <JBLAS_ISA ISA> class Epi> using DefaultGemmFp32 = diff --git a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp index e731d33fb47..ea2e22d3e03 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp @@ -863,7 +863,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter using KernelRef = WeiS4ClipFp32<GcCompFp32, JblasNoSIMD>; static Kernel kernel; static KernelRef kernelref; - auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym); + auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::asym); packedw.assign(dstbptr); if (cd->AVX512_FP16()) { kernel.packTransposeWeight(n, k, f32ptr, k, &packedw); @@ -876,7 +876,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter using KernelRef = WeiS4ClipFp32<GcCompBf16, JblasNoSIMD>; static Kernel kernel; static KernelRef kernelref; - auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym); + auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::asym); packedw.assign(dstbptr); if (cd->AMX_BF16()) { kernel.packTransposeWeight(n, k, f32ptr, k, &packedw); @@ -926,7 +926,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter using KernelRef = WeiS8Fp32<GcCompFp32, JblasNoSIMD>; static Kernel kernel; static KernelRef kernelref; - auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym); + auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::asym); packedw.assign(dstbptr); if (cd->AVX512_FP16()) { kernel.packTransposeWeight(n, k, f32ptr, k, &packedw); @@ -939,7 +939,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter using KernelRef = WeiS8Fp32<GcCompBf16, JblasNoSIMD>; static Kernel kernel; static KernelRef kernelref; - auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym); + auto packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::asym); packedw.assign(dstbptr); if (cd->AMX_BF16()) { kernel.packTransposeWeight(n, k, f32ptr, k, &packedw); diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec.hpp index 54a31f09687..41cc13da3dc 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec.hpp @@ -19,7 +19,6 @@ #include "vec_base.hpp" #include "vec_compare.hpp" #include "vec_convert.hpp" -#include "vec_load.hpp" #include "vec_set.hpp" #endif // ENGINE_EXECUTOR_INCLUDE_VEC_HPP_ diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.cpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.cpp index 30fca7c37c2..1e47ef8b2d3 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.cpp @@ -12,29 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <cassert> - +#include "vec_load.hpp" +#include "vec_store.hpp" #include "vec_arithmetic.hpp" +#include "cmath" -inline fp32x16 sub_fp32x16(fp32x16 x, fp32x16 y) { +fp32x16 sub_fp32x16(fp32x16 x, fp32x16 y) { #if __AVX512F__ - return _mm512_sub_ps(x, y); + return {_mm512_sub_ps(x.first, y.first)}; #else return {_mm256_sub_ps(x.first, y.first), _mm256_sub_ps(x.second, y.second)}; #endif } -inline fp32x16 fmsub_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z) { +fp32x16 fmsub_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z) { #if __AVX512F__ - return _mm512_fmsub_ps(x, y, z); + return {_mm512_fmsub_ps(x.first, y.first, z.first)}; #else return {_mm256_fmsub_ps(x.first, y.first, z.first), _mm256_fmsub_ps(x.second, y.second, z.second)}; #endif } -inline fp32x16 maskz_fmsub_fp32x16(int mask, fp32x16 x, fp32x16 y, fp32x16 z) { +fp32x16 maskz_fmsub_fp32x16(int mask, fp32x16 x, fp32x16 y, fp32x16 z) { #if __AVX512F__ - return _mm512_maskz_fmsub_ps(mask, x, y, z); + return {_mm512_maskz_fmsub_ps(mask, x.first, y.first, z.first)}; #else __m256 first, second; MASK_DECORATOR(_mm256_blend_ps, _mm256_setzero_ps(), _mm256_fmsub_ps(x.first, y.first, z.first), mask & 255, first); @@ -44,33 +45,33 @@ inline fp32x16 maskz_fmsub_fp32x16(int mask, fp32x16 x, fp32x16 y, fp32x16 z) { #endif } -inline fp32x16 add_fp32x16(fp32x16 x, fp32x16 y) { +fp32x16 add_fp32x16(fp32x16 x, fp32x16 y) { #if __AVX512F__ - return _mm512_add_ps(x, y); + return {_mm512_add_ps(x.first, y.first)}; #else return {_mm256_add_ps(x.first, y.first), _mm256_add_ps(x.second, y.second)}; #endif } -inline fp32x16 fmadd_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z) { +fp32x16 fmadd_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z) { #if __AVX512F__ - return _mm512_fmadd_ps(x, y, z); + return {_mm512_fmadd_ps(x.first, y.first, z.first)}; #else return {_mm256_fmadd_ps(x.first, y.first, z.first), _mm256_fmadd_ps(x.second, y.second, z.second)}; #endif } -inline fp32x16 mul_fp32x16(fp32x16 x, fp32x16 y) { +fp32x16 mul_fp32x16(fp32x16 x, fp32x16 y) { #if __AVX512F__ - return _mm512_mul_ps(x, y); + return {_mm512_mul_ps(x.first, y.first)}; #else return {_mm256_mul_ps(x.first, y.first), _mm256_mul_ps(x.second, y.second)}; #endif } -inline fp32x16 maskz_mul_fp32x16(int mask, fp32x16 x, fp32x16 y) { +fp32x16 maskz_mul_fp32x16(int mask, fp32x16 x, fp32x16 y) { #if __AVX512F__ - return _mm512_maskz_mul_ps(mask, x, y); + return {_mm512_maskz_mul_ps(mask, x.first, y.first)}; #else __m256 first, second; MASK_DECORATOR(_mm256_blend_ps, _mm256_setzero_ps(), _mm256_mul_ps(x.first, y.first), mask & 255, first); @@ -80,31 +81,31 @@ inline fp32x16 maskz_mul_fp32x16(int mask, fp32x16 x, fp32x16 y) { } template <int rounding> -inline fp32x16 mul_round_fp32x16(fp32x16 x, fp32x16 y) { +fp32x16 mul_round_fp32x16(fp32x16 x, fp32x16 y) { static_assert(rounding == (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_CUR_DIRECTION), "ERROR: Not support rounding"); #if __AVX512F__ - return _mm512_mul_round_ps(x, y, rounding); + return {_mm512_mul_round_ps(x.first, y.first, rounding)}; #else return {_mm256_round_ps(_mm256_mul_ps(x.first, y.first), rounding), _mm256_round_ps(_mm256_mul_ps(x.second, y.second), rounding)}; #endif } -inline fp32x16 div_fp32x16(fp32x16 x, fp32x16 y) { +fp32x16 div_fp32x16(fp32x16 x, fp32x16 y) { #if __AVX512F__ - return _mm512_div_ps(x, y); + return {_mm512_div_ps(x.first, y.first)}; #else return {_mm256_div_ps(x.first, y.first), _mm256_div_ps(x.second, y.second)}; #endif } -inline float reduce_add_fp32x16(fp32x16 x) { +float reduce_add_fp32x16(fp32x16 x) { #if __AVX512F__ - return _mm512_reduce_add_ps(x); + return {_mm512_reduce_add_ps(x.first)}; #else const __m256 x256 = _mm256_add_ps(x.first, x.second); const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x256, 1), _mm256_castps256_ps128(x256)); @@ -114,46 +115,54 @@ inline float reduce_add_fp32x16(fp32x16 x) { #endif } -inline fp32x16 sqrt_fp32x16(fp32x16 x) { +fp32x16 sqrt_fp32x16(fp32x16 x) { #if __AVX512F__ - return _mm512_sqrt_ps(x); + return {_mm512_sqrt_ps(x.first)}; #else return {_mm256_sqrt_ps(x.first), _mm256_sqrt_ps(x.second)}; #endif } -inline fp32x16 rsqrt14_fp32x16(fp32x16 x) { +fp32x16 rsqrt14_fp32x16(fp32x16 x) { #if __AVX512F__ - return _mm512_rsqrt14_ps(x); + return {_mm512_rsqrt14_ps(x.first)}; #else // the max relative error is 6x than avx512 return {_mm256_rsqrt_ps(x.first), _mm256_rsqrt_ps(x.second)}; #endif } -inline fp32x16 ceil_fp32x16(fp32x16 x) { +fp32x16 ceil_fp32x16(fp32x16 x) { #if __AVX512F__ - return _mm512_ceil_ps(x); + return {_mm512_ceil_ps(x.first)}; #else // the max relative error is 6x than avx512 return {_mm256_ceil_ps(x.first), _mm256_ceil_ps(x.second)}; #endif } -inline fp32x16 scale_fp32x16(fp32x16 x, fp32x16 y) { +fp32x16 scale_fp32x16(fp32x16 x, fp32x16 y) { #if __AVX512F__ - return _mm512_scalef_ps(x, y); + return {_mm512_scalef_ps(x.first, y.first)}; #else - // No intrinsic - assert("No intrinsic"); - return {_mm256_rsqrt_ps(x.first), _mm256_rsqrt_ps(x.second)}; + float* vec_x = new float[16]; + float* vec_y = new float[16]; + float* vec_z = new float[16]; + store_fp32x16(vec_x, x); + store_fp32x16(vec_y, y); + for (int i = 0; i < 16; i++) vec_z[i] = vec_x[i] * exp2(vec_y[i]); + fp32x16 res = load_fp32x16(vec_z); + delete[] vec_x; + delete[] vec_y; + delete[] vec_z; + return res; #endif } -inline float dot_fp32x16(fp32x16 x, fp32x16 y) { return reduce_add_fp32x16(mul_fp32x16(x, y)); } +float dot_fp32x16(fp32x16 x, fp32x16 y) { return reduce_add_fp32x16(mul_fp32x16(x, y)); } -inline fp32x16 abs_fp32x16(fp32x16 x) { +fp32x16 abs_fp32x16(fp32x16 x) { #if __AVX512F__ - return _mm512_abs_ps(x); + return {_mm512_abs_ps(x.first)}; #else return {_mm256_castsi256_ps(_mm256_abs_epi32(_mm256_castps_si256(x.first))), _mm256_castsi256_ps(_mm256_abs_epi32(_mm256_castps_si256(x.second)))}; diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.hpp index c261a3a26d9..71bf6f7f565 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_arithmetic.hpp @@ -17,50 +17,50 @@ #include "vec_base.hpp" -inline fp32x16 sub_fp32x16(fp32x16 x, fp32x16 y); +fp32x16 sub_fp32x16(fp32x16 x, fp32x16 y); REGISTER_KERNEL_T(sub_fp32x16, fp32x16, fp32x16, fp32x16); -inline fp32x16 fmsub_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z); +fp32x16 fmsub_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z); REGISTER_KERNEL_T(fmsub_fp32x16, fp32x16, fp32x16, fp32x16, fp32x16); -inline fp32x16 maskz_fmsub_fp32x16(int mask, fp32x16 x, fp32x16 y, fp32x16 z); +fp32x16 maskz_fmsub_fp32x16(int mask, fp32x16 x, fp32x16 y, fp32x16 z); -inline fp32x16 add_fp32x16(fp32x16 x, fp32x16 y); +fp32x16 add_fp32x16(fp32x16 x, fp32x16 y); REGISTER_KERNEL_T(add_fp32x16, fp32x16, fp32x16, fp32x16); -inline fp32x16 fmadd_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z); +fp32x16 fmadd_fp32x16(fp32x16 x, fp32x16 y, fp32x16 z); REGISTER_KERNEL_T(fmadd_fp32x16, fp32x16, fp32x16, fp32x16, fp32x16); -inline fp32x16 mul_fp32x16(fp32x16 x, fp32x16 y); +fp32x16 mul_fp32x16(fp32x16 x, fp32x16 y); REGISTER_KERNEL_T(mul_fp32x16, fp32x16, fp32x16, fp32x16); -inline fp32x16 maskz_mul_fp32x16(int mask, fp32x16 x, fp32x16 y); +fp32x16 maskz_mul_fp32x16(int mask, fp32x16 x, fp32x16 y); template <int rounding> -inline fp32x16 mul_round_fp32x16(fp32x16 x, fp32x16 y); +fp32x16 mul_round_fp32x16(fp32x16 x, fp32x16 y); -inline fp32x16 div_fp32x16(fp32x16 x, fp32x16 y); +fp32x16 div_fp32x16(fp32x16 x, fp32x16 y); REGISTER_KERNEL_T(div_fp32x16, fp32x16, fp32x16, fp32x16); -inline float reduce_add_fp32x16(fp32x16 x); +float reduce_add_fp32x16(fp32x16 x); REGISTER_KERNEL_T(reduce_add_fp32x16, float, fp32x16); -inline fp32x16 sqrt_fp32x16(fp32x16 x); +fp32x16 sqrt_fp32x16(fp32x16 x); REGISTER_KERNEL_T(sqrt_fp32x16, fp32x16, fp32x16); -inline fp32x16 rsqrt14_fp32x16(fp32x16 x); +fp32x16 rsqrt14_fp32x16(fp32x16 x); REGISTER_KERNEL_T(rsqrt14_fp32x16, fp32x16, fp32x16); -inline fp32x16 ceil_fp32x16(fp32x16 x); +fp32x16 ceil_fp32x16(fp32x16 x); REGISTER_KERNEL_T(ceil_fp32x16, fp32x16, fp32x16); -inline fp32x16 scale_fp32x16(fp32x16 x, fp32x16 y); +fp32x16 scale_fp32x16(fp32x16 x, fp32x16 y); REGISTER_KERNEL_T(scale_fp32x16, fp32x16, fp32x16, fp32x16); -inline float dot_fp32x16(fp32x16 x, fp32x16 y); +float dot_fp32x16(fp32x16 x, fp32x16 y); REGISTER_KERNEL_T(dot_fp32x16, float, fp32x16, fp32x16); -inline fp32x16 abs_fp32x16(fp32x16 x); +fp32x16 abs_fp32x16(fp32x16 x); REGISTER_KERNEL_T(abs_fp32x16, fp32x16, fp32x16); #endif // ENGINE_EXECUTOR_INCLUDE_VEC_SET_HPP_ diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_base.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_base.hpp index fb749985b51..2051f9bb9a6 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_base.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_base.hpp @@ -17,15 +17,29 @@ #include <immintrin.h> #include <cstdint> -#include <utility> #if __AVX512F__ -typedef __m512 fp32x16; -typedef __m512i int32x16; +struct fp32x16 { + __m512 first; +}; + +struct s32x16 { + __m512i first; +}; +struct u32x16 { + __m512i first; +}; #else -typedef std::pair<__m256, __m256> fp32x16; -typedef std::pair<__m256i, __m256i> int32x16; +struct fp32x16 { + __m256 first, second; +}; +struct s32x16 { + __m256i first, second; +}; +struct u32x16 { + __m256i first, second; +}; #define MASK_DECORATOR(blend_func, a, b, mask, res) \ switch ((mask)) { \ case 1: \ @@ -54,16 +68,49 @@ typedef std::pair<__m256i, __m256i> int32x16; } #endif -typedef __m256i bf16x16; -typedef __m256i int16x16; -typedef __m128i int8x16; + +struct bf16x16 { + __m256i first; +}; + +struct fp16x16 { + __m256i first; +}; + +struct s16x16 { + __m256i first; +}; +struct s8x16 { + __m128i first; +}; +struct u8x16 { + __m128i first; +}; + #define CPU_VEC_STEP 16 template <typename T> -T load_kernel_t(const void*); +T load_kernel_t(const void* src) { + return *reinterpret_cast<const T*>(src); +} + +template <> +fp32x16 load_kernel_t<fp32x16>(const void* src); +template <> +bf16x16 load_kernel_t<bf16x16>(const void* src); template <typename T> -void store_kernel_t(void*, T); +void store_kernel_t(void* dst, T src) { + T* dst_T = reinterpret_cast<T*>(dst); + *dst_T = src; +} + +template <> +void store_kernel_t<s8x16>(void* dst, s8x16 src); +template <> +void store_kernel_t<fp32x16>(void* dst, fp32x16 src); +template <> +void store_kernel_t<bf16x16>(void* dst, bf16x16 src); template <typename dstT, typename src0T = void, typename src1T = void, typename src2T = void> struct kernel_t { diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.cpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.cpp index 6cff8569c89..be78ac3b8aa 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.cpp @@ -14,33 +14,33 @@ #include "vec_compare.hpp" -inline fp32x16 min_fp32x16(fp32x16 a, fp32x16 b) { +fp32x16 min_fp32x16(fp32x16 a, fp32x16 b) { #if __AVX512F__ - return _mm512_min_ps(a, b); + return {_mm512_min_ps(a.first, b.first)}; #else return {_mm256_min_ps(a.first, b.first), _mm256_min_ps(a.second, b.second)}; #endif } -inline int32x16 max_int32x16(int32x16 a, int32x16 b) { +s32x16 max_s32x16(s32x16 a, s32x16 b) { #if __AVX512F__ - return _mm512_max_epi32(a, b); + return {_mm512_max_epi32(a.first, b.first)}; #else return {_mm256_max_epi32(a.first, b.first), _mm256_max_epi32(a.second, b.second)}; #endif } -inline fp32x16 max_fp32x16(fp32x16 a, fp32x16 b) { +fp32x16 max_fp32x16(fp32x16 a, fp32x16 b) { #if __AVX512F__ - return _mm512_max_ps(a, b); + return {_mm512_max_ps(a.first, b.first)}; #else return {_mm256_max_ps(a.first, b.first), _mm256_max_ps(a.second, b.second)}; #endif } -inline float reduce_max_fp32x16(fp32x16 x) { +float reduce_max_fp32x16(fp32x16 x) { #if __AVX512F__ - return _mm512_reduce_max_ps(x); + return {_mm512_reduce_max_ps(x.first)}; #else const __m256 x256 = _mm256_max_ps(x.first, x.second); const __m128 x128 = _mm_max_ps(_mm256_extractf128_ps(x256, 1), _mm256_castps256_ps128(x256)); diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.hpp index 14d365eb988..2300f952683 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_compare.hpp @@ -17,13 +17,13 @@ #include "vec_base.hpp" -inline fp32x16 min_fp32x16(fp32x16 a, fp32x16 b); +fp32x16 min_fp32x16(fp32x16 a, fp32x16 b); -inline int32x16 max_int32x16(int32x16 a, int32x16 b); +s32x16 max_s32x16(s32x16 a, s32x16 b); -inline fp32x16 max_fp32x16(fp32x16 a, fp32x16 b); +fp32x16 max_fp32x16(fp32x16 a, fp32x16 b); -inline float reduce_max_fp32x16(fp32x16 x); +float reduce_max_fp32x16(fp32x16 x); REGISTER_KERNEL_T(reduce_max_fp32x16, float, fp32x16); #endif // ENGINE_EXECUTOR_INCLUDE_VEC_COMPARE_HPP_ diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.cpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.cpp index f52484eac70..4e57dbb915b 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.cpp @@ -12,31 +12,32 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "vec_store.hpp" #include "vec_convert.hpp" template <int rounding> -inline int32x16 cvt_roundfp32x16_int32x16(fp32x16 a) { +s32x16 cvt_roundfp32x16_s32x16(fp32x16 a) { static_assert(rounding == (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_CUR_DIRECTION), "ERROR: Not support rounding"); #if __AVX512F__ - return _mm512_cvt_roundps_epi32(a, rounding); + return {_mm512_cvt_roundps_epi32(a.first, rounding)}; #else return {_mm256_cvtps_epi32(_mm256_round_ps(a.first, rounding)), _mm256_cvtps_epi32(_mm256_round_ps(a.second, rounding))}; #endif } template <int rounding> -inline int32x16 maskz_cvt_roundfp32x16_int32x16(int mask, fp32x16 a) { +s32x16 maskz_cvt_roundfp32x16_s32x16(int mask, fp32x16 a) { static_assert(rounding == (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_CUR_DIRECTION), "ERROR: Not support rounding"); #if __AVX512F__ - return _mm512_maskz_cvt_roundps_epi32(mask, a, rounding); + return {_mm512_maskz_cvt_roundps_epi32(mask, a.first, rounding)}; #else __m256i first, second; first = _mm256_cvtps_epi32(_mm256_round_ps(a.first, rounding)); @@ -47,48 +48,48 @@ inline int32x16 maskz_cvt_roundfp32x16_int32x16(int mask, fp32x16 a) { #endif } -inline bf16x16 cvt_fp32x16_bf16x16(fp32x16 a) { +bf16x16 cvt_fp32x16_bf16x16(fp32x16 a) { #if __AVX512F__ #if __AVX512BF16__ && __GNUC__ > 11 - return _mm256_castph_si256((__m256h)_mm512_cvtneps_pbh(a)); + return {_mm512_cvtneps_pbh(a.first)}; #else - return _mm512_cvtepi32_epi16(_mm512_bsrli_epi128(_mm512_castps_si512(a), 2)); + return {_mm512_cvtepi32_epi16(_mm512_bsrli_epi128(_mm512_castps_si512(a.first), 2))}; #endif #else __m256i first = _mm256_bsrli_epi128(_mm256_castps_si256(a.first), 2); __m256i second = _mm256_bsrli_epi128(_mm256_castps_si256(a.second), 2); __m256i res = _mm256_packus_epi32(first, second); - return _mm256_permute4x64_epi64(res, 0x18); + return {_mm256_permute4x64_epi64(res, 0x18)}; #endif } -inline fp32x16 cvt_bf16x16_fp32x16(bf16x16 a) { +fp32x16 cvt_bf16x16_fp32x16(bf16x16 a) { #if __AVX512F__ #if __AVX512BF16__ && __GNUC__ > 11 - return _mm512_cvtpbh_ps((__m256bh)_mm256_castsi256_ph(a)); + return {_mm512_cvtpbh_ps(a.first)}; #else - return _mm512_castsi512_ps(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(a), 2)); + return {_mm512_castsi512_ps(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(a.first), 2))}; #endif #else - __m128i second = _mm256_extractf128_si256(a, 1); + __m128i second = _mm256_extractf128_si256(a.first, 1); __m256 second_fp32 = _mm256_castsi256_ps(_mm256_bslli_epi128(_mm256_cvtepu16_epi32(second), 2)); - __m128i first = _mm256_castsi256_si128(a); + __m128i first = _mm256_castsi256_si128(a.first); __m256 first_fp32 = _mm256_castsi256_ps(_mm256_bslli_epi128(_mm256_cvtepu16_epi32(first), 2)); return {first_fp32, second_fp32}; #endif } -inline fp32x16 maskz_cvt_bf16x16_fp32x16(int mask, bf16x16 a) { +fp32x16 maskz_cvt_bf16x16_fp32x16(int mask, bf16x16 a) { #if __AVX512F__ #if __AVX512BF16__ && __GNUC__ > 11 - return _mm512_maskz_cvtpbh_ps(mask, (__m256bh)a); + return {_mm512_maskz_cvtpbh_ps(mask, a.first)}; #else - return _mm512_castsi512_ps(_mm512_bslli_epi128(_mm512_maskz_cvtepu16_epi32(mask, a), 2)); + return {_mm512_castsi512_ps(_mm512_bslli_epi128(_mm512_maskz_cvtepu16_epi32(mask, a.first), 2))}; #endif #else - __m128i second = _mm256_extractf128_si256(a, 1); + __m128i second = _mm256_extractf128_si256(a.first, 1); __m256 second_fp32 = _mm256_castsi256_ps(_mm256_bslli_epi128(_mm256_cvtepu16_epi32(second), 2)); - __m128i first = _mm256_castsi256_si128(a); + __m128i first = _mm256_castsi256_si128(a.first); __m256 first_fp32 = _mm256_castsi256_ps(_mm256_bslli_epi128(_mm256_cvtepu16_epi32(first), 2)); MASK_DECORATOR(_mm256_blend_ps, _mm256_setzero_ps(), first_fp32, mask & 255, first_fp32); MASK_DECORATOR(_mm256_blend_ps, _mm256_setzero_ps(), second_fp32, mask >> 8, second_fp32); @@ -96,9 +97,9 @@ inline fp32x16 maskz_cvt_bf16x16_fp32x16(int mask, bf16x16 a) { #endif } -inline int8x16 cvt_uint32x16_uint8x16(int32x16 a) { +u8x16 cvt_u32x16_u8x16(u32x16 a) { #if __AVX512F__ - return _mm512_cvtusepi32_epi8(a); + return {_mm512_cvtusepi32_epi8(a.first)}; #else __m256i first = _mm256_min_epi32(_mm256_set1_epi32(255), a.first); __m256i second = _mm256_min_epi32(_mm256_set1_epi32(255), a.second); @@ -108,13 +109,13 @@ inline int8x16 cvt_uint32x16_uint8x16(int32x16 a) { -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1)); __m256i result = _mm256_or_si256(first, second); result = _mm256_permutevar8x32_epi32(result, _mm256_set_epi32(7, 6, 3, 2, 5, 1, 4, 0)); - return _mm256_castsi256_si128(result); + return {_mm256_castsi256_si128(result)}; #endif } -inline int8x16 maskz_cvt_uint32x16_uint8x16(int mask, int32x16 a) { +u8x16 maskz_cvt_u32x16_u8x16(int mask, u32x16 a) { #if __AVX512F__ - return _mm512_maskz_cvtusepi32_epi8(mask, a); + return {_mm512_maskz_cvtusepi32_epi8(mask, a.first)}; #else __m256i first, second; MASK_DECORATOR(_mm256_blend_epi32, _mm256_setzero_si256(), _mm256_min_epi32(_mm256_set1_epi32(255), a.first), @@ -127,13 +128,13 @@ inline int8x16 maskz_cvt_uint32x16_uint8x16(int mask, int32x16 a) { -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1)); __m256i result = _mm256_or_si256(first, second); result = _mm256_permutevar8x32_epi32(result, _mm256_set_epi32(7, 6, 3, 2, 5, 1, 4, 0)); - return _mm256_castsi256_si128(result); + return {_mm256_castsi256_si128(result)}; #endif } -inline int8x16 cvt_int32x16_int8x16(int32x16 a) { +s8x16 cvt_s32x16_s8x16(s32x16 a) { #if __AVX512F__ - return _mm512_cvtsepi32_epi8(a); + return {_mm512_cvtsepi32_epi8(a.first)}; #else __m256i first = _mm256_min_epi32(_mm256_set1_epi32(127), a.first); __m256i second = _mm256_min_epi32(_mm256_set1_epi32(127), a.second); @@ -145,13 +146,13 @@ inline int8x16 cvt_int32x16_int8x16(int32x16 a) { -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1)); __m256i result = _mm256_or_si256(first, second); result = _mm256_permutevar8x32_epi32(result, _mm256_set_epi32(7, 6, 3, 2, 5, 1, 4, 0)); - return _mm256_castsi256_si128(result); + return {_mm256_castsi256_si128(result)}; #endif } -inline int8x16 maskz_cvt_int32x16_int8x16(const int mask, int32x16 a) { +s8x16 maskz_cvt_s32x16_s8x16(const int mask, s32x16 a) { #if __AVX512F__ - return _mm512_maskz_cvtsepi32_epi8(mask, a); + return {_mm512_maskz_cvtsepi32_epi8(mask, a.first)}; #else __m256i first, second; MASK_DECORATOR(_mm256_blend_epi32, _mm256_setzero_si256(), _mm256_min_epi32(_mm256_set1_epi32(127), a.first), @@ -166,6 +167,22 @@ inline int8x16 maskz_cvt_int32x16_int8x16(const int mask, int32x16 a) { -1, -1, -1, -1, -1, -1, -1, 12, 8, 4, 0, -1, -1, -1, -1)); __m256i result = _mm256_or_si256(first, second); result = _mm256_permutevar8x32_epi32(result, _mm256_set_epi32(7, 6, 3, 2, 5, 1, 4, 0)); - return _mm256_castsi256_si128(result); + return {_mm256_castsi256_si128(result)}; +#endif +} + +void cvtu32x16_store_u8x16(void* base_addr, u32x16 a) { +#ifdef __AVX512F__ + _mm512_mask_cvtusepi32_storeu_epi8(base_addr, 0xffff, a.first); +#else + store_u8x16(base_addr, cvt_u32x16_u8x16(a)); +#endif +} + +void mask_cvtu32x16_store_u8x16(void* base_addr, int mask, u32x16 a) { +#ifdef __AVX512F__ + _mm512_mask_cvtusepi32_storeu_epi8(base_addr, mask, a.first); +#else + mask_store_u8x16(base_addr, mask, maskz_cvt_u32x16_u8x16(mask, a)); #endif } diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.hpp index e85e324243b..d0e58e5d0c0 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_convert.hpp @@ -18,25 +18,26 @@ #include "vec_base.hpp" template <int rounding = (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)> -inline int32x16 cvt_roundfp32x16_int32x16(fp32x16 a); +s32x16 cvt_roundfp32x16_s32x16(fp32x16 a); template <int rounding> -struct ne_cvt_roundfp32x16_int32x16_kernel_t : public kernel_t<int32x16, fp32x16> { - ne_cvt_roundfp32x16_int32x16_kernel_t() { func_ = cvt_roundfp32x16_int32x16<rounding>; } +struct ne_cvt_roundfp32x16_s32x16_kernel_t : public kernel_t<s32x16, fp32x16> { + ne_cvt_roundfp32x16_s32x16_kernel_t() { func_ = cvt_roundfp32x16_s32x16<rounding>; } }; template <int rounding = (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)> -inline int32x16 maskz_cvt_roundfp32x16_int32x16(int mask, fp32x16 a); -inline bf16x16 cvt_fp32x16_bf16x16(fp32x16 a); +s32x16 maskz_cvt_roundfp32x16_s32x16(int mask, fp32x16 a); +bf16x16 cvt_fp32x16_bf16x16(fp32x16 a); -inline fp32x16 cvt_bf16x16_fp32x16(bf16x16 a); +fp32x16 cvt_bf16x16_fp32x16(bf16x16 a); -inline fp32x16 maskz_cvt_bf16x16_fp32x16(int mask, bf16x16 a); +fp32x16 maskz_cvt_bf16x16_fp32x16(int mask, bf16x16 a); -inline int8x16 cvt_uint32x16_uint8x16(int32x16 a); +u8x16 cvt_u32x16_u8x16(u32x16 a); +u8x16 maskz_cvt_u32x16_u8x16(int mask, u32x16 a); -inline int8x16 maskz_cvt_uint32x16_uint8x16(int mask, int32x16 a); - -inline int8x16 cvt_int32x16_int8x16(int32x16 a); -inline int8x16 maskz_cvt_int32x16_int8x16(const int mask, int32x16 a); +s8x16 cvt_s32x16_s8x16(s32x16 a); +s8x16 maskz_cvt_s32x16_s8x16(const int mask, s32x16 a); +void cvtu32x16_store_u8x16(void* base_addr, u32x16 a); +void mask_cvtu32x16_store_u8x16(void* base_addr, int mask, u32x16 a); #endif // ENGINE_EXECUTOR_INCLUDE_VEC_CONVERT_HPP_ diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.cpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.cpp index c6a89f46280..eeaea180dda 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.cpp @@ -13,38 +13,3 @@ // limitations under the License. #include "vec_load.hpp" - -inline fp32x16 load_fp32x16(void const* mem_addr) { -#if __AVX512F__ - return _mm512_loadu_ps(mem_addr); -#else - float const* mem_addr_fp32 = reinterpret_cast<float const*>(mem_addr); - return {_mm256_loadu_ps(mem_addr_fp32), _mm256_loadu_ps(mem_addr_fp32 + 8)}; -#endif -} - -inline fp32x16 mask_load_fp32x16(fp32x16 src, int mask, void const* mem_addr) { -#if __AVX512F__ - return _mm512_mask_loadu_ps(src, mask, mem_addr); -#else - float const* mem_addr_fp32 = reinterpret_cast<float const*>(mem_addr); - return {_mm256_loadu_ps(mem_addr_fp32), _mm256_loadu_ps(mem_addr_fp32 + 8)}; -#endif -} - -inline bf16x16 load_bf16x16(void const* mem_addr) { - __m256i const* mem_addr_bf16 = reinterpret_cast<__m256i const*>(mem_addr); - return _mm256_loadu_si256(mem_addr_bf16); -} - -inline bf16x16 maskz_load_bf16x16(int mask, void const* mem_addr) { -#if __AVX512F__ - __m256i const* mem_addr_bf16 = reinterpret_cast<__m256i const*>(mem_addr); - return _mm256_maskz_loadu_epi16(mask, mem_addr_bf16); -#else - bf16x16 res; - MASK_DECORATOR(_mm256_blend_epi16, _mm256_setzero_si256(), - _mm256_loadu_si256(reinterpret_cast<__m256i const*>(mem_addr)), mask, res); - return res; -#endif -} diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.hpp index dd5bab9f0ee..435736eee87 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_load.hpp @@ -17,19 +17,31 @@ #include "vec_base.hpp" -template <> -float load_kernel_t<float>(const void* src) { - return *reinterpret_cast<const float*>(src); +inline fp32x16 load_fp32x16(void const* mem_addr) { +#if __AVX512F__ + return {_mm512_loadu_ps(mem_addr)}; +#else + float const* mem_addr_fp32 = reinterpret_cast<float const*>(mem_addr); + return {_mm256_loadu_ps(mem_addr_fp32), _mm256_loadu_ps(mem_addr_fp32 + 8)}; +#endif } - -inline fp32x16 load_fp32x16(void const* mem_addr); template <> fp32x16 load_kernel_t<fp32x16>(const void* src) { return load_fp32x16(src); } -inline fp32x16 mask_load_fp32x16(fp32x16 src, int mask, void const* mem_addr); +inline fp32x16 mask_load_fp32x16(fp32x16 src, int mask, void const* mem_addr) { +#if __AVX512F__ + return {_mm512_mask_loadu_ps(src.first, mask, mem_addr)}; +#else + float const* mem_addr_fp32 = reinterpret_cast<float const*>(mem_addr); + return {_mm256_loadu_ps(mem_addr_fp32), _mm256_loadu_ps(mem_addr_fp32 + 8)}; +#endif +} -inline bf16x16 load_bf16x16(void const* mem_addr); +inline bf16x16 load_bf16x16(void const* mem_addr) { + __m256i const* mem_addr_bf16 = reinterpret_cast<__m256i const*>(mem_addr); + return {_mm256_loadu_si256(mem_addr_bf16)}; +} template <> bf16x16 load_kernel_t<bf16x16>(const void* src) { return load_bf16x16(src); diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.cpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.cpp index afd2ba8bedf..d16c749ee56 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.cpp @@ -14,57 +14,39 @@ #include "vec_set.hpp" -inline fp32x16 set1_fp32x16(const float x) { +fp32x16 set1_fp32x16(const float x) { #if __AVX512F__ - return _mm512_set1_ps(x); + return {_mm512_set1_ps(x)}; #else return {_mm256_set1_ps(x), _mm256_set1_ps(x)}; #endif } -inline int32x16 set1_int8x16(const int8_t x) { -#if __AVX512F__ - return _mm512_set1_epi8(x); -#else - return {_mm256_set1_epi8(x), _mm256_set1_epi8(x)}; -#endif -} +s8x16 set1_s8x16(const int8_t x) { return {_mm_set1_epi8(x)}; } -inline int32x16 set1_int16x16(const int16_t x) { -#if __AVX512F__ - return _mm512_set1_epi16(x); -#else - return {_mm256_set1_epi16(x), _mm256_set1_epi16(x)}; -#endif -} +s16x16 set1_s16x16(const int16_t x) { return {_mm256_set1_epi16(x)}; } -inline int32x16 set1_fp16x16(const uint16_t x) { -#if __AVX512F__ - return _mm512_set1_epi16(x); -#else - return {_mm256_set1_epi16(x), _mm256_set1_epi16(x)}; -#endif -} +fp16x16 set1_fp16x16(const uint16_t x) { return {_mm256_set1_epi16(x)}; } -inline int32x16 set1_int32x16(const int16_t x) { +s32x16 set1_s32x16(const int32_t x) { #if __AVX512F__ - return _mm512_set1_epi32(x); + return {_mm512_set1_epi32(x)}; #else return {_mm256_set1_epi32(x), _mm256_set1_epi32(x)}; #endif } -inline int32x16 setzero_int32x16() { +s32x16 setzero_s32x16() { #if __AVX512F__ - return _mm512_setzero_epi32(); + return {_mm512_setzero_epi32()}; #else return {_mm256_setzero_si256(), _mm256_setzero_si256()}; #endif } -inline fp32x16 setzero_fp32x16() { +fp32x16 setzero_fp32x16() { #if __AVX512F__ - return _mm512_setzero_ps(); + return {_mm512_setzero_ps()}; #else return {_mm256_setzero_ps(), _mm256_setzero_ps()}; #endif diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.hpp index 8e4507cc1d8..84b15ca2a18 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_set.hpp @@ -17,18 +17,23 @@ #include "vec_base.hpp" -inline fp32x16 set1_fp32x16(const float x); +fp32x16 set1_fp32x16(const float x); +REGISTER_KERNEL_T(set1_fp32x16, fp32x16, float); -inline int32x16 set1_int8x16(const int8_t x); +s8x16 set1_s8x16(const int8_t x); +REGISTER_KERNEL_T(set1_s8x16, s8x16, int8_t); -inline int32x16 set1_int16x16(const int16_t x); +s16x16 set1_s16x16(const int16_t x); +REGISTER_KERNEL_T(set1_s16x16, s16x16, int16_t); -inline int32x16 set1_fp16x16(const uint16_t x); +fp16x16 set1_fp16x16(const uint16_t x); +REGISTER_KERNEL_T(set1_fp16x16, fp16x16, uint16_t); -inline int32x16 set1_int32x16(const int16_t x); +s32x16 set1_s32x16(const int32_t x); +REGISTER_KERNEL_T(set1_s32x16, s32x16, int32_t); -inline int32x16 setzero_int32x16(); +s32x16 setzero_s32x16(); -inline fp32x16 setzero_fp32x16(); +fp32x16 setzero_fp32x16(); #endif // ENGINE_EXECUTOR_INCLUDE_VEC_SET_HPP_ diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.cpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.cpp index 1d60f6bf30e..8c862d6c20b 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.cpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.cpp @@ -13,44 +13,3 @@ // limitations under the License. #include "vec_store.hpp" -#include "vec_convert.hpp" - -inline void store_int8x16(void* mem_addr, int8x16 a) { _mm_storeu_si128(reinterpret_cast<__m128i*>(mem_addr), a); } -inline void mask_store_int8x16(void* mem_addr, const int mask, int8x16 a) { -#ifdef __AVX512F__ - _mm_mask_storeu_epi8(mem_addr, mask, a); -#else - __m128i mask_reg = - _mm_set_epi8(mask & 32768, mask & 16384, mask & 8192, mask & 4096, mask & 2048, mask & 1024, mask & 512, - mask & 256, mask & 128, mask & 64, mask & 32, mask & 16, mask & 8, mask & 4, mask & 2, mask & 1); - _mm_maskmoveu_si128(a, mask_reg, reinterpret_cast<char*>(mem_addr)); -#endif -} - -inline void store_fp32x16(void* mem_addr, fp32x16 a) { -#ifdef __AVX512F__ - _mm512_storeu_ps(mem_addr, a); -#else - float* mem_addr_fp32 = reinterpret_cast<float*>(mem_addr); - _mm256_storeu_ps(mem_addr_fp32, a.first); - _mm256_storeu_ps(mem_addr_fp32 + 8, a.second); -#endif -} - -inline void store_bf16x16(void* mem_addr, bf16x16 a) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(mem_addr), a); } - -inline void cvtuint32x16_store_int8x16(void* base_addr, int32x16 a) { -#ifdef __AVX512F__ - _mm512_mask_cvtusepi32_storeu_epi8(base_addr, 0xffff, a); -#else - store_int8x16(base_addr, cvt_uint32x16_uint8x16(a)); -#endif -} - -inline void mask_cvtuint32x16_store_int8x16(void* base_addr, int mask, int32x16 a) { -#ifdef __AVX512F__ - _mm512_mask_cvtusepi32_storeu_epi8(base_addr, mask, a); -#else - mask_store_int8x16(base_addr, mask, maskz_cvt_uint32x16_uint8x16(mask, a)); -#endif -} diff --git a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.hpp b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.hpp index dfc1bc3daa3..e096a0de96b 100644 --- a/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.hpp +++ b/intel_extension_for_transformers/llm/runtime/graph/vectors/cpu/vec_store.hpp @@ -17,19 +17,43 @@ #include "vec_base.hpp" -inline void store_int8x16(void* mem_addr, int8x16 a); +inline void store_s8x16(void* mem_addr, s8x16 a) { _mm_storeu_si128(reinterpret_cast<__m128i*>(mem_addr), a.first); } +inline void store_u8x16(void* mem_addr, u8x16 a) { _mm_storeu_si128(reinterpret_cast<__m128i*>(mem_addr), a.first); } template <> -void store_kernel_t<int8x16>(void* dst, int8x16 src) { - store_int8x16(dst, src); +void store_kernel_t<s8x16>(void* dst, s8x16 src) { + store_s8x16(dst, src); } -inline void mask_store_int8x16(void* mem_addr, const int mask, int8x16 a); +inline void mask_store_s8x16(void* mem_addr, const int mask, s8x16 a) { +#ifdef __AVX512F__ + _mm_mask_storeu_epi8(mem_addr, mask, a.first); +#else + __m128i mask_reg = + _mm_set_epi8(mask & 32768, mask & 16384, mask & 8192, mask & 4096, mask & 2048, mask & 1024, mask & 512, + mask & 256, mask & 128, mask & 64, mask & 32, mask & 16, mask & 8, mask & 4, mask & 2, mask & 1); + _mm_maskmoveu_si128(a.first, mask_reg, reinterpret_cast<char*>(mem_addr)); +#endif +} -inline void store_fp32x16(void* mem_addr, fp32x16 a); -template <> -void store_kernel_t<float>(void* dst, float src) { - float* dst_fp32 = reinterpret_cast<float*>(dst); - *dst_fp32 = src; +inline void mask_store_u8x16(void* mem_addr, const int mask, u8x16 a) { +#ifdef __AVX512F__ + _mm_mask_storeu_epi8(mem_addr, mask, a.first); +#else + __m128i mask_reg = + _mm_set_epi8(mask & 32768, mask & 16384, mask & 8192, mask & 4096, mask & 2048, mask & 1024, mask & 512, + mask & 256, mask & 128, mask & 64, mask & 32, mask & 16, mask & 8, mask & 4, mask & 2, mask & 1); + _mm_maskmoveu_si128(a.first, mask_reg, reinterpret_cast<char*>(mem_addr)); +#endif +} + +inline void store_fp32x16(void* mem_addr, fp32x16 a) { +#ifdef __AVX512F__ + _mm512_storeu_ps(mem_addr, a.first); +#else + float* mem_addr_fp32 = reinterpret_cast<float*>(mem_addr); + _mm256_storeu_ps(mem_addr_fp32, a.first); + _mm256_storeu_ps(mem_addr_fp32 + 8, a.second); +#endif } template <> @@ -37,13 +61,13 @@ void store_kernel_t<fp32x16>(void* dst, fp32x16 src) { store_fp32x16(dst, src); } -inline void store_bf16x16(void* mem_addr, bf16x16 a); +inline void store_bf16x16(void* mem_addr, bf16x16 a) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(mem_addr), a.first); +} + template <> void store_kernel_t<bf16x16>(void* dst, bf16x16 src) { store_bf16x16(dst, src); } -inline void cvtuint32x16_store_int8x16(void* base_addr, int32x16 a); - -inline void mask_cvtuint32x16_store_int8x16(void* base_addr, int mask, int32x16 a); #endif // ENGINE_EXECUTOR_INCLUDE_VEC_STORE_HPP_