Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Add weight_only support for PyTorch framework #297

Merged
merged 25 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2e35c81
Add weight_only support for PyTorch framework
PenghuiCheng Sep 4, 2023
268e068
update isa check on no-amx-supported platform
zhewang1-intc Sep 5, 2023
a8c09bc
Fixed building issue and get_version issue
PenghuiCheng Sep 6, 2023
1246d55
Fixed UT error
PenghuiCheng Sep 6, 2023
2b29630
Add weight_only support for PyTorch framework
PenghuiCheng Sep 4, 2023
8df521f
update isa check on no-amx-supported platform
zhewang1-intc Sep 5, 2023
39096e9
Fixed building issue and get_version issue
PenghuiCheng Sep 6, 2023
f69d252
Fixed UT error
PenghuiCheng Sep 6, 2023
25f68b9
fix pylint
VincyZhang Sep 7, 2023
7495c1f
Merge branch 'penghuic/qbits_porting' of https://github.com/intel/int…
VincyZhang Sep 7, 2023
26b90e7
Fixed pylint error
PenghuiCheng Sep 8, 2023
ef70865
merge main branch
PenghuiCheng Sep 8, 2023
75de4d6
Fixed pylint error
PenghuiCheng Sep 8, 2023
f3611d0
Update install_binary.sh
VincyZhang Sep 8, 2023
2819294
Merge main branch
PenghuiCheng Sep 8, 2023
9a244aa
Merge main branch
PenghuiCheng Sep 12, 2023
7b2834a
Fixed UT error with master version neural-compressor
PenghuiCheng Sep 12, 2023
74e76b9
Merge remote-tracking branch 'origin' into penghuic/qbits_porting
PenghuiCheng Sep 12, 2023
4dd5d55
Add UT for weight-only quantization
PenghuiCheng Sep 12, 2023
bec8576
Remove modeling_causal.py since modeling_auto replace it
PenghuiCheng Sep 12, 2023
77af939
Fixed import error when loading libweight_only_jblasop.so
PenghuiCheng Sep 14, 2023
79a1842
merge main branch
PenghuiCheng Sep 15, 2023
dd3ebfb
Update jblass and weight-only UT
PenghuiCheng Sep 15, 2023
96180e0
Support gcc 8.4 version to build jblass
PenghuiCheng Sep 15, 2023
0130bb4
Merge branch 'main' into penghuic/qbits_porting
VincyZhang Sep 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/cpp-graph-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:
cd ${{ github.workspace }}
conda activate cpp-graph-test || source activate cpp-graph-test
pip install build --upgrade
python -m build -s -w
pip install -r requirements.txt
python setup.py sdist bdist_wheel
pip install dist/intel_extension_for_transformers*.whl
pip list

Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/llm-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:
cd ${{ github.workspace }}
conda activate llm-test || source activate llm-test
pip install build --upgrade
python -m build -s -w
pip install -r requirements.txt
python setup.py sdist bdist_wheel
pip install dist/intel_extension_for_transformers*.whl
pip list

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/script/formatScan/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ else
echo "Not found requirements.txt file."
fi
# install packages
pip install accelerate nlpaug nltk optimum-intel
pip install accelerate nlpaug nltk schema optimum-intel
pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@83dbfbf6070324f3e5872f63e49d49ff7ef4c9b3

echo "[DEBUG] list pipdeptree..."
Expand All @@ -39,7 +39,7 @@ python -m pylint -f json --disable=R,C,W,E1129 \
--max-line-length=120 \
--extension-pkg-whitelist=numpy,nltk \
--ignored-classes=TensorProto,NodeProto \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.transformers.modeling.modeling_causal,intel_extension_for_transformers.neural_engine_py \
--ignored-modules=tensorflow,torch,torch.quantization,torch.tensor,torchvision,mxnet,onnx,onnxruntime,neural_compressor,neural_compressor.benchmark,intel_extension_for_transformers.neural_engine_py \
/intel-extension-for-transformers/intel_extension_for_transformers >${log_dir}/pylint.json
exit_code=$?

Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/script/install_binary.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ git config --global --add safe.directory "*"
git submodule update --init --recursive

$BOLD_YELLOW && echo "---------------- run python setup.py sdist bdist_wheel -------------" && $RESET
pip install build --upgrade
python3 -m build -s -w
python setup.py sdist bdist_wheel

$BOLD_YELLOW && echo "---------------- pip install binary -------------" && $RESET
pip install dist/intel_extension_for_transformers*.whl
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ tags
build/
_build
dist/
.cache/

