Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add model export for QAT #3458

Merged
merged 8 commits into from
Mar 24, 2021
Merged

Add model export for QAT #3458

merged 8 commits into from
Mar 24, 2021

Conversation

linbinskn
Copy link
Contributor

Add export_model function for quantization algorithm QAT. Users can save quantized weight and calibration parameters to specific path. What' s more, this function will be a prerequisite for #3356 (Support mixed-precision quantization speed up by using tensorrt).

if hasattr(module, 'weight_bit'):
delattr(module, 'weight_bit')
if hasattr(module, 'activation_bit'):
delattr(module, 'activation_bit')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you choose these attributes? different quantization algorithms have different subset of these attributes? is it possible that a new quantization algorithm has more attributes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is only for QAT now, other quantization algorithms might have other attributes and we need to define another function to handle it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then it is better to put it in QAT_Quantizer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why only support export for QAT?

Copy link
Contributor Author

@linbinskn linbinskn Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because QAT is enough for linear quantization simulation. Binarized quantization like BNN needs no calibration and users can save its weights as normal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you refer to model export implemented in pruners, and try to align the export feature of pruner and quantizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have aligned export_model() function in pruners and quantizers.

assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
support_op = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ReLU6]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ReLU6?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because activation layers also need to be set bit width in inference framework like tensorrt, we have to choose activation layers which we can support during calibration. I have tested ReLU6 in some examples and it can be fully supported by tensorrt. But I am not sure whether other activation ops can be supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what types of ops does QAT support in our current implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed support_op constrain in export_model().

@linbinskn linbinskn requested a review from QuanluZhang March 20, 2021 06:22
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
layer.module.register_buffer('tracked_min', torch.zeros(1))
layer.module.register_buffer('tracked_max_biased', torch.zeros(1))
layer.module.register_buffer('tracked_max', torch.zeros(1))

def del_simulated_attr(self, module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> _del_simulated_attr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

Comment on lines 169 to 188
if hasattr(module, 'old_weight'):
delattr(module, 'old_weight')
if hasattr(module, 'ema_decay'):
delattr(module, 'ema_decay')
if hasattr(module, 'tracked_min_biased'):
delattr(module, 'tracked_min_biased')
if hasattr(module, 'tracked_max_biased'):
delattr(module, 'tracked_max_biased')
if hasattr(module, 'tracked_min'):
delattr(module, 'tracked_min')
if hasattr(module, 'tracked_max'):
delattr(module, 'tracked_max')
if hasattr(module, 'scale'):
delattr(module, 'scale')
if hasattr(module, 'zero_point'):
delattr(module, 'zero_point')
if hasattr(module, 'weight_bit'):
delattr(module, 'weight_bit')
if hasattr(module, 'activation_bit'):
delattr(module, 'activation_bit')
Copy link
Contributor

@QuanluZhang QuanluZhang Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to use for.

to_del = ['old_weight', 'ema_decay', ...]
for each in to_del:
    if hasattr():
        delattr()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Modified.

device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tested export onnx? and better to write test for this feature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has tested export_model function including pytorch state_dict() and onnx in three algorithms.

if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bit', torch.zeros(1))

def del_simulated_attr(self, module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> _del_simulated_attr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

torch.save(self.bound_model.state_dict(), model_path)
logger.info('Model state_dict saved to %s', model_path)
if calibration_path is not None:
logger.info('No calibration config will be saved because no calibration data in BNN quantizer')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should export bit number even they are all 1 bit. Because the speedup module should know this information to use 1 bit. the speedup module does not know you are using BNNQuantizer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. Have added.

for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_bit'):
calibration_config[name] = {}
calibration_config[name]['weight_bit'] = int(module.weight_bit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this quantizer does not calibrate activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our current implementation of Dorefa, it does not quantize activation, so we don't need to calibrate activation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you double check the paper, does the paper mention how to calibrate activation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion, we reach an agreement that the refactor of Dorefa should starts after survey and it will be done in another PR. What' s more, ut related to export_model() has been added into code.

device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For other implementations return a dict, add return {} is better?

And all export model() implementations seem to have mostly the same logic, can we use the implementation in QAT_Quantizer for all Quantizer? Or just specify how to construct calibration_config in different Quantizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useful suggestions! Have modified.


for config, quantize_algorithm in zip(config_set, quantize_algorithm_set):
model = TorchModel()
model.relu = torch.nn.ReLU()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this line used for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow the code of test_torch_QAT_quantizer() in ut. Change the op relu type to ReLU and it will match the type in config_list.

@QuanluZhang
Copy link
Contributor

@linbinskn , looks good! please update doc accordingly

@QuanluZhang QuanluZhang merged commit f51d985 into microsoft:master Mar 24, 2021
@linbinskn linbinskn deleted the qat_export branch March 24, 2021 12:11
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants