Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] refactor itrex backend based on the latest Jblas (#769)
Browse files Browse the repository at this point in the history
Co-authored-by: luoyu-intel <yu.luo@intel.com>
Co-authored-by: Ding, Yi1 <yi1.ding@intel.com>
Co-authored-by: zhenwei-intel <zhenwei.liu@intel.com>
Co-authored-by: yuchengliu1 <yucheng.liu@intel.com>
Co-authored-by: Meng, Hengyu <hengyu.meng@intel.com>
  • Loading branch information
6 people authored Dec 13, 2023
1 parent c087c74 commit 43e30bc
Show file tree
Hide file tree
Showing 63 changed files with 13,774 additions and 12,749 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/script/formatScan/cpplint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ log_path=${log_dir}/cpplint.log
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/compile 2>&1 | tee ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/executor 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/deprecated/test 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/application 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/library/kernels 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/models 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/llm/runtime/graph/vectors 2>&1 | tee -a ${log_path}
cpplint --extensions cpp,hpp --filter=-build/include_subdir,-build/header_guard --recursive --quiet --linelength=120 ${REPO_DIR}/intel_extension_for_transformers/operator/csrc 2>&1 | tee -a ${log_path}
if [[ ! -f ${log_path} ]] || [[ $(grep -c "Total errors found:" ${log_path}) != 0 ]]; then
exit 1
fi
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Language: Cpp
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 120
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SortIncludes: false
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <cstddef>
#include <type_traits>

#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"

Expand Down Expand Up @@ -50,6 +49,21 @@ class JitBase : protected Xbyak::CodeGenerator {
#endif
}

void padto_le(const Xbyak::Reg64& _src, int padding) {
// _src=_src/padding*padding
if (padding == 1) {
return;
}
for (int i = 1; i < 16; i++) {
if ((1 << i) == padding) {
shr(_src, i);
shl(_src, i);
return;
}
}
assert(0);
}

void generate_Nbitsmask(const Xbyak::Opmask& _msk, const Xbyak::Reg64& _pos, const Xbyak::Address& _total,
const Xbyak::Reg64& _tmp, const Xbyak::Reg64& _tmp1, int N) {
inLocalLabel();
Expand All @@ -59,9 +73,9 @@ class JitBase : protected Xbyak::CodeGenerator {
jb(".maskflag");
cmp(_tmp, 0);
jl(".zeroflag");
uint64_t allmask = ((uint64_t)1 << N) - 1;
uint64_t allmask = (static_cast<uint64_t>(1) << N) - 1;
if (N == 64) {
allmask = (uint64_t)-1;
allmask = static_cast<uint64_t>(-1);
}
mov(_tmp, allmask);
kmovq(_msk, _tmp);
Expand All @@ -87,13 +101,16 @@ class JitBase : protected Xbyak::CodeGenerator {
class JitAvx : protected JitBase {
protected:
static int constexpr VBits = 256;
static int constexpr VecBytes = VBits / 8;
static int constexpr RegCount = 16;
typedef Xbyak::Ymm vreg_t;
};

class JitAvx2 : protected JitAvx {
protected:
static int constexpr VBits = 256;
typedef Xbyak::Ymm vreg_t;
void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxor(x1, x2, op); }

void loadbf16_f32(const Xbyak::Ymm& dst, const Xbyak::Address& addr) {
vpmovzxwd(dst, addr);
Expand All @@ -104,8 +121,12 @@ class JitAvx2 : protected JitAvx {
class JitAvx512f : protected JitAvx2 {
protected:
static int constexpr VBits = 512;
static int constexpr VecBytes = VBits / 8;
static int constexpr RegCount = 32;
typedef Xbyak::Zmm vreg_t;

void vxor(const vreg_t& x1, const vreg_t& x2, const Xbyak::Operand& op) { vpxorq(x1, x2, op); }

void interleave_2rows_4regs(Xbyak::Zmm* src_2regs, Xbyak::Zmm* tmp_2reg) {
vpunpcklwd(tmp_2reg[0], src_2regs[0], src_2regs[1]);
vpunpckhwd(tmp_2reg[1], src_2regs[0], src_2regs[1]);
Expand Down Expand Up @@ -192,18 +213,20 @@ class JitAvx512f : protected JitAvx2 {
}
};

class JitAvx512_bf16 : protected JitAvx512f {};

class JitAvx512_fp16 : protected JitAvx512f {};

class JitAvx512vnni : protected JitAvx512f {
protected:
void vpdpbusds_evex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
vpdpbusds(x1, x2, op, Xbyak::EvexEncoding);
}
};

class JitAvxvnni : protected JitAvx2 {
protected:
void vpdpbusds_vex(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
void vpdpbusds_(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2, const Xbyak::Operand& op) {
vpdpbusds(x1, x2, op, Xbyak::VexEncoding);
}
};
Expand All @@ -216,6 +239,15 @@ class JitAmxtile : protected JitAvx512f {
uint16_t colb[16];
uint8_t rows[16];
};
static int constexpr TileCount = 8;

typedef long long (*configure_t)(void*);

static void generate_config(Xbyak::CodeGenerator* g) {
Xbyak::util::StackFrame st(g, 1, 0, 0);
auto& parambase = st.p[0];
g->ldtilecfg(g->ptr[parambase]);
}

static void configure_tiles(tileconfig_t& tc, int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum,
int CNum) {
Expand All @@ -224,19 +256,19 @@ class JitAmxtile : protected JitAvx512f {
// Configure C tiles
int t = 0;
for (; t < CNum; ++t) {
tc.rows[t] = uint8_t(TILE_M);
tc.colb[t] = uint16_t(TILE_N * 4);
tc.rows[t] = static_cast<uint8_t>(TILE_M);
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
}
// Configure A tiles
for (; t < CNum + ANum; ++t) {
tc.rows[t] = uint8_t(TILE_M);
tc.colb[t] = uint16_t(TILE_K * elesize);
tc.rows[t] = static_cast<uint8_t>(TILE_M);
tc.colb[t] = static_cast<uint16_t>(TILE_K * elesize);
}
// Configure B tile. B effectively has 64 rows and 16 columns.
int kpack = 4 / elesize;
for (; t < CNum + ANum + BNum; ++t) {
tc.rows[t] = uint8_t(TILE_K / kpack);
tc.colb[t] = uint16_t(TILE_N * 4);
tc.rows[t] = static_cast<uint8_t>(TILE_K / kpack);
tc.colb[t] = static_cast<uint16_t>(TILE_N * 4);
}
}
};
Expand Down
115 changes: 69 additions & 46 deletions intel_extension_for_transformers/llm/library/jblas/jblas/jit_blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,59 +15,82 @@
#include <stdint.h>
enum JBLAS_CODE {
JblasSuccess = 0,
JblasInvalidParam = -1,
JblasInvalidISA = -2,
JblasRuntimeError = -3,
JblasNotSupport = -4,
JblasInvalidParam = 1,
JblasInvalidISA = 2,
JblasRuntimeError = 4,
JblasNotSupport = 8,
};
enum JBLAS_ISA {
JblasNoSIMD = 10,
JblasAVX = 11,
JblasAVX2 = 12,
JblasAVX_VNNI = 13,
JblasAVX512F = 14,
JblasAVX512_VNNI = 15,
JblasAMX_BF16 = 16,
JblasAMX_INT8 = 17,
JblasAVX512_FP16 = 18,
enum JBLAS_ISA : uint8_t {
JblasNoSIMD = 0,
JblasAVX,
JblasAVX2,
JblasAVX_VNNI,
JblasAVX512F,
JblasAVX512_VNNI,
JblasAMX_BF16,
JblasAMX_INT8,
JblasAVX512_FP16,
JblasAVX512_BF16,
};
enum JBLAS_DTYPE {
JblasF64 = 59,
JblasF32 = 60,
JblasBF16 = 61,
JblasS8 = 63,
JblasU8 = 64,
JblasF32F8 = 65,
};
enum JBLAS_FP8_ENCODING {
JblasFp8_e4m3 = 80,
JblasFp8_e5m2 = 81,
JblasFp8_e3m4 = 82,
enum class JBLAS_DTYPE : uint32_t {
EleBitsMask = 0xff,
EleBitsShift = 0,
EleBitsUndef = 0,
EleBits4 = 4,
EleBits8 = 8,
EleBits16 = 16,
EleBits32 = 32,
EleBits64 = 64,
TypeMask = 0xff00,
TypeShift = 8,
TypeFloat = 0 << TypeShift,
TypeInt = 1 << TypeShift,
SubTypeMask = 0xff0000,
SubTypeShift = 16,
SubType0 = 0 << SubTypeShift,
SubType1 = 1 << SubTypeShift,
SubType2 = 2 << SubTypeShift,
SubType3 = 3 << SubTypeShift,
F64 = EleBits64 | TypeFloat,
F32 = EleBits32 | TypeFloat,
F16 = EleBits16 | TypeFloat,
BF16 = EleBits16 | TypeFloat | SubType1,
F8_E4M3 = EleBits8 | TypeFloat,
F8_E5M2 = EleBits8 | TypeFloat | SubType1,
F8_E3M4 = EleBits8 | TypeFloat | SubType2,
F8_E8M0 = EleBits8 | TypeFloat | SubType3,
S8 = EleBits8 | TypeInt,
U8 = EleBits8 | TypeInt | SubType1,
S4_CLIP = EleBits4 | TypeInt,
S4_FULLRANGE = EleBits4 | TypeInt | SubType1,
F4_E2M1 = EleBits4 | TypeFloat,
F4_BNB = EleBits4 | TypeFloat | SubType1,
F4_NF4 = EleBits4 | TypeFloat | SubType2,
S32 = EleBits32 | TypeInt,
U32 = EleBits32 | TypeInt | SubType1,
};

enum JBLAS_LAYOUT { JblasRowMajor = 101, JblasColMajor = 102 };
enum JBLAS_TRANSPOSE {
JblasNoTrans = 111,
JblasTrans = 112,
JblasConjTrans = 113,
};
enum JBLAS_ELTWISEOP {
GELU,
SWISH,
TANH,
EXP,
LOW_PRECISION_EXP,
RELU,
LINEAR,
};
enum JBLAS_F4_TYPE {
F4_UNDEF,
FP4_BNB,
FP4_E2M1,
NF4,
};
enum JBLAS_SIGN_INT_TYPE {
S8,
S4_CLIP,
S4_FULLRANGE,
S4_UNDEF,
enum JBLAS_ELTWISEOP { GELU, SWISH, TANH, EXP, LOW_PRECISION_EXP, RELU, LINEAR };

enum class JBLAS_PROLOGUEB_IDS : uint32_t {
Undef = (uint32_t)-1,
Begin = 0,
NormalBegin = Begin,
WeightPack = NormalBegin,
NormalEnd,
KBlockBegin = NormalEnd,
WeightKBlockNInteger = KBlockBegin,
WeightKBlockNFloat,
WeightKBlockS8,
WeightKBlockS4,
WeightKBlockF4,
WeightKBlockF8,
KBlockEnd,
End,
};
Loading

0 comments on commit 43e30bc

Please sign in to comment.