# build / dist files
/intel_extension_for_transformers/intel_extension_for_transformers[.-]*/
Expand Down
21 changes: 17 additions & 4 deletions intel_extension_for_transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
from ._version import __version__ # load _version file generated by setuptools_scm
except ModuleNotFoundError:
__version__ = "1.1"
def _get_version(default='x.x.x.dev'):
try:
from pkg_resources import DistributionNotFound, get_distribution
except ImportError:
return default
else:
try:
return get_distribution(__package__).version
except DistributionNotFound: # Run without install
return default
except ValueError: # Python 3 setup
return default
except TypeError: # Python 2 setup
return default


__version__ = _get_version()
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@
template <JBLAS_ISA> class _Epilogue_T> \
class Launcher

class env_initer {
public:
env_initer() { jblas::utils::request_perm_xtile_data(); }
};
static env_initer initer;

inline bool check_amx() {
return jblas::utils::parallel::CpuDevice::getInstance()->AMX_BF16();
}
inline bool check_vnni() {
return jblas::utils::parallel::CpuDevice::getInstance()->AVX_VNNI();
inline bool check_avx512_vnni() {
return jblas::utils::parallel::CpuDevice::getInstance()->AVX512_VNNI();
}
inline bool check_avx512f() {
return jblas::utils::parallel::CpuDevice::getInstance()->AVX512F();
}

class env_initer {
public:
env_initer() { if(check_amx()) jblas::utils::request_perm_xtile_data(); }
};
static env_initer initer;

inline void set_nk(qbits_runtime_ctx* ctx, torch::Tensor* tensor) {
ctx->n = ctx->transpose ? tensor->sizes()[0] : tensor->sizes()[1];
ctx->k = ctx->transpose ? tensor->sizes()[1] : tensor->sizes()[0];
Expand Down Expand Up @@ -216,7 +216,7 @@ void parse_gemm_core_online(qbits_config_param* p, qbits_runtime_ctx* ctx) {
jblas::utils::parallel::Parallel2DGemmKBlockFixed,
JblasAVX512_VNNI>(p, ctx);
}
if (check_vnni()) {
if (check_avx512_vnni()) {
return parse_weight<
TASK, jblas::wrapper::gemm_kblock::GemmInterfaceKBlockPackWeight,
jblas::wrapper::gemm_kblock::GemmSLauncherKBlockPackWeight,
Expand Down Expand Up @@ -278,7 +278,7 @@ void parse_gemm_core_offline(qbits_config_param* p, qbits_runtime_ctx* ctx) {
jblas::gemm::kblock::GemmCore_Row_NN_16x48_AMX_INT8_KBLOCK,
jblas::utils::parallel::Parallel2DGemmKBlockFixed, JblasAMX_INT8>(
p, ctx);
} else if (check_vnni() &&
} else if (check_avx512_vnni() &&
blocksize %
(jblas::gemm::kblock::
GemmCore_Row_NN_3x48_AVX512_VNNI_KBLOCK::KTILE *
Expand All @@ -293,7 +293,7 @@ void parse_gemm_core_offline(qbits_config_param* p, qbits_runtime_ctx* ctx) {
}
TORCH_CHECK(false,
"Illegal config in int8 compute_type: blocksize:", blocksize,
" ISA largger than vnni:", check_vnni());
" ISA largger than vnni:", check_avx512_vnni());
break;
case jblas::gemm::GemmCoreType::AVX512F_8X48:
assert(p->compute_type == "fp32");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .functions import matmul_4bit
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import operator
import torch
from functools import reduce
from torch import Tensor
from typing import Tuple, Optional, List

def prod(iterable):
return reduce(operator.mul, iterable, 1)

class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
ctx.bias = bias
B_shape = state[1]
if A.shape[-1] == B_shape[0]:
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)


# 1. Dequantize
# 2. MatmulnN
# torch.ops.weight_only_jblasop.jblas_symqdq_weight(B, False, 4, 32) # TODO: replace with dequantize
output = torch.nn.functional.linear(A, B.to(A.dtype), bias)

# 3. Save state
ctx.state = state
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype

if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, B)
else:
ctx.tensors = (None, None)

return output

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None

req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
A, B = ctx.tensors
state = ctx.state

grad_A, grad_B, grad_bias = None, None, None

if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
# torch.ops.weight_only_jblasop.jblas_symqdq_weight(B, False, 4, 32) # TODO: replace with dequantize
if req_gradA: grad_A = torch.matmul(grad_output, B.to(grad_output.dtype))

return grad_A, grad_B, None, grad_bias, None

def matmul_4bit(A: Tensor, B: Tensor, quant_state: List = None, out: Tensor = None, bias=None, do_dequant=True):
# assert quant_state is not None
if do_dequant:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state) # TODO: replace with 4bit matmul
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .quantization_config import WeightOnlyConfig
Loading