Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
fix QBits actshuf buf overflow under large batch (#1473)
Browse files Browse the repository at this point in the history
Co-authored-by: changwangss <chang1.wang@intel.com>
  • Loading branch information
zhewang1-intc and changwangss authored Apr 15, 2024
1 parent 0ec83b1 commit a6f3ab3
Show file tree
Hide file tree
Showing 28 changed files with 39 additions and 34 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/script/formatScan/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ python -m pylint -f json --disable=R,C,W,E1129 \
--max-line-length=120 \
--extension-pkg-whitelist=numpy,nltk \
--ignored-classes=TensorProto,NodeProto \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.neural_engine_py,cv2,PIL.Image \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.neural_engine_py,intel_extension_for_transformers.qbits,cv2,PIL.Image \
/intel-extension-for-transformers/intel_extension_for_transformers >${log_dir}/pylint.json
exit_code1=$?

Expand All @@ -51,7 +51,7 @@ python -m pylint -f json --disable=R,C,W,E1129 \
--disable=no-name-in-module,import-error,no-member,undefined-variable,no-value-for-parameter,unexpected-keyword-arg,not-callable,no-self-argument,too-many-format-args,invalid-unary-operand-type,too-many-function-args \
--extension-pkg-whitelist=numpy,nltk \
--ignored-classes=TensorProto,NodeProto \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.neural_engine_py,cv2,PIL.Image \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.neural_engine_py,intel_extension_for_transformers.qbits,cv2,PIL.Image \
/intel-extension-for-transformers/intel_extension_for_transformers >> ${log_dir}/pylint.json
exit_code2=$?

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@
if args.sq:
config.save_pretrained(args.output_dir)
user_model.save(args.output_dir)
elif args.mixed_precision or args.woq:
elif args.mixed_precision or (args.woq and not args.use_neural_speed):
# user_model will be changed.
user_model.save_pretrained(args.output_dir)
# loading saved woq model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
## See the License for the specific language governing permissions and
## limitations under the License.
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(qbits LANGUAGES C CXX)
project(qbits_py LANGUAGES C CXX)


set(QBITS_TORCH_PATH "" CACHE STRING "Torch install path")
Expand All @@ -31,17 +31,20 @@ endif()
find_package(Torch REQUIRED
PATHS ${torch_path}
NO_DEFAULT_PATH)

if(NOT WIN32)
find_package(PythonLibs 3 REQUIRED)
endif()

include(FindOpenMP)
add_subdirectory(dispatcher)
add_subdirectory(../../../runtime/third_party/pybind11 pybind11)
add_subdirectory(../transformers/runtime/third_party/pybind11 pybind11)

file(GLOB HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp)
file(GLOB qbits_src ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp)

# Link against LibTorch
pybind11_add_module(qbits ${qbits_src})
target_compile_features(qbits PRIVATE cxx_std_14)
target_link_directories(qbits PRIVATE ${torch_path}/lib)
target_link_libraries(qbits PRIVATE bestla_dispatcher torch_python)
pybind11_add_module(qbits_py ${qbits_src})
target_compile_features(qbits_py PRIVATE cxx_std_14)
target_link_directories(qbits_py PRIVATE ${torch_path}/lib)
target_link_libraries(qbits_py PRIVATE bestla_dispatcher torch_python)
19 changes: 19 additions & 0 deletions intel_extension_for_transformers/qbits/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from intel_extension_for_transformers.qbits_py import * # pylint: disable=E0401, E0611
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ endif()

set_target_properties(bestla_dispatcher PROPERTIES POSITION_INDEPENDENTBTLA_CODE ON)
set_target_properties(bestla_dispatcher PROPERTIES LINKER_LANGUAGE CXX)
target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" torch_python bestla::bestla)
target_link_libraries(bestla_dispatcher OpenMP::OpenMP_CXX OpenMP::OpenMP_C "${TORCH_LIBRARIES}" bestla::bestla)
set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "")
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void quantize_to_packed_weight(woq_config_param* p, woq_runtime_ctx* ctx) {
}
}

void* get_workspace(int need_size) {
void* get_workspace(size_t need_size) {
void* tmpbuf = NULL;
void* workspace = woq_workspace == nullptr ? NULL : woq_workspace;
if (workspace != NULL) {
Expand All @@ -126,7 +126,7 @@ void do_compute(woq_config_param* p, woq_runtime_ctx* ctx, ParamA param_a) {
EpiParam param_epi = {ctx->output->data_ptr(), ctx->bias->data_ptr(), ctx->ldo, 0, ctx->alpha, ctx->beta};
using GemmCore = typename Launcher::GemmCore;
using StorageWeight = typename Launcher::PrologueB::StorageWeight;
int asym_size = 0, shuf_size = 0;
size_t asym_size = 0, shuf_size = 0;
int8_t* tmpbuf = nullptr;
if constexpr (GemmCore::ISA == BTLA_ISA::AMX_INT8 || GemmCore::ISA == BTLA_ISA::AVX512_VNNI ||
GemmCore::ISA == BTLA_ISA::AVX_VNNI) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ static bool check_isa_supported(std::string isa) {
return false;
}

PYBIND11_MODULE(qbits, m) {
PYBIND11_MODULE(qbits_py, m) {
m.def("quantize_to_packed_weight", &quantize_to_packed_weight);
m.def("woq_linear", &woq_linear);
m.def("dequantize_packed_weight", &dequantize_packed_weight);
Expand Down

This file was deleted.

This file was deleted.

3 changes: 3 additions & 0 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def post_init_runtime(self):
runtime_supported_weight_dtype = [
"int4",
"int4_clip", # int4_clip will merge to int4 in next release.
"int4_fullrange", # int4_fullrange will merge to int4 in next release.
"int8",
"fp8",
"fp8_e5m2",
Expand Down Expand Up @@ -467,6 +468,8 @@ def post_init_runtime(self):
self.weight_dtype = "int4"
elif self.weight_dtype == "int4_clip":
self.weight_dtype == "int4"
elif self.weight_dtype == "int4_fullrange":
self.weight_dtype == "int4"
elif self.weight_dtype == "fp8":
self.weight_dtype == "fp8_e4m3"
elif self.weight_dtype == "fp8":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def check_submodules():
ext_modules = []
else:
ext_modules = [CMakeExtension(
"intel_extension_for_transformers.qbits", 'intel_extension_for_transformers/transformers/llm/operator/csrc/')]
"intel_extension_for_transformers.qbits_py", 'intel_extension_for_transformers/qbits/')]
if SKIP_RUNTIME:
subprocess.check_call(
["git", "submodule", "update", "--init", "intel_extension_for_transformers/transformers/runtime/third_party/pybind11"], cwd=cwd)
Expand Down

0 comments on commit a6f3ab3

Please sign in to comment.