Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] Allow CompileBF16 on GCC11 (#655)
Browse files Browse the repository at this point in the history
* Allow CompileBF16 on GCC11
  • Loading branch information
DDEle authored Nov 10, 2023
1 parent 53e9133 commit d9e95da
Show file tree
Hide file tree
Showing 8 changed files with 393 additions and 52 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cpp-graph-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
- '.github/workflows/cpp-graph-test.yml'
- '.github/workflows/script/models/cpp_graph_inference.sh'
- 'intel_extension_for_transformers/llm/runtime/graph/**'
- 'intel_extension_for_transformers/llm/library/jblas/**'
- '!intel_extension_for_transformers/llm/runtime/graph/README.md'
workflow_dispatch:
inputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#define CompileAVX512F() (defined(__GNUC__) && (__GNUC__ >= 6))
#define CompileAVX2() (defined(__GNUC__) && (__GNUC__ >= 5))
#define CompileAMX() (defined(__GNUC__) && (__GNUC__ >= 11))
#define CompileBF16() (defined(__GNUC__) && (__GNUC__ >= 13))
#define CompileBF16() (defined(__GNUC__) && (__GNUC__ >= 11))
#define CompileFP16() (defined(__GNUC__) && (__GNUC__ >= 13))
#define CompileAMXBF16() (CompileAMX())
#define CompileAMXINT8() (CompileAMX())
Expand Down Expand Up @@ -77,14 +77,28 @@ struct bf16 {

#if CompileBF16()
#pragma GCC target("avx512vl", "avx512bf16")
explicit bf16(float vf32) : x(bit_cast<uint16_t>(_mm_cvtness_sbh(vf32))) {}
static uint16_t f32_to_bf16(float v) {
auto mm = _mm_load_ss(&v);
auto mm2 = _mm_cvtneps_pbh(mm);
uint16_t dst;
_mm_storeu_si16(reinterpret_cast<uint16_t*>(&dst), reinterpret_cast<__m128i>(mm2));
return dst;
}

explicit bf16(float vf32) : x(bit_cast<uint16_t>(f32_to_bf16(vf32))) {}
#else
explicit bf16(float vf32) { fromfloat(vf32); }
#endif

#if CompileBF16()
#pragma GCC target("avx512vl", "avx512bf16")
float tofloat() const { return static_cast<float>(bit_cast<__bf16>(this->x)); }
float tofloat() const {
auto mm = _mm_loadu_si16(&(this->x));
auto mm2 = _mm_bslli_si128(mm, 2);
float dst;
_mm_store_ss(&dst, reinterpret_cast<__m128>(mm2));
return dst;
}
#else
float tofloat() const {
bf16f32 tmp = {0.f};
Expand All @@ -103,7 +117,7 @@ struct bf16 {

void fromfloat(float _v) {
#if CompileBF16()
x = bit_cast<uint16_t>(_mm_cvtness_sbh(_v));
x = bit_cast<uint16_t>(f32_to_bf16(_v));
#else
bf16f32 tmp = {0.f};
tmp.f32 = _v;
Expand Down
106 changes: 106 additions & 0 deletions intel_extension_for_transformers/llm/library/jblas/jblas/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,112 @@ static inline JBLAS_CODE dequant_kblock_s8_f32(int8_t* srcptr, float* dstptr, in
kblock, NPad);
}

static inline JBLAS_CODE dequant_s32_fp32(const int32_t* srcptr, const int srcstep, float* dstptr, const int dststep,
const int row, const int col, const float* scaleA, const int ldsa,
const float* scaleB) {
int col8 = utils::padto_le(col, 8);
for (int irow = 0; irow < row; irow++) {
auto scale = scaleA[irow * ldsa];
auto valpha = _mm256_set1_ps(scale);
int icol = 0;
for (; icol < col8; icol += 8) {
auto vwscale = _mm256_loadu_ps(scaleB + icol);
auto vscale = _mm256_mul_ps(valpha, vwscale);
auto vsrcd = _mm256_loadu_si256((__m256i*)(srcptr + irow * srcstep + icol));
auto vsrc = _mm256_cvtepi32_ps(vsrcd);
vsrc = _mm256_mul_ps(vsrc, vscale);
_mm256_storeu_ps(dstptr + irow * dststep + icol, vsrc);
}
for (; icol < col; icol += 1) {
dstptr[irow * dststep + icol] = scale * scaleB[icol] * srcptr[irow * srcstep + icol];
}
}
return JblasSuccess;
}

static inline JBLAS_CODE remove_act_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zps,
float* scales, int lds, const float* reduce) {
int constexpr VLen = 8;
auto col8 = utils::padto_le(col, VLen);
for (int i = 0; i < row; i++) {
auto zpf = float(zps[i * lds]) * scales[i * lds];
int j = 0;
auto vzp = _mm256_set1_ps(-zpf);
for (; j < col8; j += VLen) {
auto vreduce = _mm256_loadu_ps(reduce + j);
auto vacc = _mm256_loadu_ps(&accptr[i * ldacc + j]);
vacc = _mm256_fmadd_ps(vzp, vreduce, vacc);
_mm256_storeu_ps(&accptr[i * ldacc + j], vacc);
}
if (j < col) {
for (; j < col; j++) {
accptr[i * ldacc + j] -= zpf * reduce[j];
}
}
}
return JblasSuccess;
}

static inline JBLAS_CODE remove_wei_zeropoint_bias(float* accptr, int ldacc, int row, int col, int8_t* zps,
float* scales, int lds, const float* reduce) {
int constexpr VLen = 8;
auto col8 = utils::padto_le(col, VLen);
const int32_t mask[] = {-1, -1, 0, 0};
for (int i = 0; i < row; i++) {
auto vreduce = _mm256_set1_ps(-reduce[i * lds]);
int j = 0;
for (; j < col8; j += VLen) {
auto vzp_s32 = _mm256_cvtepi8_epi32(_mm_maskload_epi32((const int*)(zps + j), _mm_loadu_si128((__m128i*)mask)));
auto vzp_f32 = _mm256_cvtepi32_ps(vzp_s32);
auto vzp = _mm256_mul_ps(vzp_f32, _mm256_loadu_ps(scales + j));
auto vacc = _mm256_loadu_ps(&accptr[i * ldacc + j]);
vacc = _mm256_fmadd_ps(vzp, vreduce, vacc);
_mm256_storeu_ps(&accptr[i * ldacc + j], vacc);
}
if (j < col) {
for (; j < col8; j++) {
accptr[i * ldacc + j] -= float(zps[j]) * scales[j] * reduce[i * lds];
}
}
}
return JblasSuccess;
}

static inline JBLAS_CODE remove_zeropoint_bias(float* accptr, int ldacc, int row, int col, uint8_t* zpa, int8_t* zpb,
float* scalea, float* scaleb, int lds, int k, const float* reducea,
const float* reduceb) {
int constexpr VLen = 8;
auto col8 = utils::padto_le(col, VLen);
auto vk = _mm256_set1_ps((float)(k));
const int32_t mask[] = {-1, -1, 0, 0};
for (int i = 0; i < row; i++) {
auto vreducea = _mm256_set1_ps(-reducea[i * lds]);
auto zpaf = float(zpa[i * lds]) * scalea[i * lds];
auto vzpa = _mm256_set1_ps(-zpaf);
int j = 0;
for (; j < col8; j += VLen) {
auto vzp_s32 = _mm256_cvtepi8_epi32(_mm_maskload_epi32((const int*)(zpb + j), _mm_loadu_si128((__m128i*)mask)));
auto vzp_f32 = _mm256_cvtepi32_ps(vzp_s32);
auto vzpb = _mm256_mul_ps(vzp_f32, _mm256_loadu_ps(scaleb + j));
auto vreduceb = _mm256_loadu_ps(reduceb + j);
auto vacc = _mm256_loadu_ps(&accptr[i * ldacc + j]);
vacc = _mm256_fmadd_ps(vzpa, vreduceb, vacc);
vacc = _mm256_fmadd_ps(vzpb, vreducea, vacc);
vzpb = _mm256_mul_ps(vzpb, vk);
vacc = _mm256_fmadd_ps(vzpa, vzpb, vacc);
_mm256_storeu_ps(&accptr[i * ldacc + j], vacc);
}
if (j < col) {
for (; j < col8; j++) {
accptr[i * ldacc + j] -= float(zpb[j]) * scaleb[j] * reducea[i * lds];
accptr[i * ldacc + j] -= zpaf * reduceb[j];
accptr[i * ldacc + j] -= zpaf * float(zpb[j]) * scaleb[j] * k;
}
}
}
return JblasSuccess;
}

template <JBLAS_SIGN_INT_TYPE S4_T>
static inline JBLAS_CODE decompress_s4_s8(utils::int4x2* srcptr, int8_t* dstptr, int row, int col, int ld_src,
int ld_dst) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ static inline JBLAS_CODE bf16_cvt_fp32_2D_write_back(const utils::bf16* src_ptr,
auto dst = dst_ptr + i * dst_step;
int j = 0;
for (; j < col_body; j += simd_proc_elt)
_mm512_storeu_ps(dst + j, _mm512_cvtpbh_ps((__m256bh)_mm256_loadu_ps(reinterpret_cast<float*>(src + j))));
_mm512_storeu_ps(
dst + j, //
reinterpret_cast<__m512>(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2)));
if (col_tail > 0)
_mm512_mask_storeu_ps(dst + j, tail_mask,
_mm512_cvtpbh_ps((__m256bh)_mm256_loadu_ps(reinterpret_cast<float*>(src + j))));
_mm512_mask_storeu_ps(
dst + j, tail_mask,
reinterpret_cast<__m512>(_mm512_bslli_epi128(_mm512_cvtepu16_epi32(_mm256_loadu_epi16(src + j)), 2)));
if (zeropadding && npadding) std::memset(dst + col, 0, npadding);
}
return JblasSuccess;
Expand Down
Loading

0 comments on commit d9e95da

Please sign in to comment.