Skip to content

Latest commit

 

History

History
93 lines (68 loc) · 4.56 KB

support_new_model.md

File metadata and controls

93 lines (68 loc) · 4.56 KB

How to support new models

We provide several tools to support model conversion.

Function Rewriter

The PyTorch neural network is written in python that eases the development of the algorithm. But the use of Python control flow and third-party libraries make it difficult to export the network to an intermediate representation. We provide a 'monkey patch' tool to rewrite the unsupported function to another one that can be exported. Here is an example:

from mmdeploy.core import FUNCTION_REWRITER

@FUNCTION_REWRITER.register_rewriter(
    func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(ctx, input, *size):
    origin_func = ctx.origin_func
    if input.dim() == 1 and len(size) == 1:
        return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
    else:
        return origin_func(input, *size)

It is easy to use the function rewriter. Just add a decorator with arguments:

  • func_name is the function to override. It can be either a PyTorch function or a custom function. Methods in modules can also be overridden by this tool.
  • backend is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.

The arguments are the same as the original function, except a context ctx as the first argument. The context provides some useful information such as the deployment config ctx.cfg and the original function (which has been overridden) ctx.origin_func.

Module Rewriter

If you want to replace a whole module with another one, we have another rewriter as follows:

@MODULE_REWRITER.register_rewrite_module(
    'mmedit.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNNWrapper(nn.Module):

    def __init__(self,
                 module,
                 cfg,
                 channels=(3, 64, 32, 3),
                 kernel_sizes=(9, 1, 5),
                 upscale_factor=4):
        super(SRCNNWrapper, self).__init__()

        self._module = module

        module.img_upsampler = nn.Upsample(
            scale_factor=module.upscale_factor,
            mode='bilinear',
            align_corners=False)

    def forward(self, *args, **kwargs):
        """Run forward."""
        return self._module(*args, **kwargs)

    def init_weights(self, *args, **kwargs):
        """Initialize weights."""
        return self._module.init_weights(*args, **kwargs)

Just like function rewriter, add a decorator with arguments:

  • module_type the module class to rewrite.
  • backend is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.

All instances of the module in the network will be replaced with instances of this new class. The original module and the deployment config will be passed as the first two arguments.

Custom Symbolic

The mappings between PyTorch and ONNX are defined in PyTorch with symbolic functions. The custom symbolic function can help us to bypass some ONNX nodes which are unsupported by inference engine.

@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(ctx, g, self, dim=None):
    if dim is None:
        dims = []
        for i, size in enumerate(self.type().sizes()):
            if size == 1:
                dims.append(i)
    else:
        dims = [sym_help._get_const(dim, 'i', 'dim')]
    return g.op('Squeeze', self, axes_i=dims)

The decorator arguments:

  • func_name The function name to add symbolic. Use full path if it is a custom torch.autograd.Function. Or just a name if it is a PyTorch built-in function.
  • backend is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
  • is_pytorch True if the function is a PyTorch built-in function.
  • arg_descriptors the descriptors of the symbolic function arguments. Will be feed to torch.onnx.symbolic_helper._parse_arg.

Just like function rewriter, there is a context ctx as the first argument. The context provides some useful information such as the deployment config ctx.cfg and the original function (which has been overridden) ctx.origin_func. Note that the ctx.origin_func can be used only when is_pytorch==False.