Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

QBits adapt to the latest BesTLA #1535

Merged
merged 16 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/qbits.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ qbits.woq_linear(
activation, pack_weight, bias, output, n, add_bias, compute_type, weight_type, scale_type, asym)
```
please refer [here](https://github.com/intel/intel-extension-for-transformers/tree/main/intel_extension_for_transformers/transformers/llm/operator/csrc/qbits_ut) for more QBits operators usage.

## Pytorch version constrain
If user wants to use QBits, the Pytorch version must meet ITREX requirements, here are the constrains:

| ITREX version | Pytorch version |
| :-----------: | :-------------: |
| v1.4 | 2.2.0+cpu |
| v1.4.1 | 2.2.0+cpu |
3 changes: 3 additions & 0 deletions intel_extension_for_transformers/qbits/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ find_package(PythonLibs 3 REQUIRED)
endif()

include(FindOpenMP)
set(BTLA_ENABLE_OPENMP ON CACHE BOOL "BesTLA enable compiling OpenMP threading")
add_subdirectory(dispatcher)
add_subdirectory(../transformers/runtime/third_party/pybind11 pybind11)

file(GLOB HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
file(GLOB qbits_src ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp)

add_compile_options(-flto=auto)

# Link against LibTorch
pybind11_add_module(qbits_py ${qbits_src})
target_compile_features(qbits_py PRIVATE cxx_std_14)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ endif()

set_target_properties(bestla_dispatcher PROPERTIES POSITION_INDEPENDENTBTLA_CODE ON)
set_target_properties(bestla_dispatcher PROPERTIES LINKER_LANGUAGE CXX)
target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla::bestla)
target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla)
set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "")
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@

template <typename Param, typename DST_T, BTLA_ISA ISA_T>
inline BTLA_CODE alphabeta_dt_cvt_process(float* tmp_dst, const int cachestep, const int M_offset, const int N_offset,
const int M, const int N, const Param& _param) {
const int M, const int N, const Param& _param) {
auto DOffset = M_offset * _param.ldd + N_offset;
auto dptr = reinterpret_cast<float*>(_param.D) + DOffset;
bestla::kernel::wrapper::AlphaBetaF32F32::template forward<ISA_T>(_param.alpha, tmp_dst, cachestep, _param.beta, dptr,
_param.ldd, tmp_dst, cachestep, M, N);
_param.ldd, tmp_dst, cachestep, M, N);

auto COffset = M_offset * _param.ldc + N_offset;
auto cptr = reinterpret_cast<DST_T*>(_param.C) + COffset;
if constexpr (std::is_same_v<DST_T, float>) {
return bestla::kernel::wrapper::Memcpy2D::template forward<ISA_T, float, DST_T>(tmp_dst, cptr, M, N, cachestep,
_param.ldc, NULL);
_param.ldc, NULL);
}
if constexpr (std::is_same_v<DST_T, bestla::utils::bf16>) {
return bestla::kernel::wrapper::Memcpy2DFp32CvtBf16::template forward<ISA_T>(
Expand All @@ -47,8 +47,8 @@ class AlphaBetaProcess {
int ldc, ldd;
float alpha, beta;
};
BTLA_CODE forward(float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
const int N, const Param& _param, void* tmpcache = nullptr, size_t cachesize = -1) {
static BTLA_CODE forward(float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M,
const int N, const Param& _param, void* tmpcache = nullptr, size_t cachesize = -1) {
return alphabeta_dt_cvt_process<Param, DST_T, ISA_T>(cacheptr, cachestep, M_offset, N_offset, M, N, _param);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ struct woq_runtime_ctx {

static std::map<std::string, BTLA_DTYPE> wei2bestladt_map{{"int8", BTLA_DTYPE::S8},
{"int4_clip", BTLA_DTYPE::S4_CLIP},
{"int4_fullrange", BTLA_DTYPE::S4_FULLRANGE},
{"int3_clip", BTLA_DTYPE::S3_CLIP},
{"int2_clip", BTLA_DTYPE::S2_CLIP},
{"nf4", BTLA_DTYPE::F4_NF4},
{"fp4_e2m1_bnb", BTLA_DTYPE::F4_BNB},
{"fp4_e2m1", BTLA_DTYPE::F4_E2M1},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,26 @@ inline bool check_avx_vnni() { return bestla::device::CpuDevice::getInstance()->
inline bool check_avx512f() { return bestla::device::CpuDevice::getInstance()->AVX512F(); }
inline bool check_avx2() { return bestla::device::CpuDevice::getInstance()->AVX2(); }

class qbits_threading {
public:
static bestla::parallel::IThreading* get() {
GetCPUDevice();
static bestla::parallel::StdThreading OptmizedThreading;
static bestla::parallel::OMPThreading DefaultThreading;
if (!_cd->isHybrid()) {
return &DefaultThreading;
}
return &OptmizedThreading;
}

static void set_threads(int n_thread) { get()->set_threads(n_thread); }
};

class env_initer {
public:
env_initer() {
if (check_amx()) bestla::utils::request_perm_xtile_data();
qbits_threading::set_threads(bestla::device::CpuDevice::getInstance()->getThreads());
verbose = std::getenv("QBITS_VERBOSE") != nullptr;
FLAGS_caffe2_log_level = 0;
}
Expand All @@ -56,7 +72,7 @@ class Timer {
high_resolution_clock::time_point m_end;
};
static Timer timer;
static bestla::parallel::OMPThreading DefaultThreading(bestla::device::CpuDevice::getInstance()->getThreads());

string get_torch_dt_name(torch::Tensor* tensor);

} // namespace dispatcher_utils
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
set(NEURAL_SPEED_URL https://github.com/intel/neural-speed.git)
set(NEURAL_SPEED_TAG bestlav0.1)
set(NEURAL_SPEED_TAG 2f7943681e02c6e87a4c70c3925327f00194c78f)

FetchContent_Declare(
neural_speed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ void do_gemm(bestla_gemm_runtime_ctx* ctx) {
packw.assign(tmpbuf);
if (ctx->matB_trans) {
launcher.mProB.packWeightTranspose(ctx->n, ctx->k, {reinterpret_cast<DT*>(ctx->matB->data_ptr()), ctx->k, &packw},
&dispatcher_utils::DefaultThreading);
dispatcher_utils::qbits_threading::get());
} else {
launcher.mProB.packWeight(ctx->n, ctx->k, {reinterpret_cast<DT*>(ctx->matB->data_ptr()), ctx->n, &packw},
&dispatcher_utils::DefaultThreading);
dispatcher_utils::qbits_threading::get());
}
bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k);
typename Launcher::Param args{gp,
{reinterpret_cast<DT*>(ctx->matA->data_ptr()), ctx->k},
{reinterpret_cast<DT*>(ctx->matB->data_ptr()), ctx->n, &packw},
{reinterpret_cast<DT*>(ctx->matC->data_ptr()), ctx->n}};
bestla::parallel::GemmRun<Parallel>(launcher, args, &dispatcher_utils::DefaultThreading);
bestla::parallel::GemmRun<Parallel>(launcher, args, dispatcher_utils::qbits_threading::get());
bestla::utils::afree(tmpbuf);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ void execute_qpack(repack_quantized_weight_param* p, repack_quantized_weight_ctx
*(ctx->output) = torch::empty(qpackw.mSize, torch::kInt8);
qpackw.assign(ctx->output->data_ptr<int8_t>());
if (p->enable_act_shuffle)
ker.setShuffleIndices(ctx->g_idx->data_ptr<int>(), &qpackw, &dispatcher_utils::DefaultThreading);
ker.setShuffleIndices(ctx->g_idx->data_ptr<int>(), &qpackw, dispatcher_utils::qbits_threading::get());
ker.packQWeight(ctx->n, ctx->k, ctx->qweight->data_ptr<int8_t>(), ctx->n, ctx->scale->data_ptr<float>(),
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, &dispatcher_utils::DefaultThreading);
p->asym ? ctx->zp->data_ptr<int8_t>() : nullptr, &qpackw, dispatcher_utils::qbits_threading::get());
}

std::string get_dtype_str(BTLA_DTYPE dtype) {
Expand All @@ -41,8 +41,10 @@ std::string get_dtype_str(BTLA_DTYPE dtype) {
return "bf16";
case BTLA_DTYPE::S4_CLIP:
return "int4_clip";
case BTLA_DTYPE::S4_FULLRANGE:
return "int4_fullrange";
case BTLA_DTYPE::S3_CLIP:
return "int3_clip";
case BTLA_DTYPE::S2_CLIP:
return "int2_clip";
case BTLA_DTYPE::F4_NF4:
return "nf4";
case BTLA_DTYPE::F4_E2M1:
Expand All @@ -66,7 +68,6 @@ std::string get_dtype_str(BTLA_DTYPE dtype) {
std::string get_cmpt_str(bestla::gemm::CompType cmpt) {
using bestla::gemm::CompType;
switch (cmpt) {
case CompType::COMP_INT8_US_INT32:
case CompType::COMP_INT8_US_FP32:
return "int8";
case CompType::COMP_FP32:
Expand Down Expand Up @@ -182,43 +183,34 @@ torch::Tensor get_packw_info(torch::Tensor& packw, PACKW_ACQUIRE_TYPE ACQ_T) {
}

void bestla_packq(repack_quantized_weight_param* p, repack_quantized_weight_ctx* ctx, WOQ_TASK task) {
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int4_fullrange",
// TODO(zhe): elegant impl.
TORCH_CHECK(p->weight_type == "int8" || p->weight_type == "int4_clip" || p->weight_type == "int3_clip" ||
p->weight_type == "int2_clip",
"Qbits: only support Integer WOQ in PACKQ");

// NTILE & compute-dtype determine the padsize.
// in qbits:
// avx_vnni/avx512f_vnni/amx-int8 NTILE==48, compute-dtype=int8;
// avx2/avx512f NTILE==48, compute-dtype=fp32;
// amx-bf16 NTILE==64, compute-dtype=bf16.
if (task == WOQ_GET_PACKW_SIZE) {
if (p->compute_type == "int8")
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
if (p->compute_type == "fp32")
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
if (p->compute_type == "bf16")
return execute_qpack<bestla::gemm::HCoreRowNAmxbf16<64, 16>, BTLA_ISA::AMX_BF16>(p, ctx, task);
}

if (p->compute_type == "int8") {
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<48, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
if (dispatcher_utils::check_amx() && p->blocksize % bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAmxint8KBlock<64, 16>, BTLA_ISA::AMX_INT8>(p, ctx, task);
}
if (dispatcher_utils::check_avx512_vnni() &&
p->blocksize % bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx512vnniKBlock<48, 4>, BTLA_ISA::AVX512_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<48, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
if (dispatcher_utils::check_avx_vnni() && p->blocksize % bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvxvnniKBlock<24, 2>, BTLA_ISA::AVX_VNNI>(p, ctx, task);
}
if (dispatcher_utils::check_avx2() && p->blocksize % bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>::KTILE == 0) {
return execute_qpack<bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: Illegal config in int8 compute_type, blocksize:", p->blocksize,
", ISA support vnni:", dispatcher_utils::check_avx_vnni());
", ISA support avx2:", dispatcher_utils::check_avx2());
}
if (p->compute_type == "fp32") {
if (dispatcher_utils::check_avx512f()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx512f<48, 8>, BTLA_ISA::AVX512F>(p, ctx, task);
}
if (dispatcher_utils::check_avx2()) {
return execute_qpack<bestla::gemm::SCoreRowNAvx2<48, 2>, BTLA_ISA::AVX2>(p, ctx, task);
return execute_qpack<bestla::gemm::SCoreRowNAvx2<24, 4>, BTLA_ISA::AVX2>(p, ctx, task);
}
TORCH_CHECK(false, "Qbits: device ISA must support BTLA_ISA::AVX2 when compute_type==fp32");
}
Expand Down
Loading
Loading