Skip to content

Commit

Permalink
fix bf16 symbolic_trace bug (#1892)
Browse files Browse the repository at this point in the history
Description: fix bf16 symbolic_trace bug,

- cause abnormal recursive calling.
- missing necessary attributes
- By moving BF16 fallback ahead of quantization and removing bf16_symbolic_trace, we fix it.

---------

Signed-off-by: xin3he <xin3.he@intel.com>
Co-authored-by: Sun, Xuehao <xuehao.sun@intel.com>
  • Loading branch information
xin3he and XuehaoSun authored Jul 9, 2024
1 parent e080e06 commit 3fe2fd9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 49 deletions.
20 changes: 12 additions & 8 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,8 @@ def _combine_capability(self, bf16_ops, q_capability):
q_capability["opwise"][bf16_op] = [bf16_config, fp32_config]
if bf16_op[1] not in q_capability["optypewise"]:
q_capability["optypewise"][bf16_op[1]] = [bf16_config, fp32_config]
if bf16_op[1] in q_capability["optypewise"] and bf16_config not in q_capability["optypewise"][bf16_op[1]]:
q_capability["optypewise"][bf16_op[1]].append(bf16_config)
return q_capability

def get_fused_list(self, model):
Expand Down Expand Up @@ -3579,6 +3581,16 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
return q_model

self.tune_cfg["fx_sub_module_list"] = self.sub_module_list

# BF16 fallback
if (
len(self.tune_cfg["bf16_ops_list"]) > 0
and self.version.release >= Version("1.11.0").release
and self.use_bf16
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
): # pragma: no cover
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)

if self.approach == "quant_aware_training":
q_model._model.train()
if self.sub_module_list is None:
Expand Down Expand Up @@ -3665,14 +3677,6 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
self.sub_module_list, q_model._model, prefix="", custom_config=self.prepare_custom_config_dict
)

if (
len(self.tune_cfg["bf16_ops_list"]) > 0
and self.version.release >= Version("1.11.0").release
and self.use_bf16
and (CpuInfo().bf16 or os.getenv("FORCE_BF16") == "1")
): # pragma: no cover
q_model._model = torch_utils.bf16_convert.Convert(q_model._model, self.tune_cfg)

self.fused_dict = self.get_fused_list(q_model.model)
q_model.is_quantized = True
q_model.q_config = copy.deepcopy(self.tune_cfg)
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/pytorch_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
name: '1.11'

bf16: ['Linear', 'bmm', 'mm', 'baddbmm', 'addmm', 'addbmm',
'_convolution', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
'Conv1d', 'Conv2d', 'Conv3d', 'LSTM', 'LSTMCell', 'GRU', 'GRUCell']
fp32: ['*'] # `*` means all op types.
int8: &1_11_capabilities {
'static': &cap_s8_1_11 {
Expand Down
35 changes: 4 additions & 31 deletions neural_compressor/adaptor/torch_utils/bf16_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""Bf16 Convert for Torch Utils."""
import torch
import torch.nn as nn
from torch.fx import symbolic_trace

from ...utils import logger

Expand All @@ -28,6 +27,7 @@ class BF16ModuleWrapper(nn.Module):
def __init__(self, module):
"""Init a BF16ModuleWrapper object."""
super(BF16ModuleWrapper, self).__init__()
module = module.bfloat16()
self.add_module("module", module)
self.train(module.training)
# WA for TransformerEncoder to access its Linear's weights and bias
Expand All @@ -38,7 +38,6 @@ def __init__(self, module):
def forward(self, X):
"""Convert dtype."""
X = X.to(torch.bfloat16)
self.module.bfloat16()
X = self.module(X)
return X.float()

Expand All @@ -54,44 +53,18 @@ def Convert(model, tune_cfg):
mixed_precision_model (object): model with mixed precision.
"""
bf16_ops_list = tune_cfg["bf16_ops_list"]
fx_sub_module_list = tune_cfg["fx_sub_module_list"] if "fx_sub_module_list" in tune_cfg.keys() else []
if len(bf16_ops_list) > 0:
logger.info("Convert operators to bfloat16")
mixed_precision_model = _bf16_wrapper_model(model, bf16_ops_list)
if fx_sub_module_list is not None and len(fx_sub_module_list) > 0:
mixed_precision_model = bf16_symbolic_trace(mixed_precision_model, fx_sub_module_list)
return mixed_precision_model


def _bf16_wrapper_model(model, bf16_ops_list, prefix=""):
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for bf16_op_name in bf16_ops_list:
if op_name == bf16_op_name[0]:
if op_name == bf16_op_name[0] or op_name == bf16_op_name[0].split(".module")[0]:
child = BF16ModuleWrapper(child)
else:
_bf16_wrapper_model(child, bf16_ops_list, op_name)
setattr(model, name, child)
return model


def bf16_symbolic_trace(model, fx_sub_module_list, prefix=""):
"""Symbolic trace for bf16 models.
Args:
model (object): the input model.
fx_sub_module_list (list): _description_
prefix (str): prefix of op name.
Returns:
model (object)
"""
for name, child in model.named_children():
op_name = prefix + "." + name if prefix != "" else name
for fx_sub_module_name in fx_sub_module_list:
if op_name == fx_sub_module_name:
child = symbolic_trace(child)
else:
bf16_symbolic_trace(child, fx_sub_module_list, op_name)
setattr(model, name, child)
setattr(model, name, child)
_bf16_wrapper_model(child, bf16_ops_list, op_name)
return model
12 changes: 3 additions & 9 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,21 +392,15 @@ def test_fx_sub_module_quant(self):
"Please use PyTroch 1.11 or higher version for mixed precision with pytorch_fx or pytorch backend",
)
def test_mix_precision(self):
os.environ["FORCE_BF16"] = "1"
model_origin = DynamicControlModel()
# run fx_quant in neural_compressor and save the quantized GraphModule
dataset = Datasets("pytorch")["dummy"]((100, 3, 224, 224))
dataloader = DataLoader("pytorch", dataset)
set_workspace("./saved")
# fx mode usually has .module suffix due to tracing of the entire model fails, so use conv.* to leverage re.match
ptq_fx_op_name_list["conv.*"] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}}
conf = PostTrainingQuantConfig(op_name_dict=ptq_fx_op_name_list)
q_model = quantization.fit(model_origin, conf, calib_dataloader=dataloader, calib_func=eval_func)
tune_cfg = q_model.q_config
tune_cfg["op"][("conv.module", "Conv2d")].clear()
tune_cfg["op"][("conv.module", "Conv2d")] = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}}
tune_cfg["bf16_ops_list"].append(("conv.module", "Conv2d"))
from neural_compressor.adaptor.torch_utils.bf16_convert import Convert

q_model._model = Convert(q_model._model, tune_cfg)

self.assertEqual(q_model._model.conv.module.module.weight.dtype, torch.bfloat16)
self.assertEqual(q_model._model.conv.module.module.bias.dtype, torch.bfloat16)

Expand Down

0 comments on commit 3fe2fd9

Please sign in to comment.