From af6ed3105af3df5ba782b5ca3e1c7d8409189cd5 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 17 Jan 2023 18:34:32 +0800 Subject: [PATCH] [Docs & Refactor] Add docstring and UT of other quantizers (#439) * add quantizer docstring and refactor the interface of AcademicQuantizer * add AcademicQuantizer unittest * add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface * adapt torch113 ci * fix import * fix lint * update some docstring * fix ci --- .../quantization/mm_architecture.py | 14 +- .../models/quantizers/academic_quantizer.py | 118 +++++++++---- mmrazor/models/quantizers/base.py | 32 +++- mmrazor/models/quantizers/native_quantizer.py | 50 +++--- .../models/quantizers/openvino_quantizer.py | 63 ++++--- .../models/quantizers/tensorrt_quantizer.py | 55 +++--- mmrazor/testing/_fx_models.py | 2 + .../test_academic_quantizer.py | 167 ++++++++++++++++++ .../test_quantizers/test_native_quantizer.py | 4 +- .../test_openvino_quantizer.py | 78 ++++++++ .../test_tensorrt_quantizer.py | 74 ++++++++ .../test_quantizers/test_trt_quantizer.py | 34 ---- .../test_task_modules/test_custom_tracer.py | 1 - 13 files changed, 532 insertions(+), 160 deletions(-) create mode 100644 tests/test_models/test_quantizers/test_academic_quantizer.py create mode 100644 tests/test_models/test_quantizers/test_openvino_quantizer.py create mode 100644 tests/test_models/test_quantizers/test_tensorrt_quantizer.py delete mode 100644 tests/test_models/test_quantizers/test_trt_quantizer.py diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 9feb3fb53..afdd7799c 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -8,7 +8,6 @@ from mmengine.structures import BaseDataElement from torch import nn -from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm, BaseModel @@ -156,20 +155,10 @@ def _build_qmodels(self, model: BaseModel): """ qmodels = nn.ModuleDict() - - self.quantizer.swap_ff_with_fxff(model) - tracer = self.quantizer.tracer - for mode in self.forward_modes: concrete_args = {'mode': mode} - traced_graph = tracer.trace(model, concrete_args=concrete_args) - graph_mopdule = build_graphmodule(model, traced_graph) - observed_module = self.quantizer.prepare(model, graph_mopdule) + observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module - # import pdb - # pdb.set_trace() - # dummy_input = torch.randn(self.input_shapes) - # qmodels['predict'](dummy_input, None, 'predict') return qmodels @@ -177,6 +166,7 @@ def forward(self, inputs: torch.Tensor, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: + """Forward with qmodels in quantization.""" if mode in self.qmodels: qmodel = self.qmodels[mode] diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 09cfc7944..a6cfc257c 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + import torch +from mmrazor.models.task_modules.tracer import build_graphmodule +from mmrazor.models.utils import str2class from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHandler from .base import BaseQuantizer @@ -10,7 +14,6 @@ from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, PrepareCustomConfig) from torch.ao.quantization.qconfig_mapping import QConfigMapping - from torch.ao.quantization.quant_type import _quant_type_from_str from torch.ao.quantization.quantize_fx import _fuse_fx except ImportError: from mmrazor.utils import get_placeholder @@ -18,37 +21,83 @@ FuseCustomConfig = get_placeholder('torch>=1.13') PrepareCustomConfig = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') - _quant_type_from_str = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') GLOBAL_DICT_KEY = '_global_' OBJECT_TYPE_DICT_KEY = 'object_type' -MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex' MODULE_NAME_DICT_KEY = 'module_name' -MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = 'module_name_object_type_order' +# keys can be used in `prepare_custom_config` of `AcademicQuantizer`. FLOAT_TO_OBSERVED_DICT_KEY = 'float_to_observed_custom_module_class' PRESERVED_ATTRIBUTES_DICT_KEY = 'preserved_attributes' @MODELS.register_module() class AcademicQuantizer(BaseQuantizer): - """tmp.""" + """Quantizer for academic researching. Different from some quantizers for + deploying, `AcademicQuantizer` is without the interfaces for deployment, + but it has more flexible functions for quantizing your model. With its + help, you can custom configuration qconfig for differenet OP by + `qconfig_mapping` to implement customized experiments, including using + custom fakquant, trying mixed precision quantization, comparing different + quantization scheme and so on. + + Args: + qconfig_mapping (Dict): Mapping from model ops to qconfig to configure + how a model is quantized. You can specify qconfigs using the + following keys (in increasing match priority): + ``_global_`` : sets the global (default) qconfig + ``object_type`` : sets the qconfig for a given module type, + function, or method name + ``module_name`` : sets the qconfig for modules matching the + given module name + tracer (Dict): It can be used to trace the float model to generate the + corresponding graph, which contributes to prepare for quantizing + the float model with code-free. Default to + `dict(type='mmrazor.CustomTracer')`. + prepare_custom_config (Optional[Dict]): Custom configuration for + :func:`~torch.ao.quantization.fx.prepare`. You can specify the + follow: + ``float_to_observed_custom_module_class`` : a list of dict that + mapping from float module classes to observed module + classes, e.g. + `[('FloatCustomModule', 'ObservedCustomModule')]` + ``preserved_attributes``: a list of attributes that persist + even if they are not used in ``forward``, e.g. + `['attr1', 'attr2']` + """ def __init__(self, - qconfig_mapping, - tracer=dict(type='mmrazor.CustomTracer'), - prepare_custom_config=None, - backend_config=BackendConfigs['academic']): + qconfig_mapping: Dict, + tracer: Dict = dict(type='mmrazor.CustomTracer'), + prepare_custom_config: Optional[Dict] = None): super().__init__(tracer) self.qconfig_mapping = self.gen_qconfig_mapping(qconfig_mapping) self.prepare_custom_config = self.gen_prepare_custom_config( prepare_custom_config) - self.backend_config = backend_config + self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) - def prepare(self, model, graph_module): - """tmp.""" + @property + def backend(self): + """The key of the corresponding backend config.""" + return 'academic' + + def prepare(self, model, concrete_args=None): + """Prepare for quantizing model, which includes as follows: + + 1. Swap floatfunctional with FXFloatFunctional; + 2. Trace model to generate `GraphModule`; + 2. Fuse some OPs combination, such as conv + bn, conv + relu and so on; + 3. Swap some conv or linear module with QAT Modules which contain + weight fakequant nodes; + 4. Insert required fakequant nodes for activation. + step 3 and step 4 are implemented in + :func:`~torch.ao.quantization.fx.prepare` + """ + self.swap_ff_with_fxff(model) + traced_graph = self.tracer.trace(model, concrete_args=concrete_args) + graph_module = build_graphmodule(model, traced_graph) preserved_attributes = self.prepare_custom_config.preserved_attributes for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) @@ -71,51 +120,46 @@ def prepare(self, model, graph_module): return prepared - def gen_qconfig_mapping(self, qconfig_mapping): - """tmp.""" + def gen_qconfig_mapping(self, qconfig_mapping: Dict): + """Convert qconfig_mapping in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: qconfig = QConfigHandler( qconfig_mapping[GLOBAL_DICT_KEY]).convert() conf.set_global(qconfig) + for object_type, qconfig in qconfig_mapping.get( OBJECT_TYPE_DICT_KEY, []): qconfig = QConfigHandler(qconfig).convert() - conf.set_object_type(object_type, qconfig) + conf.set_object_type(str2class(object_type), qconfig) - for module_name_regex, qconfig in qconfig_mapping.get( - MODULE_NAME_REGEX_DICT_KEY, []): - qconfig = QConfigHandler(qconfig).convert() - conf.set_module_name_regex(module_name_regex, qconfig) for module_name, qconfig in qconfig_mapping.get( MODULE_NAME_DICT_KEY, []): qconfig = QConfigHandler(qconfig).convert() conf.set_module_name(module_name, qconfig) - for module_name, object_type, index, qconfig in qconfig_mapping.get( - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): - qconfig = QConfigHandler(qconfig).convert() - conf.set_module_name_object_type_order(module_name, object_type, - index, qconfig) return conf - def gen_prepare_custom_config(self, prepare_custom_config): - """tmp.""" + def gen_prepare_custom_config(self, prepare_custom_config: Optional[Dict]): + """Convert prepare_custom_config in config file to + `PrepareCustomConfig`. + + `PrepareCustomConfig` is a custom class for custom configurating + :func:`~torch.ao.quantization.fx.prepare`. + """ conf = PrepareCustomConfig() if prepare_custom_config is None: return conf else: - for quant_type_name, custom_module_mapping in \ - prepare_custom_config.get( - FLOAT_TO_OBSERVED_DICT_KEY, {}).items(): - quant_type = _quant_type_from_str(quant_type_name) - mapping_items = custom_module_mapping.items() - for float_class_str, observed_class_str in mapping_items: - float_class = MODELS.get(float_class_str) - observed_class = MODELS.get(observed_class_str) - conf.set_float_to_observed_mapping(float_class, - observed_class, - quant_type) + for float_class_str, observed_class_str in prepare_custom_config.get( # noqa: E501 + FLOAT_TO_OBSERVED_DICT_KEY, []): + float_class = MODELS.get(float_class_str) + observed_class = MODELS.get(observed_class_str) + conf.set_float_to_observed_mapping(float_class, observed_class) conf.set_preserved_attributes( prepare_custom_config.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 0f14917ac..866199735 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod +from typing import Dict import torch from mmengine.model import BaseModule @@ -8,18 +9,37 @@ class BaseQuantizer(BaseModule): - """tmp.""" - - def __init__(self, tracer): + """Base class for quantizers. Its role for several subclass is as follows: + 1. Provide tracer for tracing model for all subclass. + 2. Define some common abstract methods, such as `prepare`. + 3. Provide some common functional interfaces, such as `swap_ff_with_fxff`. + + Args: + tracer (Dict): It can be used to trace the float model to generate the + corresponding graph, which contributes to prepare for quantizing + the float model with code-free. + """ + + def __init__(self, tracer: Dict): super().__init__() self.tracer = TASK_UTILS.build(tracer) @abstractmethod - def prepare(self, model, graph_module): - """tmp.""" + def prepare(self, model): + """Prepare for quantizing model, which usually includes as follows: + + 1. Swap floatfunctional with FXFloatFunctional; + 2. Trace model to generate `GraphModule`; + 2. Fuse some OPs combination, such as conv + bn, conv + relu and so on; + 3. Swap some conv or linear module with QAT Modules which contain + weight fakequant nodes; + 4. Insert required fakequant nodes for activation. + 5. (Optional) Delete some redundant fakequant nodes according to the + special requirement of the backend for deployment. + """ pass - def swap_ff_with_fxff(self, model): + def swap_ff_with_fxff(self, model: torch.nn.Module): """Swap FloatFunctional with FXFloatFunctional.""" modules_to_swap = [] for name, module in model.named_children(): diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 2b75cf29c..b5de1c028 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -26,6 +26,7 @@ qat_modules = get_package_placeholder('torch>=1.13') from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, @@ -90,10 +91,6 @@ class NativeQuantizer(BaseQuantizer): ) """ - # backend: 'native' - # support_w_modes = ['per_tensor', 'per_channel'] - # support_a_modes = ['per_tensor'] - def __init__(self, global_qconfig: Union[Dict, Config], no_observer_modules: Optional[List] = None, @@ -135,25 +132,24 @@ def __init__(self, @property def backend(self): - """tmp.""" + """The key of the corresponding backend config.""" return 'native' @property def support_w_modes(self): - """tmp.""" - return ['per_tensor', 'per_channel'] + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') @property def support_a_modes(self): - """tmp.""" - return ['per_tensor'] + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') - def prepare(self, model, graph_module): + def prepare(self, model, concrete_args=None): """prepare graph to ObservedGraphModule. - Args: - graph_module (_type_): GraphModules before fuse. - Returns: ObservedGraphModule: GraphModules after fuse and observer. @@ -170,7 +166,9 @@ def prepare(self, model, graph_module): fake_quant operations that we need it to be fused into our `SUPPORT_QAT_MODULES` type, which is a tricky way to deal with it. """ - + self.swap_ff_with_fxff(model) + traced_graph = self.tracer.trace(model, concrete_args=concrete_args) + graph_module = build_graphmodule(model, traced_graph) graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -319,40 +317,48 @@ def module_prev_wo_fakequant(self): @property def module_prev_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their previous nodes are redundant + fakequants.""" return tuple() @property def module_next_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their next nodes are redundant + fakequants.""" return tuple() @property def function_prev_wo_fakequant(self): - """tmp.""" + """Configurate the functions that their previous nodes are redundant + fakequants.""" return tuple() @property def function_next_wo_fakequant(self): - """tmp.""" + """Configurate the functions that their next nodes are redundant + fakequants.""" return tuple() @property def method_prev_wo_fakequant(self): - """tmp.""" + """Configurate the methods that their previous nodes are redundant + fakequants.""" return tuple() @property def method_next_wo_fakequant(self): - """tmp.""" + """Configurate the methods that their next nodes are redundant + fakequants.""" return tuple() @property def op_prev_wo_fakequant(self): - """tmp.""" + """Configurate the OPs that their previous nodes are redundant + fakequants.""" return tuple() @property def op_next_wo_fakequant(self): - """tmp.""" + """Configurate the OPs that their next nodes are redundant + fakequants.""" return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 23abf40da..cb7d3084b 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple import torch @@ -8,43 +9,57 @@ from mmrazor.utils import get_placeholder disable_observer = get_placeholder('torch>=1.13') -from mmrazor.models.task_modules.tracer.fx import build_graphmodule from mmrazor.registry import MODELS from .native_quantizer import NativeQuantizer @MODELS.register_module() class OpenVINOQuantizer(NativeQuantizer): - """Quantizer for Openvino backend.""" + """Quantizer for quantizing and deploying to Openvino backend. - # backend: 'openvino' - # support_w_mode = ['per_tensor', 'per_channel'] - # support_a_mode = ['per_tensor'] + Each backend has its own features, for reducing the gap of quantized + performance between before and after deployment as possible, we should + match the backend's features in quantization. + + Openvino's some important features about quantization is as follows: + * support_w_mode = ('per_tensor', 'per_channel') + * support_a_mode = ('per_tensor') + * weight range should be symmetric, such as int 8 is [-127, 127] rather + than [-128, 127] + """ @property def backend(self): - """tmp.""" + """The backend to deploy, also the key of the corresponding backend + config.""" return 'openvino' @property def support_w_modes(self): - """tmp.""" - return ['per_tensor', 'per_channel'] + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') @property def support_a_modes(self): - """tmp.""" - return ['per_tensor'] + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') def prepare_for_mmdeploy(self, - model, - dummy_input=(1, 3, 224, 224), - checkpoint=None): - """tmp.""" - self.swap_ff_with_fxff(model) - graph = self.tracer.trace(model) - graph_module = build_graphmodule(model, graph) - observed_model = self.prepare(model, graph_module) + model: torch.nn.Module, + dummy_input: Tuple = (1, 3, 224, 224), + checkpoint: Optional[str] = None): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + observed_model = self.prepare(model) if dummy_input is not None: observed_model(torch.randn(dummy_input)) if checkpoint is not None: @@ -59,20 +74,24 @@ def prepare_for_mmdeploy(self, @property def module_prev_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their previous nodes are redundant + fakequants.""" return (torch.nn.ReLU6, torch.nn.Identity) @property def module_next_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their next nodes are redundant + fakequants.""" return (torch.nn.MaxPool2d, ) @property def method_next_wo_fakequant(self): - """tmp.""" + """Configurate the methods that their next nodes are redundant + fakequants.""" return ('flatten', ) @property def op_prev_wo_fakequant(self): - """tmp.""" + """Configurate the OPs that their previous nodes are redundant + fakequants.""" return ('output', ) diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py index 36e3f2be7..028c96a8c 100644 --- a/mmrazor/models/quantizers/tensorrt_quantizer.py +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + import torch try: @@ -7,50 +9,55 @@ from mmrazor.utils import get_placeholder disable_observer = get_placeholder('torch>=1.13') -from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ - build_graphmodule from mmrazor.registry import MODELS from .native_quantizer import NativeQuantizer @MODELS.register_module() class TensorRTQuantizer(NativeQuantizer): - """Quantizer for TensorRT backend.""" + """Quantizer for quantizing and deploying to TensorRT backend. - # backend: 'tensorrt' - # support_w_mode = ['per_tensor', 'per_channel'] - # support_a_mode = ['per_tensor'] + Each backend has its own features, for reducing the gap of quantized + performance between before and after deployment as possible, we should + match the backend's features in quantization. - def __init__(self, - global_qconfig, - no_observer_modules=None, - tracer=dict(type='CustomTracer')): - super().__init__(global_qconfig, no_observer_modules, tracer) + TensorRT's some important features about quantization is as follows: + * support_w_mode = ('per_tensor', 'per_channel') + * support_a_mode = ('per_tensor') + """ @property def backend(self): - """tmp.""" + """The backend to deploy, also the key of the corresponding backend + config.""" return 'tensorrt' @property def support_w_modes(self): - """tmp.""" - return ['per_tensor', 'per_channel'] + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') @property def support_a_modes(self): - """tmp.""" - return ['per_tensor'] + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') def prepare_for_mmdeploy(self, - model, - dummy_input=(1, 3, 224, 224), - checkpoint=None): - """tmp.""" - self.swap_ff_with_fxff(model) - graph = self.tracer.trace(model) - graph_module = build_graphmodule(model, graph) - observed_model = self.prepare(model, graph_module) + model: torch.nn.Module, + dummy_input: Tuple = (1, 3, 224, 224), + checkpoint: Optional[str] = None): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + observed_model = self.prepare(model) if dummy_input is not None: observed_model(torch.randn(dummy_input)) if checkpoint is not None: diff --git a/mmrazor/testing/_fx_models.py b/mmrazor/testing/_fx_models.py index 969c4792d..6bf42e16a 100644 --- a/mmrazor/testing/_fx_models.py +++ b/mmrazor/testing/_fx_models.py @@ -34,6 +34,8 @@ def __init__( stride, padding, dilation, groups, bias, conv_cfg, norm_cfg, act_cfg, inplace, with_spectral_norm, padding_mode, order) + self.toy_attr1 = 1 + self.toy_attr2 = 2 def forward(self, x): x = self.conv_module.conv(x) diff --git a/tests/test_models/test_quantizers/test_academic_quantizer.py b/tests/test_models/test_quantizers/test_academic_quantizer.py new file mode 100644 index 000000000..c95060a00 --- /dev/null +++ b/tests/test_models/test_quantizers/test_academic_quantizer.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import copy +from unittest import TestCase + +import torch +from mmengine.model import BaseModule + +try: + from torch.ao.nn.intrinsic import ConvBnReLU2d + from torch.ao.quantization.backend_config import BackendConfig + from torch.ao.quantization.fx.custom_config import PrepareCustomConfig + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quant_type import QuantType +except ImportError: + from mmrazor.utils import get_placeholder + ConvBnReLU2d = get_placeholder('torch>=1.13') + BackendConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + ConObservedGraphModuleBnReLU2d = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + QuantType = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import AcademicQuantizer +from mmrazor.models.quantizers.academic_quantizer import ( + FLOAT_TO_OBSERVED_DICT_KEY, GLOBAL_DICT_KEY, MODULE_NAME_DICT_KEY, + OBJECT_TYPE_DICT_KEY, PRESERVED_ATTRIBUTES_DICT_KEY) +from mmrazor.registry import MODELS +from mmrazor.testing import ConvBNReLU + + +@MODELS.register_module() +class ToyFloatModel(BaseModule): + + def __init__(self) -> None: + super().__init__() + + +@MODELS.register_module() +class ToyObservedModel(BaseModule): + + def __init__(self) -> None: + super().__init__() + + +class TestAcademicQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=4, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=4, is_symmetry=True), + ) + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def test_gen_qconfig_mapping(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test set GLOBAL_DICT_KEY by QConfigMapping + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.global_qconfig + + # test set OBJECT_TYPE_DICT_KEY by QConfigMapping + qconfig = copy(self.qconfig) + qconfig_mapping = { + OBJECT_TYPE_DICT_KEY: + [('torch.ao.nn.intrinsic.ConvBnReLU2d', qconfig)] + } + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.object_type_qconfigs.get(ConvBnReLU2d) + + # test set MODULE_NAME_DICT_KEY by QConfigMapping + qconfig = copy(self.qconfig) + qconfig_mapping = { + MODULE_NAME_DICT_KEY: [('conv_module.conv', qconfig)] + } + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.module_name_qconfigs.get( + 'conv_module.conv') + + def test_gen_prepare_custom_config(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test prepare_custom_config is None + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'prepare_custom_config') + assert isinstance(quantizer.prepare_custom_config, PrepareCustomConfig) + + # test set FLOAT_TO_OBSERVED_DICT_KEY and PRESERVED_ATTRIBUTES_DICT_KEY + # by PrepareCustomConfig + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + flop_to_observed_list = [('ToyFloatModel', 'ToyObservedModel')] + preserved_attributes_list = ['toy_attr1', 'toy_attr2'] + prepare_custom_config = { + FLOAT_TO_OBSERVED_DICT_KEY: flop_to_observed_list, + PRESERVED_ATTRIBUTES_DICT_KEY: preserved_attributes_list + } + quantizer = AcademicQuantizer( + qconfig_mapping=qconfig_mapping, + prepare_custom_config=prepare_custom_config) + + assert hasattr(quantizer, 'prepare_custom_config') + assert isinstance(quantizer.prepare_custom_config, PrepareCustomConfig) + mapping = quantizer.prepare_custom_config.float_to_observed_mapping[ + QuantType.STATIC] + assert mapping.get(ToyFloatModel) + assert mapping[ToyFloatModel] == ToyObservedModel + + attributes = quantizer.prepare_custom_config.preserved_attributes + assert attributes == preserved_attributes_list + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'backend_config') + assert isinstance(quantizer.backend_config, BackendConfig) + + def test_prepare(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + preserved_attributes_list = ['toy_attr1', 'toy_attr2'] + prepare_custom_config = { + PRESERVED_ATTRIBUTES_DICT_KEY: preserved_attributes_list + } + quantizer = AcademicQuantizer( + qconfig_mapping=qconfig_mapping, + prepare_custom_config=prepare_custom_config) + model = copy(self.model) + prepared = quantizer.prepare(model) + assert isinstance(prepared, ObservedGraphModule) + assert hasattr(prepared, 'toy_attr1') + assert hasattr(prepared, 'toy_attr2') diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py index afd6011ed..62052f66f 100644 --- a/tests/test_models/test_quantizers/test_native_quantizer.py +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -11,7 +11,7 @@ from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ build_graphmodule from mmrazor.registry import MODELS -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler try: from torch.ao.quantization.fx import prepare @@ -127,7 +127,7 @@ def setUp(self): self.q_kwargs = q_kwargs self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] - self.qconfig = QConfigHander(global_qconfig) + self.qconfig = QConfigHandler(global_qconfig) self.qconfig_mapping = QConfigMapping().set_global( self.qconfig.convert()) self.example_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/tests/test_models/test_quantizers/test_openvino_quantizer.py b/tests/test_models/test_quantizers/test_openvino_quantizer.py new file mode 100644 index 000000000..24fc81ca4 --- /dev/null +++ b/tests/test_models/test_quantizers/test_openvino_quantizer.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import tempfile +from copy import copy +from unittest import TestCase + +import torch + +try: + from torch.ao.quantization.fx.graph_module import ObservedGraphModule +except ImportError: + from mmrazor.utils import get_placeholder + ObservedGraphModule = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import OpenVINOQuantizer +from mmrazor.testing import ConvBNReLU + + +class TestOpenVINOQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.temp_dir = tempfile.mkdtemp() + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + shutil.rmtree(self.temp_dir) + + def test_property(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = OpenVINOQuantizer(global_qconfig=global_qconfig) + assert quantizer.backend == 'openvino' + assert quantizer.support_w_modes == ('per_tensor', 'per_channel') + assert quantizer.support_a_modes == ('per_tensor') + assert quantizer.module_prev_wo_fakequant + assert quantizer.module_next_wo_fakequant + assert quantizer.method_next_wo_fakequant + assert quantizer.op_prev_wo_fakequant + + def test_prepare_for_mmdeploy(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = OpenVINOQuantizer(global_qconfig=global_qconfig) + model = copy(self.model) + + # test checkpoint is None + prepared_deploy = quantizer.prepare_for_mmdeploy(model=model) + assert isinstance(prepared_deploy, ObservedGraphModule) + + # test checkpoint is not None + ckpt_path = os.path.join(self.temp_dir, + 'test_prepare_for_mmdeploy.pth') + model = copy(self.model) + prepared = quantizer.prepare(model) + torch.save({'state_dict': prepared.state_dict()}, ckpt_path) + prepared_deploy = quantizer.prepare_for_mmdeploy( + model=model, checkpoint=ckpt_path) + assert isinstance(prepared_deploy, ObservedGraphModule) diff --git a/tests/test_models/test_quantizers/test_tensorrt_quantizer.py b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py new file mode 100644 index 000000000..aeae311f3 --- /dev/null +++ b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import tempfile +from copy import copy +from unittest import TestCase + +import torch + +try: + from torch.ao.quantization.fx.graph_module import ObservedGraphModule +except ImportError: + from mmrazor.utils import get_placeholder + ObservedGraphModule = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import TensorRTQuantizer +from mmrazor.testing import ConvBNReLU + + +class TestTensorRTQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.temp_dir = tempfile.mkdtemp() + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + shutil.rmtree(self.temp_dir) + + def test_property(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = TensorRTQuantizer(global_qconfig=global_qconfig) + assert quantizer.backend == 'tensorrt' + assert quantizer.support_w_modes == ('per_tensor', 'per_channel') + assert quantizer.support_a_modes == ('per_tensor') + + def test_prepare_for_mmdeploy(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = TensorRTQuantizer(global_qconfig=global_qconfig) + model = copy(self.model) + + # test checkpoint is None + prepared_deploy = quantizer.prepare_for_mmdeploy(model=model) + assert isinstance(prepared_deploy, ObservedGraphModule) + + # test checkpoint is not None + ckpt_path = os.path.join(self.temp_dir, + 'test_prepare_for_mmdeploy.pth') + model = copy(self.model) + prepared = quantizer.prepare(model) + torch.save({'state_dict': prepared.state_dict()}, ckpt_path) + prepared_deploy = quantizer.prepare_for_mmdeploy( + model=model, checkpoint=ckpt_path) + assert isinstance(prepared_deploy, ObservedGraphModule) diff --git a/tests/test_models/test_quantizers/test_trt_quantizer.py b/tests/test_models/test_quantizers/test_trt_quantizer.py deleted file mode 100644 index 9f85d1ecd..000000000 --- a/tests/test_models/test_quantizers/test_trt_quantizer.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch.nn as nn - - -class ToyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - # TODO - - -class TestTRTQuantizer(TestCase): - """TODO. - - Args: - TestCase (_type_): _description_ - """ - - def test_init(self): - pass - - def test_prepare(self): - pass - - def test_convert(self): - pass - - def test_states(self): - pass - - def test_forward(self): - pass diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py index 207e9ccad..fcb02f381 100644 --- a/tests/test_models/test_task_modules/test_custom_tracer.py +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -51,7 +51,6 @@ def test_init(self): method_registry = UntracedMethodRegistry(method) assert hasattr(method_registry, 'method') assert hasattr(method_registry, 'method_dict') - assert len(method_registry.method_dict) == 0 def test_registry_method(self): if digit_version(torch.__version__) < digit_version('1.13.0'):