Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
eliminate bad smell code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhewang1-intc committed May 9, 2024
1 parent 3cfca1e commit 400b9eb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ concept quant_PrologueA = requires {
requires !std::is_same_v<T, bestla::utils::bf16>;
};

template <class GemmCore>
constexpr bool is_int8_cmpt_gemmcore() {
return GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI || std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>;
}

template <class Launcher>
void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start();
Expand All @@ -63,7 +69,6 @@ void dequantize_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
}
}

// TODO(zhe): weight+scale combination check.
template <class Launcher>
void quantize_to_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (dispatcher_utils::initer.verbose) dispatcher_utils::timer.start();
Expand Down Expand Up @@ -128,9 +133,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) {
using StorageWeight = typename Launcher::PrologueB::StorageWeight;
size_t asym_size = 0, shuf_size = 0;
int8_t* tmpbuf = nullptr;
if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI ||
std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
using Parallel = bestla::parallel::gemm::SchedulerKBlockS<GemmCore>;
bestla::utils::GemmProblem gp(1, ctx->m, ctx->n, ctx->k, p->blocksize);
StorageWeight* packedw = dynamic_cast<StorageWeight*>(ctx->deseries_wei);
Expand Down Expand Up @@ -233,9 +236,7 @@ void execute_task(woq_config_param* p, woq_runtime_ctx* ctx) {
template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class PrologueB,
template <class _T, BTLA_ISA> class PrologueA, template <BTLA_ISA> class Epilogue>
void parse_launcher(woq_config_param* p, woq_runtime_ctx* ctx) {
if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI ||
std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
using Launcher = bestla::wrapper::gemm::LauncherIntKBlock<GemmCore::ISA, GemmCore, PrologueA, PrologueB, Epilogue>;
return execute_task<TASK, Launcher>(p, ctx);
} else {
Expand All @@ -259,9 +260,7 @@ template <WOQ_TASK TASK, class GemmCore, template <class _T, BTLA_ISA> class Pro
void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
using namespace bestla::prologue_a::gemm;
if (p->src_dt == dispatcher_utils::QBITS_FP32) {
if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI ||
std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeF32, dispatcher_utils::QBITS_FP32>(
p, ctx);
} else {
Expand All @@ -270,9 +269,7 @@ void parse_activation(woq_config_param* p, woq_runtime_ctx* ctx) {
}
}
if (p->src_dt == dispatcher_utils::QBITS_BF16) {
if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI ||
std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>) {
if constexpr (is_int8_cmpt_gemmcore<GemmCore>()) {
return parse_store<TASK, GemmCore, PrologueB, ShuffleActivationKBlockQuantizeBf16, dispatcher_utils::QBITS_BF16>(
p, ctx);
} else {
Expand All @@ -292,9 +289,7 @@ void parse_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
if (p->weight_type == "nf4" || p->weight_type == "fp4_e2m1_bnb" || p->weight_type == "fp4_e2m1" ||
p->weight_type == "fp8_e4m3" || p->weight_type == "fp8_e5m2") {
TORCH_CHECK(!p->asym, "Qbits: float-weight unsupports asym quantization.");
if constexpr (GemmCore::ISA != BTLA_ISA::AMX_INT8 && GemmCore::ISA != BTLA_ISA::AVX512_VNNI &&
GemmCore::ISA != BTLA_ISA::AVX_VNNI &&
!std::is_same_v<GemmCore, bestla::gemm::ICoreRowNAvx2vnniKBlock<24, 2>>)
if constexpr (!is_int8_cmpt_gemmcore<GemmCore>())
return parse_activation<TASK, GemmCore, WeightKBlockNFloat>(p, ctx);
}
TORCH_CHECK(false,
Expand Down Expand Up @@ -391,6 +386,8 @@ void parse_gemm_core(woq_config_param* p, woq_runtime_ctx* ctx) {
}

void dispatch_woq_task(woq_config_param* p, woq_runtime_ctx* ctx, WOQ_TASK task) {
TORCH_CHECK(!(p->asym && (p->compute_type == "int8" && weight_type == "int8")),
"QBits: unsupported bestla_config, asym quantization in int8 compute_type with int8 weight_type.");
switch (task) {
case WOQ_QUANTIZE:
return parse_gemm_core<WOQ_QUANTIZE>(p, ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
def test(m, n, k, blocksize, compute_type, weight_type, scale_type, asym, transpose, add_bias, src_dt, dst_dt, dump_tensor_info=True):
if compute_type not in cmpt_configs[weight_type] or scale_type not in scale_configs[weight_type]:
pytest.skip()
# TODO(zhe): add constrain in QBits backend.
if asym and (weight_type not in asym_configs or (compute_type == "int8" and weight_type == "int8")):
pytest.skip()
torch.manual_seed(0)
Expand Down

0 comments on commit 400b9eb

Please sign in to comment.