Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
support Avx2 (#493)
Browse files Browse the repository at this point in the history
* support Memcpy2D

* support gelu fusion

---------

Co-authored-by: luoyu-intel <yu.luo@intel.com>
  • Loading branch information
yuchengliu1 and luoyu-intel authored Oct 20, 2023
1 parent 0cff05a commit ea69f9a
Show file tree
Hide file tree
Showing 19 changed files with 503 additions and 270 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ cd build
cmake ..
cmake --build . -j
```

Note: add compile args ```-DNE_AVX512=OFF -DNE_AVX512_VBMI=OFF -DNE_AVX512_VNNI=OFF``` to ```cmake``` when compiling it on a CPU without AVX512
### 2. Run LLM with Python API

You can use Python API to run Hugging Face model simply. Here is the sample code:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ using PerNFp32Fp32 = jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB<
jblas::utils::parallel::Parallel2DGemm>;
} // namespace avx512_vnni

namespace avx2 {
JBLAS_ISA constexpr DefaultISA = JblasAVX2;
template <template <class GC, JBLAS_ISA ISA> class ProB>
using Default = jblas::wrapper::gemm_pack_weight::GemmInterfacePackWeight<
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<DefaultISA, jblas::gemm::GemmCore_Row_NN_2x48_AVX2,
jblas::prologue::gemm::ActivationBase, ProB,
jblas::epilogue::gemm::AccumulatorWriteBackFp32>,
jblas::utils::parallel::Parallel2DGemm>;
} // namespace avx2
} // namespace

static JBLAS_CODE jblas_s4fp32kblock_f32f32_forward(float* activation, SS4Fp32* weiptr, float* output, int _m, int _n,
Expand Down Expand Up @@ -125,6 +134,10 @@ static JBLAS_CODE jblas_s4fp32kblock_f32f32_forward(float* activation, SS4Fp32*
using GemmKernel = avx512f::Default<WeiS4ClipFp32>;
static GemmKernel kernel;
ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo});
} else if (_cd->AVX2()) {
using GemmKernel = avx2::Default<WeiS4ClipFp32>;
static GemmKernel kernel;
ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo});
}
} else if (weiptr->mCoreType == GcCompBf16::TYPE) {
if (_cd->AMX_BF16()) {
Expand Down Expand Up @@ -159,6 +172,10 @@ static JBLAS_CODE jblas_s8fp32kblock_f32f32_forward(float* activation, SS8Fp32*
using GemmKernel = avx512f::Default<WeiS8Fp32>;
static GemmKernel kernel;
ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo});
} else if (_cd->AVX2()) {
using GemmKernel = avx2::Default<WeiS8Fp32>;
static GemmKernel kernel;
ret = kernel.compute({_m, _n, _k, activation, lda, weiptr, output, ldo});
}
}
return ret;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ bool jblas_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr,
if (sameKernel) {
if (w1tmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32) ||
w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
constexpr size_t EleNum = sizeof(GcCompInt8KBlockSet) / sizeof(GcCompInt8KBlockSet[0]);
support = contains(w1tmp->mCoreType, GcCompInt8KBlockSet, EleNum);
support &= hasISA(GcCompInt8KBlockSet, EleNum);
constexpr jblas::gemm::GemmCoreType cores[] = {
jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48,
jblas::gemm::GemmCoreType::AVX2_2X48};
constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]);
support = contains(w1tmp->mCoreType, cores, EleNum);
support &= hasISA(cores, EleNum);
} else if (w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32PerChannelN) ||
w1tmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32PerChannelN)) {
constexpr size_t EleNum = sizeof(GcCompInt8Set) / sizeof(GcCompInt8Set[0]);
Expand Down Expand Up @@ -115,6 +119,42 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s4fp32_f32f32_forward(float* activation, SS4Fp3
w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
}
}
} else if (w1ptr->mCoreType == GcCompFp32::TYPE) {
if (_cd->AVX512F()) {
using GemmKernel = custom::wrapper::kblock::avx512f::GemmS4KBlock;
using SiluGemmKernel = custom::wrapper::kblock::avx512f::SiluGemmS4KBlock;
using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>;
static FusedInter finter;
int lda = fin;
int ldtmp1 = fmid;
int ldtmp2 = fmid;
int ldo = fout;
GemmKernel::AParam paramA = {activation, lda};
SiluGemmKernel::BParam paramW1 = {w1ptr};
GemmKernel::BParam paramW2 = {w2ptr};
GemmKernel::BParam paramW3 = {w3ptr};
SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1};
GemmKernel::EpiParam param2 = {output, ldo, NULL};
GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL};
ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3});
} else if (_cd->AVX2()) {
using GemmKernel = custom::wrapper::kblock::avx2::GemmS4KBlock;
using SiluGemmKernel = custom::wrapper::kblock::avx2::SiluGemmS4KBlock;
using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>;
static FusedInter finter;
int lda = fin;
int ldtmp1 = fmid;
int ldtmp2 = fmid;
int ldo = fout;
GemmKernel::AParam paramA = {activation, lda};
SiluGemmKernel::BParam paramW1 = {w1ptr};
GemmKernel::BParam paramW2 = {w2ptr};
GemmKernel::BParam paramW3 = {w3ptr};
SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1};
GemmKernel::EpiParam param2 = {output, ldo, NULL};
GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL};
ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3});
}
}
return ret;
}
Expand Down Expand Up @@ -160,6 +200,42 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s8fp32_f32f32_forward(float* activation, SS8Fp3
ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
}
} else if (w1ptr->mCoreType == GcCompFp32::TYPE) {
if (_cd->AVX512F()) {
using GemmKernel = custom::wrapper::kblock::avx512f::GemmS8KBlock;
using SiluGemmKernel = custom::wrapper::kblock::avx512f::SiluGemmS8KBlock;
using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>;
static FusedInter finter;
int lda = fin;
int ldtmp1 = fmid;
int ldtmp2 = fmid;
int ldo = fout;
GemmKernel::AParam paramA = {activation, lda};
SiluGemmKernel::BParam paramW1 = {w1ptr};
GemmKernel::BParam paramW2 = {w2ptr};
GemmKernel::BParam paramW3 = {w3ptr};
SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1};
GemmKernel::EpiParam param2 = {output, ldo, NULL};
GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL};
ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3});
} else if (_cd->AVX2()) {
using GemmKernel = custom::wrapper::kblock::avx2::GemmS8KBlock;
using SiluGemmKernel = custom::wrapper::kblock::avx2::SiluGemmS8KBlock;
using FusedInter = custom::wrapper::transformer::FPFFNFusedInterface<SiluGemmKernel, GemmKernel>;
static FusedInter finter;
int lda = fin;
int ldtmp1 = fmid;
int ldtmp2 = fmid;
int ldo = fout;
GemmKernel::AParam paramA = {activation, lda};
SiluGemmKernel::BParam paramW1 = {w1ptr};
GemmKernel::BParam paramW2 = {w2ptr};
GemmKernel::BParam paramW3 = {w3ptr};
SiluGemmKernel::EpiParam param1 = {tmp1, ldtmp1};
GemmKernel::EpiParam param2 = {output, ldo, NULL};
GemmKernel::EpiParam param3 = {tmp2, ldtmp2, NULL};
ret = finter.compute({seq, fin, fmid, fout, paramA, paramW1, paramW2, paramW3, param1, param2, param3});
}
}
return ret;
}
Expand Down Expand Up @@ -460,7 +536,8 @@ bool jblas_fusion_FFN_Add_GeLu_f32f32_support(void* w1ptr, void* w2ptr, int seq,
w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
constexpr jblas::gemm::GemmCoreType cores[] = {
jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48};
jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48,
jblas::gemm::GemmCoreType::AVX2_2X48};
constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]);
support = contains(w1tmp->mCoreType, cores, EleNum);
support &= hasISA(cores, EleNum);
Expand Down Expand Up @@ -538,10 +615,16 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4fp32_f32f32_forward(float* activation, SS
int lda = fin;
int ldtmp1 = fmid;
int ldo = fout;
// FusedInter::Arguments::paramA paramA={activation, lda};
// FusedInter::Arguments::paramW1 paramW1={w1tmp};
// FusedInter::Arguments::paramW2 paramW2={w2tmp};
// FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1};
ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1,
broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
} else if (_cd->AVX2()) {
using GemmKernel = custom::wrapper::kblock::avx2::AddGemmS4KBlock;
using GeluGemmKernel = custom::wrapper::kblock::avx2::AddGeluGemmS4KBlock;
using FusedInter = custom::wrapper::transformer::FpGeluFusedInterface<GeluGemmKernel, GemmKernel>;
static FusedInter finter;
int lda = fin;
int ldtmp1 = fmid;
int ldo = fout;
ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1,
broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
}
Expand All @@ -554,10 +637,6 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4fp32_f32f32_forward(float* activation, SS
int lda = fin;
int ldtmp1 = fmid;
int ldo = fout;
// FusedInter::Arguments::paramA paramA={activation, lda};
// FusedInter::Arguments::paramW1 paramW1={w1tmp};
// FusedInter::Arguments::paramW2 paramW2={w2tmp};
// FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1};
ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1,
broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
}
Expand Down Expand Up @@ -626,10 +705,16 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s8fp32_f32f32_forward(float* activation, SS
int lda = fin;
int ldtmp1 = fmid;
int ldo = fout;
// FusedInter::Arguments::paramA paramA={activation, lda};
// FusedInter::Arguments::paramW1 paramW1={w1tmp};
// FusedInter::Arguments::paramW2 paramW2={w2tmp};
// FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1};
ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1,
broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
} else if (_cd->AVX2()) {
using GemmKernel = custom::wrapper::kblock::avx2::AddGemmS8KBlock;
using GeluGemmKernel = custom::wrapper::kblock::avx2::AddGeluGemmS8KBlock;
using FusedInter = custom::wrapper::transformer::FpGeluFusedInterface<GeluGemmKernel, GemmKernel>;
static FusedInter finter;
int lda = fin;
int ldtmp1 = fmid;
int ldo = fout;
ret = finter.compute({seq, fin, fmid, fout, activation, lda, w1tmp, w2tmp, tmp1, b1ptr, ldtmp1,
broadcast_bias ? 0 : ldtmp1, output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ static bool hasISA(const jblas::gemm::GemmCoreType* set, size_t len) {
support |= _cd->AVX512F();
break;
case jblas::gemm::GemmCoreType::AVX2_4X24:
case jblas::gemm::GemmCoreType::AVX2_2X48:
support |= _cd->AVX2();
break;
case jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK:
Expand Down Expand Up @@ -113,6 +114,7 @@ using SS8Fp32 = jblas::prologue::weight_comp::gemm_kblcok::StorageWeightS8ScaleF
using SS8Fp32PerN = jblas::prologue::weight_comp::gemm_kblcok::StorageWeightS8ScaleFp32PerChannelN;
using SS4Fp32PerN = jblas::prologue::weight_comp::gemm_kblcok::StorageWeightS4ScaleFp32PerChannelN;

using GcCompAVX2 = jblas::gemm::GemmCore_Row_NN_4x24_AVX2;
using GcCompFp32 = jblas::gemm::GemmCore_Row_NN_8x48_AVX512F;
using GcCompInt8KBlock = jblas::gemm::kblock::GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK;
using GcCompBf16 = jblas::gemm::GemmCore_Row_NN_16x64_AMX_BF16;
Expand Down Expand Up @@ -454,6 +456,96 @@ class FFNFusedInterface {
_SiluLauncher_T mActLauncher;
};

template <class _SiluLauncher_T, class _Launcher_T>
class FPFFNFusedInterface {
public:
static_assert(std::is_same<typename _Launcher_T::AParam, typename _SiluLauncher_T::AParam>::value,
"Prologue A param of the 2 Launcher (w/wo SILU) should be the same.");
struct Arguments {
const int Seq, Fin, FMid, FOut;
const typename _Launcher_T::AParam paramA;
const typename _SiluLauncher_T::BParam paramW1;
const typename _Launcher_T::BParam paramW2, paramW3;
const typename _SiluLauncher_T::EpiParam param1;
const typename _Launcher_T::EpiParam param2, param3;
};
using Config = typename _Launcher_T::ParallelConfig;
using ActConfig = typename _SiluLauncher_T::ParallelConfig;
using ActivationType = typename _Launcher_T::PrologueA;
using WeightType = typename _Launcher_T::PrologueB;
using GemmCore = typename _Launcher_T::GemmCore;
using LArguments = typename _Launcher_T::Param;
using CParam = typename _Launcher_T::EpiParam;
using Parallel = jblas::utils::parallel::Parallel2DGemmKBlockFixed<GemmCore>;
ActivationType* getActivationPtr() { return &mLauncher.mProA; }
// forward=packB+compute
JBLAS_CODE compute(const Arguments& _param) {
auto bptr = (jblas::prologue::weight_comp::gemm_kblcok::WeightBase*)(_param.paramW1.packedW);
if (bptr == nullptr) {
return JblasInvalidParam;
}
// dynamic quantization: Seq*Fin
auto cb = jblas::utils::CpuBase();

Parallel _paral = Parallel(); // w1&w3 from Seq* Fin=>FMid
Parallel _paral2 = Parallel(); // w2 from Seq* FMid=>Fout
_paral.update(_param.Seq, _param.FMid, _param.Fin, bptr->mBlockSize, cb.mNumThreads);
_paral2.update(_param.Seq, _param.FOut, _param.FMid, bptr->mBlockSize, cb.mNumThreads);

omp_set_num_threads(cb.mNumThreads);
#pragma omp parallel
{
int tidx = omp_get_thread_num();
{
int colidx, rowidx, rowsize, colsize;
_paral.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize);
if (rowsize > 0 && colsize > 0) {
ActConfig _actconfig{
rowidx, colidx, rowsize, colsize, _paral.getMStep(), _paral.getNStep(), _paral.getKStep(), cb.mL2Cache};
Config _config{rowidx, colidx, rowsize, colsize, _paral.getMStep(), _paral.getNStep(), _paral.getKStep(),
cb.mL2Cache};
mActLauncher.launch(
_actconfig, {_param.Seq, _param.FMid, _param.Fin, _param.paramA, _param.paramW1, _param.param1, NULL});
mLauncher.launch(_config,
{_param.Seq, _param.FMid, _param.Fin, _param.paramA, _param.paramW3, _param.param3, NULL});
int row_r = jblas::utils::remainsize(rowidx, _paral.mRows, rowsize);
int col_r = jblas::utils::remainsize(colidx, _paral.mCols, colsize);

// TODO(Yu): replace the naive inplace eltwise mul
for (int i = 0; i < row_r; i++) {
for (int j = 0; j < col_r; j++) {
_param.param1.C[(rowidx + i) * _param.param1.ldc + colidx + j] *=
_param.param3.C[(rowidx + i) * _param.param3.ldc + colidx + j];
}
}
}
}
#pragma omp barrier
{
int colidx, rowidx, rowsize, colsize;
_paral2.getIndex(tidx, &rowidx, &colidx, &rowsize, &colsize);
if (rowsize > 0 && colsize > 0) {
Config _config{
rowidx, colidx, rowsize, colsize, _paral2.getMStep(), _paral2.getNStep(), _paral2.getKStep(),
cb.mL2Cache};
mLauncher.launch(_config, {_param.Seq,
_param.FOut,
_param.FMid,
{_param.param1.C, _param.param1.ldc},
_param.paramW2,
_param.param2,
NULL});
}
}
}
return JblasSuccess;
}

protected:
_Launcher_T mLauncher;
_SiluLauncher_T mActLauncher;
};

template <class _SiluLauncher_T, class _Launcher_T>
class FFNFusedInterfacePerN {
public:
Expand Down Expand Up @@ -843,10 +935,29 @@ using DefaultGemmFp32 =
jblas::prologue::gemm::ActivationBase, ProB, Epi>;
using AddGeluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::Add_GeluFp32>;
using AddGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::AddFp32>;
using GemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
using SiluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::SiluFp32>;

using AddGeluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::Add_GeluFp32>;
using AddGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::AddFp32>;
using GemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
using SiluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::SiluFp32>;
} // namespace avx512f
namespace avx2 {
template <template <class GC, JBLAS_ISA ISA> class ProB, template <JBLAS_ISA ISA> class Epi>
using DefaultGemmFp32 =
jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<JblasAVX2, jblas::gemm::GemmCore_Row_NN_2x48_AVX2,
jblas::prologue::gemm::ActivationBase, ProB, Epi>;
using AddGeluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::Add_GeluFp32>;
using AddGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::AddFp32>;
using GemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
using SiluGemmS8KBlock = DefaultGemmFp32<WeiS8Fp32, custom::epilogue::SiluFp32>;

using AddGeluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::Add_GeluFp32>;
using AddGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::AddFp32>;
using GemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>;
using SiluGemmS4KBlock = DefaultGemmFp32<WeiS4ClipFp32, custom::epilogue::SiluFp32>;
} // namespace avx2
namespace amx_bf16 {
template <template <class GC, JBLAS_ISA ISA> class ProB, template <JBLAS_ISA ISA> class Epi>
using DefaultGemmFp32 =
Expand Down
Loading

0 comments on commit ea69f9a

Please sign in to comment.