Skip to content

Commit

Permalink
[Docs & Refactor] Add docstring and UT of other quantizers (open-mmla…
Browse files Browse the repository at this point in the history
…b#439)

* add quantizer docstring and refactor the interface of AcademicQuantizer

* add AcademicQuantizer unittest

* add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface

* adapt torch113 ci

* fix import

* fix lint

* update some docstring

* fix ci
  • Loading branch information
humu789 authored and humu789 committed Apr 11, 2023
1 parent 35faed8 commit af6ed31
Show file tree
Hide file tree
Showing 13 changed files with 532 additions and 160 deletions.
14 changes: 2 additions & 12 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from mmengine.structures import BaseDataElement
from torch import nn

from mmrazor.models.task_modules.tracer import build_graphmodule
from mmrazor.registry import MODEL_WRAPPERS, MODELS
from ..base import BaseAlgorithm, BaseModel

Expand Down Expand Up @@ -156,27 +155,18 @@ def _build_qmodels(self, model: BaseModel):
"""

qmodels = nn.ModuleDict()

self.quantizer.swap_ff_with_fxff(model)
tracer = self.quantizer.tracer

for mode in self.forward_modes:
concrete_args = {'mode': mode}
traced_graph = tracer.trace(model, concrete_args=concrete_args)
graph_mopdule = build_graphmodule(model, traced_graph)
observed_module = self.quantizer.prepare(model, graph_mopdule)
observed_module = self.quantizer.prepare(model, concrete_args)
qmodels[mode] = observed_module
# import pdb
# pdb.set_trace()
# dummy_input = torch.randn(self.input_shapes)
# qmodels['predict'](dummy_input, None, 'predict')

return qmodels

def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor') -> ForwardResults:
"""Forward with qmodels in quantization."""

if mode in self.qmodels:
qmodel = self.qmodels[mode]
Expand Down
118 changes: 81 additions & 37 deletions mmrazor/models/quantizers/academic_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional

import torch

from mmrazor.models.task_modules.tracer import build_graphmodule
from mmrazor.models.utils import str2class
from mmrazor.registry import MODELS
from mmrazor.structures.quantization import BackendConfigs, QConfigHandler
from .base import BaseQuantizer
Expand All @@ -10,45 +14,90 @@
from torch.ao.quantization.fx.custom_config import (FuseCustomConfig,
PrepareCustomConfig)
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization.quant_type import _quant_type_from_str
from torch.ao.quantization.quantize_fx import _fuse_fx
except ImportError:
from mmrazor.utils import get_placeholder
prepare = get_placeholder('torch>=1.13')
FuseCustomConfig = get_placeholder('torch>=1.13')
PrepareCustomConfig = get_placeholder('torch>=1.13')
QConfigMapping = get_placeholder('torch>=1.13')
_quant_type_from_str = get_placeholder('torch>=1.13')
_fuse_fx = get_placeholder('torch>=1.13')

GLOBAL_DICT_KEY = '_global_'
OBJECT_TYPE_DICT_KEY = 'object_type'
MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex'
MODULE_NAME_DICT_KEY = 'module_name'
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = 'module_name_object_type_order'

# keys can be used in `prepare_custom_config` of `AcademicQuantizer`.
FLOAT_TO_OBSERVED_DICT_KEY = 'float_to_observed_custom_module_class'
PRESERVED_ATTRIBUTES_DICT_KEY = 'preserved_attributes'


@MODELS.register_module()
class AcademicQuantizer(BaseQuantizer):
"""tmp."""
"""Quantizer for academic researching. Different from some quantizers for
deploying, `AcademicQuantizer` is without the interfaces for deployment,
but it has more flexible functions for quantizing your model. With its
help, you can custom configuration qconfig for differenet OP by
`qconfig_mapping` to implement customized experiments, including using
custom fakquant, trying mixed precision quantization, comparing different
quantization scheme and so on.
Args:
qconfig_mapping (Dict): Mapping from model ops to qconfig to configure
how a model is quantized. You can specify qconfigs using the
following keys (in increasing match priority):
``_global_`` : sets the global (default) qconfig
``object_type`` : sets the qconfig for a given module type,
function, or method name
``module_name`` : sets the qconfig for modules matching the
given module name
tracer (Dict): It can be used to trace the float model to generate the
corresponding graph, which contributes to prepare for quantizing
the float model with code-free. Default to
`dict(type='mmrazor.CustomTracer')`.
prepare_custom_config (Optional[Dict]): Custom configuration for
:func:`~torch.ao.quantization.fx.prepare`. You can specify the
follow:
``float_to_observed_custom_module_class`` : a list of dict that
mapping from float module classes to observed module
classes, e.g.
`[('FloatCustomModule', 'ObservedCustomModule')]`
``preserved_attributes``: a list of attributes that persist
even if they are not used in ``forward``, e.g.
`['attr1', 'attr2']`
"""

def __init__(self,
qconfig_mapping,
tracer=dict(type='mmrazor.CustomTracer'),
prepare_custom_config=None,
backend_config=BackendConfigs['academic']):
qconfig_mapping: Dict,
tracer: Dict = dict(type='mmrazor.CustomTracer'),
prepare_custom_config: Optional[Dict] = None):
super().__init__(tracer)
self.qconfig_mapping = self.gen_qconfig_mapping(qconfig_mapping)
self.prepare_custom_config = self.gen_prepare_custom_config(
prepare_custom_config)
self.backend_config = backend_config
self.backend_config = BackendConfigs[self.backend]
self.example_inputs = (torch.randn(1, 3, 224, 224), )

def prepare(self, model, graph_module):
"""tmp."""
@property
def backend(self):
"""The key of the corresponding backend config."""
return 'academic'

def prepare(self, model, concrete_args=None):
"""Prepare for quantizing model, which includes as follows:
1. Swap floatfunctional with FXFloatFunctional;
2. Trace model to generate `GraphModule`;
2. Fuse some OPs combination, such as conv + bn, conv + relu and so on;
3. Swap some conv or linear module with QAT Modules which contain
weight fakequant nodes;
4. Insert required fakequant nodes for activation.
step 3 and step 4 are implemented in
:func:`~torch.ao.quantization.fx.prepare`
"""
self.swap_ff_with_fxff(model)
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
graph_module = build_graphmodule(model, traced_graph)
preserved_attributes = self.prepare_custom_config.preserved_attributes
for attr_name in preserved_attributes:
setattr(graph_module, attr_name, getattr(model, attr_name))
Expand All @@ -71,51 +120,46 @@ def prepare(self, model, graph_module):

return prepared

def gen_qconfig_mapping(self, qconfig_mapping):
"""tmp."""
def gen_qconfig_mapping(self, qconfig_mapping: Dict):
"""Convert qconfig_mapping in config file to `QConfigMapping`.
`QConfigMapping` is a custom class for mapping from model ops to
:class:`torch.ao.quantization.QConfig` s.
"""
conf = QConfigMapping()
if GLOBAL_DICT_KEY in qconfig_mapping:
qconfig = QConfigHandler(
qconfig_mapping[GLOBAL_DICT_KEY]).convert()
conf.set_global(qconfig)

for object_type, qconfig in qconfig_mapping.get(
OBJECT_TYPE_DICT_KEY, []):
qconfig = QConfigHandler(qconfig).convert()
conf.set_object_type(object_type, qconfig)
conf.set_object_type(str2class(object_type), qconfig)

for module_name_regex, qconfig in qconfig_mapping.get(
MODULE_NAME_REGEX_DICT_KEY, []):
qconfig = QConfigHandler(qconfig).convert()
conf.set_module_name_regex(module_name_regex, qconfig)
for module_name, qconfig in qconfig_mapping.get(
MODULE_NAME_DICT_KEY, []):
qconfig = QConfigHandler(qconfig).convert()
conf.set_module_name(module_name, qconfig)
for module_name, object_type, index, qconfig in qconfig_mapping.get(
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
qconfig = QConfigHandler(qconfig).convert()
conf.set_module_name_object_type_order(module_name, object_type,
index, qconfig)

return conf

def gen_prepare_custom_config(self, prepare_custom_config):
"""tmp."""
def gen_prepare_custom_config(self, prepare_custom_config: Optional[Dict]):
"""Convert prepare_custom_config in config file to
`PrepareCustomConfig`.
`PrepareCustomConfig` is a custom class for custom configurating
:func:`~torch.ao.quantization.fx.prepare`.
"""
conf = PrepareCustomConfig()
if prepare_custom_config is None:
return conf
else:
for quant_type_name, custom_module_mapping in \
prepare_custom_config.get(
FLOAT_TO_OBSERVED_DICT_KEY, {}).items():
quant_type = _quant_type_from_str(quant_type_name)
mapping_items = custom_module_mapping.items()
for float_class_str, observed_class_str in mapping_items:
float_class = MODELS.get(float_class_str)
observed_class = MODELS.get(observed_class_str)
conf.set_float_to_observed_mapping(float_class,
observed_class,
quant_type)
for float_class_str, observed_class_str in prepare_custom_config.get( # noqa: E501
FLOAT_TO_OBSERVED_DICT_KEY, []):
float_class = MODELS.get(float_class_str)
observed_class = MODELS.get(observed_class_str)
conf.set_float_to_observed_mapping(float_class, observed_class)
conf.set_preserved_attributes(
prepare_custom_config.get(PRESERVED_ATTRIBUTES_DICT_KEY, []))
return conf
32 changes: 26 additions & 6 deletions mmrazor/models/quantizers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from typing import Dict

import torch
from mmengine.model import BaseModule
Expand All @@ -8,18 +9,37 @@


class BaseQuantizer(BaseModule):
"""tmp."""

def __init__(self, tracer):
"""Base class for quantizers. Its role for several subclass is as follows:
1. Provide tracer for tracing model for all subclass.
2. Define some common abstract methods, such as `prepare`.
3. Provide some common functional interfaces, such as `swap_ff_with_fxff`.
Args:
tracer (Dict): It can be used to trace the float model to generate the
corresponding graph, which contributes to prepare for quantizing
the float model with code-free.
"""

def __init__(self, tracer: Dict):
super().__init__()
self.tracer = TASK_UTILS.build(tracer)

@abstractmethod
def prepare(self, model, graph_module):
"""tmp."""
def prepare(self, model):
"""Prepare for quantizing model, which usually includes as follows:
1. Swap floatfunctional with FXFloatFunctional;
2. Trace model to generate `GraphModule`;
2. Fuse some OPs combination, such as conv + bn, conv + relu and so on;
3. Swap some conv or linear module with QAT Modules which contain
weight fakequant nodes;
4. Insert required fakequant nodes for activation.
5. (Optional) Delete some redundant fakequant nodes according to the
special requirement of the backend for deployment.
"""
pass

def swap_ff_with_fxff(self, model):
def swap_ff_with_fxff(self, model: torch.nn.Module):
"""Swap FloatFunctional with FXFloatFunctional."""
modules_to_swap = []
for name, module in model.named_children():
Expand Down
50 changes: 28 additions & 22 deletions mmrazor/models/quantizers/native_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
qat_modules = get_package_placeholder('torch>=1.13')

from mmrazor import digit_version
from mmrazor.models.task_modules.tracer import build_graphmodule
from mmrazor.models.task_modules.tracer.fx import (
del_fakequant_after_function, del_fakequant_after_method,
del_fakequant_after_module, del_fakequant_after_op,
Expand Down Expand Up @@ -90,10 +91,6 @@ class NativeQuantizer(BaseQuantizer):
)
"""

# backend: 'native'
# support_w_modes = ['per_tensor', 'per_channel']
# support_a_modes = ['per_tensor']

def __init__(self,
global_qconfig: Union[Dict, Config],
no_observer_modules: Optional[List] = None,
Expand Down Expand Up @@ -135,25 +132,24 @@ def __init__(self,

@property
def backend(self):
"""tmp."""
"""The key of the corresponding backend config."""
return 'native'

@property
def support_w_modes(self):
"""tmp."""
return ['per_tensor', 'per_channel']
"""Supported quantization modes for weight about per_tensor or
per_channel."""
return ('per_tensor', 'per_channel')

@property
def support_a_modes(self):
"""tmp."""
return ['per_tensor']
"""Supported quantization modes for activation about per_tensor or
per_channel."""
return ('per_tensor')

def prepare(self, model, graph_module):
def prepare(self, model, concrete_args=None):
"""prepare graph to ObservedGraphModule.
Args:
graph_module (_type_): GraphModules before fuse.
Returns:
ObservedGraphModule: GraphModules after fuse and observer.
Expand All @@ -170,7 +166,9 @@ def prepare(self, model, graph_module):
fake_quant operations that we need it to be fused into our
`SUPPORT_QAT_MODULES` type, which is a tricky way to deal with it.
"""

self.swap_ff_with_fxff(model)
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
graph_module = build_graphmodule(model, traced_graph)
graph_module = _fuse_fx(
graph_module=graph_module,
is_qat=True,
Expand Down Expand Up @@ -319,40 +317,48 @@ def module_prev_wo_fakequant(self):

@property
def module_prev_wo_fakequant(self):
"""tmp."""
"""Configurate the modules that their previous nodes are redundant
fakequants."""
return tuple()

@property
def module_next_wo_fakequant(self):
"""tmp."""
"""Configurate the modules that their next nodes are redundant
fakequants."""
return tuple()

@property
def function_prev_wo_fakequant(self):
"""tmp."""
"""Configurate the functions that their previous nodes are redundant
fakequants."""
return tuple()

@property
def function_next_wo_fakequant(self):
"""tmp."""
"""Configurate the functions that their next nodes are redundant
fakequants."""
return tuple()

@property
def method_prev_wo_fakequant(self):
"""tmp."""
"""Configurate the methods that their previous nodes are redundant
fakequants."""
return tuple()

@property
def method_next_wo_fakequant(self):
"""tmp."""
"""Configurate the methods that their next nodes are redundant
fakequants."""
return tuple()

@property
def op_prev_wo_fakequant(self):
"""tmp."""
"""Configurate the OPs that their previous nodes are redundant
fakequants."""
return tuple()

@property
def op_next_wo_fakequant(self):
"""tmp."""
"""Configurate the OPs that their next nodes are redundant
fakequants."""
return tuple()
Loading

0 comments on commit af6ed31

Please sign in to comment.