Skip to content

Commit

Permalink
[Docs] Add docstring and unitest about custom tracer (open-mmlab#427)
Browse files Browse the repository at this point in the history
* rename QConfigHandler and QSchemeHandler

* add docstring about custom tracer

* add ut about custom tracer

* fix torch1.13 ci

* fix lint

* fix ci

* fix ci
  • Loading branch information
humu789 authored and humu789 committed Apr 11, 2023
1 parent ff56626 commit 35faed8
Show file tree
Hide file tree
Showing 8 changed files with 357 additions and 120 deletions.
13 changes: 7 additions & 6 deletions mmrazor/models/quantizers/academic_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from mmrazor.registry import MODELS
from mmrazor.structures.quantization import BackendConfigs, QConfigHander
from mmrazor.structures.quantization import BackendConfigs, QConfigHandler
from .base import BaseQuantizer

try:
Expand Down Expand Up @@ -75,24 +75,25 @@ def gen_qconfig_mapping(self, qconfig_mapping):
"""tmp."""
conf = QConfigMapping()
if GLOBAL_DICT_KEY in qconfig_mapping:
qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert()
qconfig = QConfigHandler(
qconfig_mapping[GLOBAL_DICT_KEY]).convert()
conf.set_global(qconfig)
for object_type, qconfig in qconfig_mapping.get(
OBJECT_TYPE_DICT_KEY, []):
qconfig = QConfigHander(qconfig).convert()
qconfig = QConfigHandler(qconfig).convert()
conf.set_object_type(object_type, qconfig)

for module_name_regex, qconfig in qconfig_mapping.get(
MODULE_NAME_REGEX_DICT_KEY, []):
qconfig = QConfigHander(qconfig).convert()
qconfig = QConfigHandler(qconfig).convert()
conf.set_module_name_regex(module_name_regex, qconfig)
for module_name, qconfig in qconfig_mapping.get(
MODULE_NAME_DICT_KEY, []):
qconfig = QConfigHander(qconfig).convert()
qconfig = QConfigHandler(qconfig).convert()
conf.set_module_name(module_name, qconfig)
for module_name, object_type, index, qconfig in qconfig_mapping.get(
MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
qconfig = QConfigHander(qconfig).convert()
qconfig = QConfigHandler(qconfig).convert()
conf.set_module_name_object_type_order(module_name, object_type,
index, qconfig)

Expand Down
4 changes: 2 additions & 2 deletions mmrazor/models/quantizers/native_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
del_fakequant_before_module, del_fakequant_before_op)
from mmrazor.models.utils import str2class
from mmrazor.registry import MODELS
from mmrazor.structures.quantization import BackendConfigs, QConfigHander
from mmrazor.structures.quantization import BackendConfigs, QConfigHandler
from .base import BaseQuantizer

if digit_version(torch.__version__) >= digit_version('1.13.0'):
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(self,
extra_op_prev_wo_fakequant=tuple(),
extra_op_next_wo_fakequant=tuple())):
super().__init__(tracer)
self.qconfig = QConfigHander(global_qconfig)
self.qconfig = QConfigHandler(global_qconfig)
if self.qconfig.w_qscheme.is_per_channel:
w_mode = 'per_channel'
else:
Expand Down
204 changes: 137 additions & 67 deletions mmrazor/models/task_modules/tracer/fx/custom_tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import functools
from types import FunctionType, MethodType
from typing import Any, Callable, Dict, List, Optional, Type, Union
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -34,18 +34,24 @@


class UntracedMethodRegistry:
"""A `Descriptor` class which records untraced methods."""
"""A `Descriptor` class which records untraced methods. Thus, when the
class is traced with CustomTracer, the decorated method will be as a leaf
node, not be nested traced.
Example:
>>> # `imported_cls` is the owner of the untraced method;
>>> # `method_str` is the name of the untraced method.
>>> method_registry = UntracedMethodRegistry(method)
>>> method_registry.__set_name__(imported_cls, method_str)
Args:
method (FunctionType): Function to be registered.
"""
method_dict: Dict = dict()
tracer = None

def __init__(self, method):
"""_summary_
Args:
method (FunctionType): Function to be registered.
"""
def __init__(self, method: FunctionType):
self.method = method
self.instances: Dict = dict()
self.owner = None

def __set_name__(self, owner, name):
Expand All @@ -54,11 +60,6 @@ def __set_name__(self, owner, name):
wrapped = self.method_wrapper()
self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped)

