diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/README.md b/intel_extension_for_transformers/llm/runtime/graph/core/README.md
index efbb4a28fb3..0b049d7060a 100644
--- a/intel_extension_for_transformers/llm/runtime/graph/core/README.md
+++ b/intel_extension_for_transformers/llm/runtime/graph/core/README.md
@@ -48,12 +48,12 @@ We support three kinds of kernel fusion for transformer models: QKV, MHA (multi-
QKV |
GPT-J LLaMA |
- AMX_INT8, AVX512_VNNI |
+ AMX_INT8, AVX512_VNNI, AVX_VNNI |
FFN |
GPT-J LLaMA BLOOM ChatGLM Falcon MPT |
- AMX_INT8, AVX512_VNNI, AVX512F and AMX_BF16 |
+ AMX_INT8, AVX512_VNNI, AVX512F, AMX_BF16, AVX_VNNI, AVX2 |
MHA |
@@ -71,4 +71,6 @@ codename | weight config | runtime ISA
Sapphire Rapids | any int4
group size=-1
compute type=int8 | AMX_INT8
Ice Lake
Cascade Lake
Cooper Lake
Tiger Lake
Rocket Lake | any int4
group size=-1
compute type=int8 | AVX512_VNNI
Skylake | any 4bits
group size=-1
compute type=fp32 | AVX512F
+Alder Lake (12th Gen)
Raptor Lake (13th and 14th Gen)|any 4bits
group size=-1
compute type=int8 | AVX_VNNI
+Older architecture (before 12th Gen)| any 4bits
group size=-1
compute type=fp32 | AVX2
diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp
index 065ad048df4..8e9b6b7a1f9 100644
--- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp
+++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/inner_product.cpp
@@ -92,6 +92,24 @@ using PerNFp32Fp32 = jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB<
jblas::utils::parallel::Parallel2DGemm>;
} // namespace avx512_vnni
+namespace avx_vnni {
+JBLAS_ISA constexpr DefaultISA = JblasAVX_VNNI;
+
+template class ProB>
+using KBlockFp32Fp32 = jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight<
+ jblas::wrapper::gemm_kblock::GemmLauncherKBlock<
+ DefaultISA, jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK,
+ jblas::prologue::gemm::ActivationF32U8KBlockQuantize, ProB, jblas::epilogue::gemm::AccumulatorWriteBackFp32>,
+ jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
+
+template class ProB>
+using PerNFp32Fp32 = jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB<
+ jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight,
+ jblas::utils::parallel::Parallel2DGemm>;
+} // namespace avx_vnni
+
namespace avx2 {
JBLAS_ISA constexpr DefaultISA = JblasAVX2;
template class ProB>
@@ -114,19 +132,29 @@ static JBLAS_CODE jblas_s4fp32kblock_f32f32_forward(float* activation, SS4Fp32*
auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
quanA.assign((int8_t*)workspace);
ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
- } else if (_cd->AVX512_VNNI() && weiptr->mBlockSize % 8 == 0) {
- if (_m <= 32) {
- using GemmKernel = avx512_vnni::KBlockFp32Fp32Next;
- static GemmKernel kernel;
- auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
- quanA.assign((int8_t*)workspace);
- ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
- } else {
- using GemmKernel = avx512_vnni::KBlockFp32Fp32;
- static GemmKernel kernel;
- auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
- quanA.assign((int8_t*)workspace);
- ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
+ } else {
+ if (weiptr->mBlockSize % 8 == 0) {
+ if (_cd->AVX512_VNNI()) {
+ if (_m <= 32) {
+ using GemmKernel = avx512_vnni::KBlockFp32Fp32Next;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
+ } else {
+ using GemmKernel = avx512_vnni::KBlockFp32Fp32;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
+ }
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = avx_vnni::KBlockFp32Fp32;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
+ }
}
}
} else if (weiptr->mCoreType == GcCompFp32::TYPE) {
@@ -166,6 +194,12 @@ static JBLAS_CODE jblas_s8fp32kblock_f32f32_forward(float* activation, SS8Fp32*
auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
quanA.assign((int8_t*)workspace);
ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = avx_vnni::KBlockFp32Fp32;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, ldo});
}
} else if (weiptr->mCoreType == GcCompFp32::TYPE) {
if (_cd->AVX512F()) {
@@ -208,6 +242,20 @@ static JBLAS_CODE jblas_s8fp32perN_f32f32_forward(float* activation, SS8Fp32PerN
&quanA,
weiptr,
{output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = avx_vnni::PerNFp32Fp32;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute(
+ {_m,
+ _n,
+ _k,
+ activation,
+ lda,
+ &quanA,
+ weiptr,
+ {output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}});
}
}
return ret;
@@ -240,6 +288,20 @@ static JBLAS_CODE jblas_s4fp32perN_f32f32_forward(float* activation, SS4Fp32PerN
&quanA,
weiptr,
{output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = avx_vnni::PerNFp32Fp32;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute(
+ {_m,
+ _n,
+ _k,
+ activation,
+ lda,
+ &quanA,
+ weiptr,
+ {output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr}});
}
}
return ret;
@@ -307,19 +369,30 @@ JBLAS_CODE jblas_fusion_add_s4fp32_f32f32_forward(float* activation, SS4Fp32* we
auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
quanA.assign((int8_t*)workspace);
ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, bias, ldo, broadcast_bias ? 0 : ldo});
- } else if (_cd->AVX512_VNNI() && weiptr->mBlockSize % 8 == 0) {
- if (_m <= 32) {
+ } else if (weiptr->mBlockSize % 8 == 0) {
+ if (_cd->AVX512_VNNI()) {
+ if (_m <= 32) {
+ using GemmKernel = jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight<
+ custom::wrapper::kblock::avx512_vnni::AddGemmSKernelDynamicS4KBlockNext,
+ jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute(
+ {_m, _n, _k, activation, lda, &quanA, weiptr, output, bias, ldo, broadcast_bias ? 0 : ldo});
+ } else {
+ using GemmKernel = jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight<
+ custom::wrapper::kblock::avx512_vnni::AddGemmSKernelDynamicS4KBlock,
+ jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute(
+ {_m, _n, _k, activation, lda, &quanA, weiptr, output, bias, ldo, broadcast_bias ? 0 : ldo});
+ }
+ } else if (_cd->AVX_VNNI()) {
using GemmKernel = jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight<
- custom::wrapper::kblock::avx512_vnni::AddGemmSKernelDynamicS4KBlockNext,
- jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
- static GemmKernel kernel;
- auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
- quanA.assign((int8_t*)workspace);
- ret =
- kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, bias, ldo, broadcast_bias ? 0 : ldo});
- } else {
- using GemmKernel = jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight<
- custom::wrapper::kblock::avx512_vnni::AddGemmSKernelDynamicS4KBlock,
+ custom::wrapper::kblock::avx_vnni::AddGemmSKernelDynamicS4KBlock,
jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
static GemmKernel kernel;
auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
@@ -354,6 +427,14 @@ JBLAS_CODE jblas_fusion_add_s8fp32_f32f32_forward(float* activation, SS8Fp32* we
auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
quanA.assign((int8_t*)workspace);
ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, bias, ldo, broadcast_bias ? 0 : ldo});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight<
+ custom::wrapper::kblock::avx_vnni::AddGemmSKernelDynamicS8KBlock,
+ jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, weiptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, activation, lda, &quanA, weiptr, output, bias, ldo, broadcast_bias ? 0 : ldo});
}
}
return ret;
@@ -391,6 +472,24 @@ JBLAS_CODE jblas_fusion_add_s8fp32pern_f32f32_forward(float* activation, SS8Fp32
bias,
broadcast_bias ? 0 : ldo}});
}
+ if (_cd->AVX_VNNI()) {
+ using GemmKernel = jblas::wrapper::gemm_pack_weight::GemmInterfaceParallelAB<
+ custom::wrapper::kblock::avx_vnni::AddGemmDynamicS8PerN, jblas::utils::parallel::Parallel2DGemm>;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute(
+ {_m,
+ _n,
+ _k,
+ activation,
+ lda,
+ &quanA,
+ weiptr,
+ {{output, ldo, quanA.mCStep, quanA.mSPtr, weiptr->mSPtr, quanA.mZPtr, weiptr->mRPtr},
+ bias,
+ broadcast_bias ? 0 : ldo}});
+ }
}
return ret;
}
diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp
index 940ea3d7fb3..db952e12cc2 100644
--- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp
+++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_ffn.cpp
@@ -37,10 +37,12 @@ bool jblas_fusion_FFN_SiLu_f32f32_support(void* w1ptr, void* w2ptr, void* w3ptr,
if (sameKernel) {
if (w1tmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32) ||
w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
- constexpr jblas::gemm::GemmCoreType cores[] = {
- jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
- jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48,
- jblas::gemm::GemmCoreType::AVX2_2X48};
+ constexpr jblas::gemm::GemmCoreType cores[] = {jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK,
+ jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
+ jblas::gemm::GemmCoreType::AVX512F_8x48,
+ jblas::gemm::GemmCoreType::AMX_BF16_16x48,
+ jblas::gemm::GemmCoreType::AVX2_2X48,
+ jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK};
constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]);
support = contains(w1tmp->mCoreType, cores, EleNum);
support &= hasISA(cores, EleNum);
@@ -82,27 +84,46 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s4fp32_f32f32_forward(float* activation, SS4Fp3
ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
- } else if (_cd->AVX512_VNNI() && w1ptr->mBlockSize % 8 == 0) {
- if (seq <= 32) {
- using GemmKernel = custom::wrapper::kblock::avx512_vnni::GemmSKernelDynamicS4KBlockNext;
- using SiluGemmKernel = custom::wrapper::kblock::avx512_vnni::SiluGemmSKernelDynamicS4KBlockNext;
- using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
- using DQuantParam = GemmKernel::PrologueA::QParam;
- static FusedInter finter;
- int lda = fin;
- int ldtmp1 = fmid;
- int ldtmp2 = fmid;
- int ldo = fout;
- auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1ptr->mBlockSize);
- quanA1.assign((int8_t*)workspace);
- auto offset = workspace == NULL ? 0 : quanA1.mSize;
- auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2ptr->mBlockSize);
- quanA2.assign((int8_t*)workspace + offset);
- ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
- w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
- } else {
- using GemmKernel = custom::wrapper::kblock::avx512_vnni::GemmSKernelDynamicS4KBlock;
- using SiluGemmKernel = custom::wrapper::kblock::avx512_vnni::SiluGemmSKernelDynamicS4KBlock;
+ } else if (w1ptr->mBlockSize % 8 == 0) {
+ if (_cd->AVX512_VNNI()) {
+ if (seq <= 32) {
+ using GemmKernel = custom::wrapper::kblock::avx512_vnni::GemmSKernelDynamicS4KBlockNext;
+ using SiluGemmKernel = custom::wrapper::kblock::avx512_vnni::SiluGemmSKernelDynamicS4KBlockNext;
+ using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
+ using DQuantParam = GemmKernel::PrologueA::QParam;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldtmp2 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1ptr->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2ptr->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
+ w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
+ } else {
+ using GemmKernel = custom::wrapper::kblock::avx512_vnni::GemmSKernelDynamicS4KBlock;
+ using SiluGemmKernel = custom::wrapper::kblock::avx512_vnni::SiluGemmSKernelDynamicS4KBlock;
+ using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
+ using DQuantParam = GemmKernel::PrologueA::QParam;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldtmp2 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1ptr->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2ptr->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
+ w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
+ }
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::GemmSKernelDynamicS4KBlock;
+ using SiluGemmKernel = custom::wrapper::kblock::avx_vnni::SiluGemmSKernelDynamicS4KBlock;
using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
using DQuantParam = GemmKernel::PrologueA::QParam;
static FusedInter finter;
@@ -182,23 +203,42 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s8fp32_f32f32_forward(float* activation, SS8Fp3
quanA2.assign((int8_t*)workspace + offset);
ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
- } else if (_cd->AVX512_VNNI() && w1ptr->mBlockSize % 4 == 0) {
- using GemmKernel = custom::wrapper::kblock::avx512_vnni::GemmSKernelDynamicS8KBlock;
- using SiluGemmKernel = custom::wrapper::kblock::avx512_vnni::SiluGemmSKernelDynamicS8KBlock;
- using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
- using DQuantParam = GemmKernel::PrologueA::QParam;
- static FusedInter finter;
- int lda = fin;
- int ldtmp1 = fmid;
- int ldtmp2 = fmid;
- int ldo = fout;
- auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1ptr->mBlockSize);
- quanA1.assign((int8_t*)workspace);
- auto offset = workspace == NULL ? 0 : quanA1.mSize;
- auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2ptr->mBlockSize);
- quanA2.assign((int8_t*)workspace + offset);
- ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
- w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
+ } else if (w1ptr->mBlockSize % 4 == 0) {
+ if (_cd->AVX512_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx512_vnni::GemmSKernelDynamicS8KBlock;
+ using SiluGemmKernel = custom::wrapper::kblock::avx512_vnni::SiluGemmSKernelDynamicS8KBlock;
+ using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
+ using DQuantParam = GemmKernel::PrologueA::QParam;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldtmp2 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1ptr->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2ptr->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
+ w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::GemmSKernelDynamicS8KBlock;
+ using SiluGemmKernel = custom::wrapper::kblock::avx_vnni::SiluGemmSKernelDynamicS8KBlock;
+ using FusedInter = custom::wrapper::transformer::FFNFusedInterface;
+ using DQuantParam = GemmKernel::PrologueA::QParam;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldtmp2 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1ptr->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2ptr->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1ptr,
+ w2ptr, w3ptr, tmp1, ldtmp1, output, ldo, NULL, tmp2, ldtmp2, NULL});
+ }
}
} else if (w1ptr->mCoreType == GcCompFp32::TYPE) {
if (_cd->AVX512F()) {
@@ -296,6 +336,36 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s8fp32pern_f32f32_forward(float* activation, SS
{tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1ptr->mSPtr, quanA1.mZPtr, w1ptr->mRPtr},
{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2ptr->mSPtr, quanA2.mZPtr, w2ptr->mRPtr},
{tmp2, ldtmp2, quanA1.mCStep, quanA1.mSPtr, w3ptr->mSPtr, quanA1.mZPtr, w3ptr->mRPtr}});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::GemmDynamicS8PerN;
+ using SiluGemmKernel = custom::wrapper::kblock::avx_vnni::SiluGemmDynamicS8PerN;
+ using FusedInter = custom::wrapper::transformer::FFNFusedInterfacePerN;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldtmp2 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq,
+ fin,
+ fmid,
+ fout,
+ activation,
+ lda,
+ &quanA1,
+ tmp1,
+ ldtmp1,
+ &quanA2,
+ w1ptr,
+ w2ptr,
+ w3ptr,
+ {tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1ptr->mSPtr, quanA1.mZPtr, w1ptr->mRPtr},
+ {output, ldo, quanA2.mCStep, quanA2.mSPtr, w2ptr->mSPtr, quanA2.mZPtr, w2ptr->mRPtr},
+ {tmp2, ldtmp2, quanA1.mCStep, quanA1.mSPtr, w3ptr->mSPtr, quanA1.mZPtr, w3ptr->mRPtr}});
}
}
return ret;
@@ -358,6 +428,36 @@ JBLAS_CODE jblas_fusion_FFN_SiLu_s4clipfp32pern_f32f32_forward(float* activation
{tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1ptr->mSPtr, quanA1.mZPtr, w1ptr->mRPtr},
{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2ptr->mSPtr, quanA2.mZPtr, w2ptr->mRPtr},
{tmp2, ldtmp2, quanA1.mCStep, quanA1.mSPtr, w3ptr->mSPtr, quanA1.mZPtr, w3ptr->mRPtr}});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::GemmDynamicS4ClipPerN;
+ using SiluGemmKernel = custom::wrapper::kblock::avx_vnni::SiluGemmDynamicS4ClipPerN;
+ using FusedInter = custom::wrapper::transformer::FFNFusedInterfacePerN;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldtmp2 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq,
+ fin,
+ fmid,
+ fout,
+ activation,
+ lda,
+ &quanA1,
+ tmp1,
+ ldtmp1,
+ &quanA2,
+ w1ptr,
+ w2ptr,
+ w3ptr,
+ {tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1ptr->mSPtr, quanA1.mZPtr, w1ptr->mRPtr},
+ {output, ldo, quanA2.mCStep, quanA2.mSPtr, w2ptr->mSPtr, quanA2.mZPtr, w2ptr->mRPtr},
+ {tmp2, ldtmp2, quanA1.mCStep, quanA1.mSPtr, w3ptr->mSPtr, quanA1.mZPtr, w3ptr->mRPtr}});
}
}
return ret;
@@ -407,7 +507,8 @@ bool jblas_fusion_FFN_GeLu_f32f32_support(void* w1ptr, void* w2ptr, int seq, int
w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
constexpr jblas::gemm::GemmCoreType cores[] = {
jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
- jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48};
+ jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48,
+ jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK};
constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]);
support = contains(w1tmp->mCoreType, cores, EleNum);
support &= hasISA(cores, EleNum);
@@ -460,6 +561,21 @@ JBLAS_CODE jblas_fusion_FFN_GeLu_s4fp32_f32f32_forward(float* activation, SS4Fp3
quanA2.assign((int8_t*)workspace + offset);
ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1tmp, w2tmp, tmp1,
ldtmp1, output, ldo});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::GemmSKernelDynamicS4KBlock;
+ using GeluGemmKernel = custom::wrapper::kblock::avx_vnni::GeluGemmSKernelDynamicS4KBlock;
+ using FusedInter = custom::wrapper::transformer::GeluFusedInterface;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1tmp->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2tmp->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1tmp, w2tmp, tmp1,
+ ldtmp1, output, ldo});
}
}
return ret;
@@ -501,6 +617,21 @@ JBLAS_CODE jblas_fusion_FFN_GeLu_s8fp32_f32f32_forward(float* activation, SS8Fp3
quanA2.assign((int8_t*)workspace + offset);
ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1tmp, w2tmp, tmp1,
ldtmp1, output, ldo});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::GemmSKernelDynamicS8KBlock;
+ using GeluGemmKernel = custom::wrapper::kblock::avx_vnni::GeluGemmSKernelDynamicS8KBlock;
+ using FusedInter = custom::wrapper::transformer::GeluFusedInterface;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1tmp->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2tmp->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout, activation, lda, &quanA1, tmp1, ldtmp1, &quanA2, w1tmp, w2tmp, tmp1,
+ ldtmp1, output, ldo});
}
}
return ret;
@@ -534,10 +665,12 @@ bool jblas_fusion_FFN_Add_GeLu_f32f32_support(void* w1ptr, void* w2ptr, int seq,
if (sameKernel) {
if (w1tmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32) ||
w1tmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
- constexpr jblas::gemm::GemmCoreType cores[] = {
- jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK, jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
- jblas::gemm::GemmCoreType::AVX512F_8x48, jblas::gemm::GemmCoreType::AMX_BF16_16x48,
- jblas::gemm::GemmCoreType::AVX2_2X48};
+ constexpr jblas::gemm::GemmCoreType cores[] = {jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK,
+ jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
+ jblas::gemm::GemmCoreType::AVX512F_8x48,
+ jblas::gemm::GemmCoreType::AMX_BF16_16x48,
+ jblas::gemm::GemmCoreType::AVX2_2X48,
+ jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK};
constexpr size_t EleNum = sizeof(cores) / sizeof(cores[0]);
support = contains(w1tmp->mCoreType, cores, EleNum);
support &= hasISA(cores, EleNum);
@@ -605,6 +738,24 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4fp32_f32f32_forward(float* activation, SS
ldtmp1, &quanA2, w1tmp, w2tmp,
tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1,
output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::AddGemmSKernelDynamicS4KBlock;
+ using GeluGemmKernel = custom::wrapper::kblock::avx_vnni::AddGeluGemmSKernelDynamicS4KBlock;
+ using FusedInter = custom::wrapper::transformer::GeluFusedInterface;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldo = fout;
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1tmp->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2tmp->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout,
+ activation, lda, &quanA1, tmp1,
+ ldtmp1, &quanA2, w1tmp, w2tmp,
+ tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1,
+ output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
}
} else if (w1tmp->mCoreType == GcCompFp32::TYPE) {
if (_cd->AVX512F()) {
@@ -695,6 +846,28 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s8fp32_f32f32_forward(float* activation, SS
ldtmp1, &quanA2, w1tmp, w2tmp,
tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1,
output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::AddGemmSKernelDynamicS8KBlock;
+ using GeluGemmKernel = custom::wrapper::kblock::avx_vnni::AddGeluGemmSKernelDynamicS8KBlock;
+ using FusedInter = custom::wrapper::transformer::GeluFusedInterface;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldo = fout;
+ // FusedInter::Arguments::paramA paramA={activation, lda};
+ // FusedInter::Arguments::paramW1 paramW1={w1tmp};
+ // FusedInter::Arguments::paramW2 paramW2={w2tmp};
+ // FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1};
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin, w1tmp->mBlockSize);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid, w2tmp->mBlockSize);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq, fin, fmid, fout,
+ activation, lda, &quanA1, tmp1,
+ ldtmp1, &quanA2, w1tmp, w2tmp,
+ tmp1, b1ptr, ldtmp1, broadcast_bias ? 0 : ldtmp1,
+ output, b2ptr, ldo, broadcast_bias ? 0 : ldo});
}
} else if (w1tmp->mCoreType == GcCompFp32::TYPE) {
if (_cd->AVX512F()) {
@@ -807,6 +980,41 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s8fp32pern_f32f32_forward(float* activation
{{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2tmp->mSPtr, quanA2.mZPtr, w2tmp->mRPtr},
b2ptr,
broadcast_bias ? 0 : ldo}});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::AddGemmDynamicS8PerN;
+ using GeluGemmKernel = custom::wrapper::kblock::avx_vnni::AddGeluGemmDynamicS8PerN;
+ using FusedInter = custom::wrapper::transformer::GeluFusedInterfacePerN;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldo = fout;
+ // FusedInter::Arguments::paramA paramA={activation, lda};
+ // FusedInter::Arguments::paramW1 paramW1={w1tmp};
+ // FusedInter::Arguments::paramW2 paramW2={w2tmp};
+ // FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1};
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq,
+ fin,
+ fmid,
+ fout,
+ activation,
+ lda,
+ &quanA1,
+ tmp1,
+ ldtmp1,
+ &quanA2,
+ w1tmp,
+ w2tmp,
+ {{tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1tmp->mSPtr, quanA1.mZPtr, w1tmp->mRPtr},
+ b1ptr,
+ broadcast_bias ? 0 : ldtmp1},
+ {{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2tmp->mSPtr, quanA2.mZPtr, w2tmp->mRPtr},
+ b2ptr,
+ broadcast_bias ? 0 : ldo}});
}
}
return ret;
@@ -898,6 +1106,41 @@ JBLAS_CODE jblas_fusion_FFN_Add_GeLu_s4clipfp32pern_f32f32_forward(float* activa
{{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2tmp->mSPtr, quanA2.mZPtr, w2tmp->mRPtr},
b2ptr,
broadcast_bias ? 0 : ldo}});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = custom::wrapper::kblock::avx_vnni::AddGemmDynamicS4ClipPerN;
+ using GeluGemmKernel = custom::wrapper::kblock::avx_vnni::AddGeluGemmDynamicS4ClipPerN;
+ using FusedInter = custom::wrapper::transformer::GeluFusedInterfacePerN;
+ static FusedInter finter;
+ int lda = fin;
+ int ldtmp1 = fmid;
+ int ldo = fout;
+ // FusedInter::Arguments::paramA paramA={activation, lda};
+ // FusedInter::Arguments::paramW1 paramW1={w1tmp};
+ // FusedInter::Arguments::paramW2 paramW2={w2tmp};
+ // FusedInter::Arguments::param1 param1={tmp1, b1ptr, ldtmp1, ldtmp1};
+ auto quanA1 = finter.getActivationPtr()->createStorage(seq, fin);
+ quanA1.assign((int8_t*)workspace);
+ auto offset = workspace == NULL ? 0 : quanA1.mSize;
+ auto quanA2 = finter.getActivationPtr()->createStorage(seq, fmid);
+ quanA2.assign((int8_t*)workspace + offset);
+ ret = finter.compute({seq,
+ fin,
+ fmid,
+ fout,
+ activation,
+ lda,
+ &quanA1,
+ tmp1,
+ ldtmp1,
+ &quanA2,
+ w1tmp,
+ w2tmp,
+ {{tmp1, ldtmp1, quanA1.mCStep, quanA1.mSPtr, w1tmp->mSPtr, quanA1.mZPtr, w1tmp->mRPtr},
+ b1ptr,
+ broadcast_bias ? 0 : ldtmp1},
+ {{output, ldo, quanA2.mCStep, quanA2.mSPtr, w2tmp->mSPtr, quanA2.mZPtr, w2tmp->mRPtr},
+ b2ptr,
+ broadcast_bias ? 0 : ldo}});
}
}
return ret;
diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp
index fe7cd60453c..3c0c49e6643 100644
--- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp
+++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/ip_fusion_qkv.cpp
@@ -61,6 +61,34 @@ using QKVGemmDynamicS4ClipFp32PerN = jblas::wrapper::transformer::QKVGemmInterfa
jblas::epilogue::gemm::ZpDequantInt32ToFp32>,
jblas::utils::parallel::Parallel2DGemm>;
} // namespace avx512_vnni
+namespace avx_vnni {
+static JBLAS_ISA constexpr DefaultISA = JblasAVX_VNNI;
+using QKVGemmDynamicS4Fp32KBlock = jblas::wrapper::transformer::QKVGemmInterfaceKBlockPackWeight<
+ jblas::wrapper::gemm_kblock::GemmLauncherKBlock,
+ jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
+using QKVGemmDynamicS8Fp32KBlock = jblas::wrapper::transformer::QKVGemmInterfaceKBlockPackWeight<
+ jblas::wrapper::gemm_kblock::GemmLauncherKBlock<
+ DefaultISA, jblas::gemm::kblock::GemmCore_Row_NN_1x48_AVX_VNNI_KBLOCK,
+ jblas::prologue::gemm::ActivationF32U8KBlockQuantize,
+ jblas::prologue::weight_comp::gemm_kblcok::WeightS8ScaleFp32, jblas::epilogue::gemm::AccumulatorWriteBackFp32>,
+ jblas::utils::parallel::Parallel2DGemmKBlockFixed>;
+using QKVGemmDynamicS8Fp32PerN = jblas::wrapper::transformer::QKVGemmInterfacePackWeightParallelAB<
+ jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<
+ DefaultISA, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI, jblas::prologue::gemm::ActivationFp32AsymU8Quantize,
+ jblas::prologue::weight_comp::gemm_kblcok::WeightS8ScaleFp32PerChannelN,
+ jblas::epilogue::gemm::ZpDequantInt32ToFp32>,
+ jblas::utils::parallel::Parallel2DGemm>;
+using QKVGemmDynamicS4ClipFp32PerN = jblas::wrapper::transformer::QKVGemmInterfacePackWeightParallelAB<
+ jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight<
+ DefaultISA, jblas::gemm::GemmCore_Row_NN_2x48_AVX_VNNI, jblas::prologue::gemm::ActivationFp32AsymU8Quantize,
+ jblas::prologue::weight_comp::gemm_kblcok::WeightS4ClipScaleFp32PerN,
+ jblas::epilogue::gemm::ZpDequantInt32ToFp32>,
+ jblas::utils::parallel::Parallel2DGemm>;
+} // namespace avx_vnni
namespace amx_int8 {
static JBLAS_ISA constexpr DefaultISA = JblasAMX_INT8;
using QKVGemmDynamicS4Fp32KBlock = jblas::wrapper::transformer::QKVGemmInterfaceKBlockPackWeight<
@@ -141,25 +169,43 @@ JBLAS_CODE jblas_QKVs4fp32_f32f32_forward(float* activation, SS4Fp32* wqptr, SS4
auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, wqptr->mBlockSize);
quanA.assign((int8_t*)workspace);
ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
- } else if (_cd->AVX512_VNNI() && wqptr->mBlockSize % 8 == 0) {
- if (_m <= 32) {
- using GemmKernel = transformer::avx512_vnni::QKVGemmDynamicS4Fp32KBlockNext;
- static GemmKernel kernel;
- GemmKernel::WeightType::Param wparams[3]{
- wqptr,
- wkptr,
- wvptr,
- };
- GemmKernel::CParam oparams[3]{
- {output, ldo},
- {output + _m * _n, ldo},
- {output + 2 * _m * _n, ldo},
- };
- auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, wqptr->mBlockSize);
- quanA.assign((int8_t*)workspace);
- ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
- } else {
- using GemmKernel = transformer::avx512_vnni::QKVGemmDynamicS4Fp32KBlock;
+ } else if (wqptr->mBlockSize % 8 == 0) {
+ if (_cd->AVX512_VNNI()) {
+ if (_m <= 32) {
+ using GemmKernel = transformer::avx512_vnni::QKVGemmDynamicS4Fp32KBlockNext;
+ static GemmKernel kernel;
+ GemmKernel::WeightType::Param wparams[3]{
+ wqptr,
+ wkptr,
+ wvptr,
+ };
+ GemmKernel::CParam oparams[3]{
+ {output, ldo},
+ {output + _m * _n, ldo},
+ {output + 2 * _m * _n, ldo},
+ };
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, wqptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
+ } else {
+ using GemmKernel = transformer::avx512_vnni::QKVGemmDynamicS4Fp32KBlock;
+ static GemmKernel kernel;
+ GemmKernel::WeightType::Param wparams[3]{
+ wqptr,
+ wkptr,
+ wvptr,
+ };
+ GemmKernel::CParam oparams[3]{
+ {output, ldo},
+ {output + _m * _n, ldo},
+ {output + 2 * _m * _n, ldo},
+ };
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k, wqptr->mBlockSize);
+ quanA.assign((int8_t*)workspace);
+ ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
+ }
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = transformer::avx_vnni::QKVGemmDynamicS4Fp32KBlock;
static GemmKernel kernel;
GemmKernel::WeightType::Param wparams[3]{
wqptr,
@@ -260,6 +306,22 @@ JBLAS_CODE jblas_QKVs8fp32pern_f32f32_forward(float* activation, SS8Fp32PerN* wq
{output + 2 * _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wvptr->mSPtr, quanA.mZPtr, wvptr->mRPtr},
};
ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = transformer::avx_vnni::QKVGemmDynamicS8Fp32PerN;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k);
+ quanA.assign((int8_t*)workspace);
+ GemmKernel::WeightType::Param wparams[3]{
+ wqptr,
+ wkptr,
+ wvptr,
+ };
+ GemmKernel::CParam oparams[3]{
+ {output, ldo, quanA.mCStep, quanA.mSPtr, wqptr->mSPtr, quanA.mZPtr, wqptr->mRPtr},
+ {output + _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wkptr->mSPtr, quanA.mZPtr, wkptr->mRPtr},
+ {output + 2 * _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wvptr->mSPtr, quanA.mZPtr, wvptr->mRPtr},
+ };
+ ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
}
}
return ret;
@@ -303,6 +365,22 @@ JBLAS_CODE jblas_QKVs4clipfp32pern_f32f32_forward(float* activation, SS4Fp32PerN
{output + 2 * _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wvptr->mSPtr, quanA.mZPtr, wvptr->mRPtr},
};
ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
+ } else if (_cd->AVX_VNNI()) {
+ using GemmKernel = transformer::avx_vnni::QKVGemmDynamicS4ClipFp32PerN;
+ static GemmKernel kernel;
+ auto quanA = kernel.getActivationPtr()->createStorage(_m, _k);
+ quanA.assign((int8_t*)workspace);
+ GemmKernel::WeightType::Param wparams[3]{
+ wqptr,
+ wkptr,
+ wvptr,
+ };
+ GemmKernel::CParam oparams[3]{
+ {output, ldo, quanA.mCStep, quanA.mSPtr, wqptr->mSPtr, quanA.mZPtr, wqptr->mRPtr},
+ {output + _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wkptr->mSPtr, quanA.mZPtr, wkptr->mRPtr},
+ {output + 2 * _m * _n, ldo, quanA.mCStep, quanA.mSPtr, wvptr->mSPtr, quanA.mZPtr, wvptr->mRPtr},
+ };
+ ret = kernel.compute({_m, _n, _k, 3, activation, lda, &quanA, wparams, oparams, NULL});
}
}
return ret;
diff --git a/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp b/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp
index 87d5f936614..d2d93b4e8f1 100644
--- a/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp
+++ b/intel_extension_for_transformers/llm/runtime/graph/core/layers/jblas_common.hpp
@@ -122,10 +122,12 @@ using GcCompFp16 = jblas::gemm::GemmCore_Row_NN_8x64_AVX512_FP16;
using GcCompInt8 = jblas::gemm::GemmCore_Row_NN_8x48_AVX512_VNNI;
constexpr jblas::gemm::GemmCoreType GcCompInt8KBlockSet[] = {jblas::gemm::GemmCoreType::AMX_INT8_16x48_KBLOCK,
- jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK};
+ jblas::gemm::GemmCoreType::AVX512_VNNI_3x48_KBLOCK,
+ jblas::gemm::GemmCoreType::AVX_VNNI_1x48_KBLOCK};
constexpr jblas::gemm::GemmCoreType GcCompInt8Set[] = {jblas::gemm::GemmCoreType::AMX_INT8_16x48_SS,
- jblas::gemm::GemmCoreType::AVX512_VNNI_8x48};
+ jblas::gemm::GemmCoreType::AVX512_VNNI_8x48,
+ jblas::gemm::GemmCoreType::AVX_VNNI_2x48};
namespace custom {
namespace epilogue {
@@ -894,6 +896,40 @@ using AddGeluGemmDynamicS4ClipPerN = DynamicGemmPerN;
} // namespace avx512_vnni
+namespace avx_vnni {
+template class ProB, template class Epi>
+using DynamicGemm =
+ jblas::wrapper::gemm_kblock::GemmLauncherKBlock;
+
+template class ProB, template class Epi>
+using DynamicGemmPerN =
+ jblas::wrapper::gemm_pack_weight::GemmLauncherPackWeight;
+using GemmSKernelDynamicS4KBlock = DynamicGemm;
+using SiluGemmSKernelDynamicS4KBlock = DynamicGemm;
+using GeluGemmSKernelDynamicS4KBlock = DynamicGemm;
+using AddGeluGemmSKernelDynamicS4KBlock = DynamicGemm;
+using AddGemmSKernelDynamicS4KBlock = DynamicGemm;
+
+using GemmSKernelDynamicS8KBlock = DynamicGemm;
+using SiluGemmSKernelDynamicS8KBlock = DynamicGemm;
+using GeluGemmSKernelDynamicS8KBlock = DynamicGemm;
+using AddGeluGemmSKernelDynamicS8KBlock = DynamicGemm;
+using AddGemmSKernelDynamicS8KBlock = DynamicGemm;
+
+using GemmDynamicS8PerN = DynamicGemmPerN;
+using SiluGemmDynamicS8PerN = DynamicGemmPerN;
+using AddGeluGemmDynamicS8PerN = DynamicGemmPerN;
+using AddGemmDynamicS8PerN = DynamicGemmPerN;
+
+using GemmDynamicS4ClipPerN = DynamicGemmPerN;
+using SiluGemmDynamicS4ClipPerN = DynamicGemmPerN;
+using AddGeluGemmDynamicS4ClipPerN = DynamicGemmPerN;
+using AddGemmDynamicS4ClipPerN = DynamicGemmPerN;
+} // namespace avx_vnni
namespace amx_int8 {
template class ProB, template class Epi>
using DynamicGemm = jblas::wrapper::gemm_kblock::GemmSLauncherKBlockPackWeight<