def __get__(self, instance, owner):
if instance is None:
return self.method
return MethodType(self.method, instance)

def method_wrapper(self):

@functools.wraps(self.method)
Expand All @@ -73,33 +74,12 @@ def method(*args, **kwargs):
return wrapped_method


def custom_symbolic_trace(root: Union[nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None):
"""Modified `symbolic_trace` function.
Args:
root (Union[nn.Module, Callable]): Module or function to be
traced and converted into a Graph representation.
concrete_args (Optional[Dict[str, any]]): Inputs to be partially
specialized.
Returns:
_type_: _description_
"""
tracer = CustomTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(root,
nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)


def _prepare_module_dict(model: nn.Module, fx_graph):
def _prepare_module_dict(model: torch.nn.Module, fx_graph):
"""If there is a class method that can not be traced by the symbolic
tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in
``CustomTracer``.
For example,
```
Example:
>>> class Model:
... def __init__(self):
... self.head = ClsHead()
Expand All @@ -123,7 +103,7 @@ def _prepare_module_dict(model: nn.Module, fx_graph):
... xxx
... losses = xxx
... return losses
```
As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need
to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code
above will product the following Graph::
Expand All @@ -140,8 +120,10 @@ def _prepare_module_dict(model: nn.Module, fx_graph):
the original model.
Args:
model (nn.Module): The original model.
fx_graph (Graph): The fx Graph traced by fx tracer.
model (torch.nn.Module): Module or function to be
traced and converted into a Graph representation.
fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It
contains the nodes this GraphModule should use for code generation.
"""

def _get_attrs(target, attrs):
Expand Down Expand Up @@ -170,7 +152,32 @@ def _get_attrs(target, attrs):
return module_dict


def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'):
def build_graphmodule(model: torch.nn.Module,
fx_graph,
name: str = 'GraphModule'):
"""To build GraphModule with the generated graph by CustomTracer. The
implement of skipping methods in CustomTracer will cause the confliction of
that a node is both a leaf node and non-leaf node, which will lead that the
modification to the ``graph`` also change the original ``forward``.
Args:
model (torch.nn.Module): Module or function to be
traced and converted into a Graph representation.
fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It
contains the nodes this GraphModule should use for code generation.
name (str): The name of generated GraphModule.
Returns:
GraphModule: GraphModule is an nn.Module generated from an fx.Graph.
Graphmodule has a ``graph`` attribute, as well as ``code`` and
``forward`` attributes generated from that ``graph``.
.. warning::
When ``graph`` is reassigned, ``code`` and ``forward`` will be
automatically regenerated. However, if you edit the contents of the
``graph`` without reassigning the ``graph`` attribute itself, you must
call ``recompile()`` to update the generated code.
"""
modules = dict(model.named_modules())
module_dict = _prepare_module_dict(model, fx_graph)
modules.update(module_dict)
Expand All @@ -179,23 +186,25 @@ def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'):

@TASK_UTILS.register_module()
class CustomTracer(QuantizationTracer):
"""Custom tracer based on QuantizationTracer of pytorch. It can not only
skip some modules and classes while tracing, but also skip some methods
untraced by torch.fx.Tracer.
Args:
skipped_methods (List[str], optional): Methods to be skipped while
tracing. Defaults to None.
skipped_module_names (List[str], optional): Modules to be skipped
while tracing. Defaults to None.
skipped_module_classes (List[Callable], optional): Class to be skipped
while tracing. Defaults to None.
"""

def __init__(self,
skipped_methods: List[str] = [],
skipped_module_names: List[str] = [],
skipped_module_classes: List[Callable] = [],
*args,
**kwargs):
"""_summary_
Args:
skipped_methods (List[str], optional): Methods to be skipped while
tracing. Defaults to None.
skipped_module_names (List[str], optional): Modules to be skipped
while tracing. Defaults to None.
skipped_module_classes (List[str], optional): Class to be skipped
while tracing. Defaults to None.
"""
super(CustomTracer, self).__init__(skipped_module_names,
skipped_module_classes)
UntracedMethodRegistry.tracer = self # type: ignore
Expand All @@ -214,6 +223,7 @@ def _check_valid_source(source):
'source must have at least one `.`'

def register_skipped_methods(self):
"""Register skipped methods to UntracedMethodRegistry.method_dict."""
if not isinstance(self.skipped_methods, list):
self.skipped_methods = [self.skipped_methods]
for s_method in self.skipped_methods:
Expand All @@ -239,7 +249,8 @@ def register_skipped_methods(self):
method_registry = UntracedMethodRegistry(method)
method_registry.__set_name__(imported_cls, method_str)

def call_method(self, m: nn.Module, name, method, args, kwargs):
def call_method(self, m: torch.nn.Module, name: str, method: Callable,
args: Tuple, kwargs: Dict):
"""Method that specifies the behavior of this ``Tracer`` when it
encounters a call to an ``nn.Module`` instance.
Expand All @@ -254,15 +265,13 @@ def call_method(self, m: nn.Module, name, method, args, kwargs):
``Module`` boundaries.
Args:
m (Module): The module for which a call is being emitted
forward (Callable): The forward() method of the ``Module`` to be
invoked
m (torch.nn.Module): The module for which a call is being emitted
name (str): The name of proxy to be created.
method (Callable): The method of the ``Module`` to be invoked
args (Tuple): args of the module callsite
kwargs (Dict): kwargs of the module callsite
Return:
The return value from the Module call. In the case that a
``call_module`` node was emitted, this is a ``Proxy`` value.
Otherwise, it is whatever value was returned from the ``Module``
Expand All @@ -271,16 +280,37 @@ def call_method(self, m: nn.Module, name, method, args, kwargs):
# module_qualified_name = self.path_of_module(m)
if not self.is_skipped_method(m):
return method(*args, **kwargs)
args = list(args)
args.insert(0, m)
args = tuple(args)
args_l = list(args)
args_l.insert(0, m)
args = tuple(args_l)
return self.create_proxy('call_method', name, args, kwargs)

def trace(self, root, concrete_args=None):
if isinstance(root, nn.Module):
def trace(self,
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
"""Trace ``root`` and return the corresponding FX ``Graph``
representation. ``root`` can either be an ``nn.Module`` instance or a
Python callable. Note that after this call, ``self.root`` may be
different from the ``root`` passed in here. For example, when a free
function is passed to ``trace()``, we will create an ``nn.Module``
instance to use as the root and add embedded constants to.
Args:
root (Union[Module, Callable]): Either a ``Module`` or a function
to be traced through. Backwards-compatibility for this
parameter is guaranteed.
concrete_args (Optional[Dict[str, any]]): Concrete arguments that
should not be treated as Proxies. This parameter is
experimental and its backwards-compatibility is *NOT*
guaranteed.
Returns:
A ``Graph`` representing the semantics of the passed-in ``root``.
"""
if isinstance(root, torch.nn.Module):
self.root = root
fn = type(root).forward
self.submodule_paths = {
self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = {
mod: name
for name, mod in root.named_modules()
}
Expand Down Expand Up @@ -364,13 +394,53 @@ def forward(*args, **kwargs):

return self.graph

def is_skipped_method(self, m):
def is_skipped_method(self, m: torch.nn.Module):
"""Judge if ``m`` is registered skipped method."""
mods = tuple(value['mod']
for value in UntracedMethodRegistry.method_dict.values())
custom = isinstance(m, mods)
return custom

def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
# return super().is_leaf_module(m, module_qualified_name)
def is_leaf_module(self, m: torch.nn.Module,
module_qualified_name: str) -> bool:
"""A method to specify whether a given ``nn.Module`` is a "leaf"
module. Leaf modules are the atomic units that appear in the IR,
referenced by ``call_module`` calls. By default, Modules in the PyTorch
standard library namespace (torch.nn) are leaf modules. All other
modules are traced through and their constituent ops are recorded,
unless specified otherwise via this parameter.
Args:
m (Module): The module being queried about
module_qualified_name (str): The path to root of this module.
For example, if you have a module hierarchy where submodule
``foo`` contains submodule ``bar``, which contains submodule
``baz``, that module will appear with the qualified name
``foo.bar.baz`` here.
"""
leaf = super().is_leaf_module(m, module_qualified_name)
return leaf


def custom_symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:
"""Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or
function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
Args:
root (torch.nn.Module): Module or function to be
traced and converted into a Graph representation.
concrete_args (Optional[Dict[str, any]]): Inputs to be partially
specialized.
Returns:
GraphModule: a Module created from the recorded operations from
``root``.
"""
tracer = CustomTracer()
graph = tracer.trace(root, concrete_args)
name = root.__class__.__name__ if isinstance(
root, torch.nn.Module) else root.__name__
return GraphModule(tracer.root, graph, name)
Loading

0 comments on commit 35faed8

Please sign in to comment.