Skip to content

Commit

Permalink
[Feature] Add base transform interface (#1538)
Browse files Browse the repository at this point in the history
* Support deepcopy for Config (#1658)

* Support deepcopy for Config

* Iterate the `__dict__` of Config directly.

* Use __new__ to avoid unnecessary initialization.

* Improve according to comments

* [Feature] Add spconv ops from mmdet3d (#1581)

* add ops (spconv) of mmdet3d

* fix typo

* refactor code

* resolve comments in #1452

* fix compile error

* fix bugs

* fix bug

* transform from 'types.h' to 'extension.h'

* fix bug

* transform from 'types.h' to 'extension.h' in parrots

* add extension.h in pybind.cpp

* add unittest

* Recover code

* (1) Remove prettyprint.h
(2) Switch `T` to `scalar_t`
(3) Remove useless lines
(4) Refine example in docstring of sparse_modules.py

* (1) rename from `cu.h` to `cuh`
(2) remove useless files
(3) move cpu files to `pytorch/cpu`

* reorganize files

* Add docstring for sparse_functional.py

* use dispatcher

* remove template

* use dispatch in cuda ops

* resolve Segmentation fault

* remove useless files

* fix lint

* fix lint

* fix lint

* fix unittest in test_build_layers.py

* add tensorview into include_dirs when compiling

* recover all deleted files

* fix lint and comments

* recover setup.py

* replace tv::GPU as tv::TorchGPU & support device guard

* fix lint

Co-authored-by: hdc <hudingchang.vendor@sensetime.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>

* Imporve the docstring of imfrombytes and fix a deprecation-warning (#1731)

* [Refactor] Refactor the interface for RoIAlignRotated (#1662)

* fix interface for RoIAlignRotated

* Add a unit test for RoIAlignRotated

* Make a unit test for RoIAlignRotated concise

* fix interface for RoIAlignRotated

* Refactor ext_module.nms_rotated

* Lint cpp files

* add transforms

* add invoking time check for cacheable methods

* fix lint

* add unittest

* fix bug in non-strict input mapping

* fix ci

* fix ci

* fix compatibility with python<3.9

* fix typing compatibility

* fix import

* fix typing

* add alternative for nullcontext

* fix import

* fix import

* add docstrings

* add docstrings

* fix callable check

* resolve comments

* fix lint

* enrich unittest cases

* fix lint

* fix unittest

Co-authored-by: Ma Zerun <mzr1996@163.com>
Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com>
Co-authored-by: hdc <hudingchang.vendor@sensetime.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
Co-authored-by: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com>
Co-authored-by: Hakjin Lee <nijkah@gmail.com>
  • Loading branch information
7 people authored and zhouzaida committed Jul 19, 2022
1 parent 8b47579 commit d00b0ce
Show file tree
Hide file tree
Showing 9 changed files with 1,036 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/en/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ ops
------
.. automodule:: mmcv.ops
:members:

transform
---------
.. automodule:: mmcv.transform
:members:
5 changes: 5 additions & 0 deletions docs/zh_cn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,8 @@ ops
------
.. automodule:: mmcv.ops
:members:

transform
---------
.. automodule:: mmcv.transform
:members:
1 change: 1 addition & 0 deletions mmcv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .arraymisc import *
from .fileio import *
from .image import *
from .transform import *
from .utils import *
from .version import *
from .video import *
Expand Down
5 changes: 5 additions & 0 deletions mmcv/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import TRANSFORMS
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap

__all__ = ['TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap']
27 changes: 27 additions & 0 deletions mmcv/transform/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict


class BaseTransform(metaclass=ABCMeta):

def __call__(self, results: Dict) -> Dict:

return self.transform(results)

@abstractmethod
def transform(self, results: Dict) -> Dict:
"""The transform function. All subclass of BaseTransform should
override this method.
This function takes the result dict as the input, and can add new
items to the dict or modify existing items in the dict. And the result
dict will be returned in the end, which allows to concate multiple
transforms into a pipeline.
Args:
results (dict): The result dict.
Returns:
dict: The result dict.
"""
4 changes: 4 additions & 0 deletions mmcv/transform/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from ..utils.registry import Registry

TRANSFORMS = Registry('transform')
162 changes: 162 additions & 0 deletions mmcv/transform/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) OpenMMLab. All rights reserved.

import functools
import inspect
import weakref
from collections import defaultdict
from collections.abc import Iterable
from contextlib import contextmanager
from typing import Callable, Union

from .base import BaseTransform


class cacheable_method:
"""Decorator that marks a method of a transform class as a cacheable
method.
This decorator is usually used together with the context-manager
:func`:cache_random_params`. In this context, a cacheable method will
cache its return value(s) at the first time of being invoked, and always
return the cached values when being invoked again.
.. note::
Only a instance method can be decorated as a cacheable_method.
"""

def __init__(self, func):

# Check `func` is to be bound as an instance method
if not inspect.isfunction(func):
raise TypeError('Unsupport callable to decorate with'
'@cacheable_method.')
func_args = inspect.getfullargspec(func).args
if len(func_args) == 0 or func_args[0] != 'self':
raise TypeError(
'@cacheable_method should only be used to decorate '
'instance methods (the first argument is `self`).')

functools.update_wrapper(self, func)
self.func = func
self.instance_ref = None

def __set_name__(self, owner, name):
# Maintain a record of decorated methods in the class
if not hasattr(owner, '_cacheable_methods'):
setattr(owner, '_cacheable_methods', [])
owner._cacheable_methods.append(self.__name__)

def __call__(self, *args, **kwargs):
# Get the transform instance whose method is decorated
# by cacheable_method
instance = self.instance_ref()
name = self.__name__

# Check the flag `self._cache_enabled`, which should be
# set by the contextmanagers like `cache_random_parameters`
cache_enabled = getattr(instance, '_cache_enabled', False)

if cache_enabled:
# Initialize the cache of the transform instances. The flag
# `cache_enabled` is set by contextmanagers like
# `cache_random_params`.
if not hasattr(instance, '_cache'):
setattr(instance, '_cache', {})

if name not in instance._cache:
instance._cache[name] = self.func(instance, *args, **kwargs)
# Return the cached value
return instance._cache[name]
else:
# Clear cache
if hasattr(instance, '_cache'):
del instance._cache
# Return function output
return self.func(instance, *args, **kwargs)

def __get__(self, obj, cls):
self.instance_ref = weakref.ref(obj)
return self


@contextmanager
def cache_random_params(transforms: Union[BaseTransform, Iterable]):
"""Context-manager that enables the cache of cacheable methods in
transforms.
In this mode, cacheable methods will cache their return values on the
first invoking, and always return the cached value afterward. This allow
to apply random transforms in a deterministic way. For example, apply same
transforms on multiple examples. See `cacheable_method` for more
information.
Args:
transforms (BaseTransform|list[BaseTransform]): The transforms to
enable cache.
"""

# key2method stores the original methods that are replaced by the wrapped
# ones. These methods will be restituted when exiting the context.
key2method = dict()

# key2counter stores the usage number of each cacheable_method. This is
# used to check that any cacheable_method is invoked once during processing
# on data sample.
key2counter = defaultdict(int)

def _add_counter(obj, method_name):
method = getattr(obj, method_name)
key = f'{id(obj)}.{method_name}'
key2method[key] = method

@functools.wraps(method)
def wrapped(*args, **kwargs):
key2counter[key] += 1
return method(*args, **kwargs)

return wrapped

def _start_cache(t: BaseTransform):
# Set cache enabled flag
setattr(t, '_cache_enabled', True)

# Store the original method and init the counter
if hasattr(t, '_cacheable_methods'):
setattr(t, 'transform', _add_counter(t, 'transform'))
for name in t._cacheable_methods:
setattr(t, name, _add_counter(t, name))

def _end_cache(t: BaseTransform):
# Remove cache enabled flag
del t._cache_enabled
if hasattr(t, '_cache'):
del t._cache

# Restore the original method
if hasattr(t, '_cacheable_methods'):
key_transform = f'{id(t)}.transform'
for name in t._cacheable_methods:
key = f'{id(t)}.{name}'
if key2counter[key] != key2counter[key_transform]:
raise RuntimeError(
'The cacheable method should be called once and only'
f'once during processing one data sample. {t} got'
f'unmatched number of {key2counter[key]} ({name}) vs'
f'{key2counter[key_transform]} (data samples)')
setattr(t, name, key2method[key])
setattr(t, 'transform', key2method[key_transform])

def _apply(t: Union[BaseTransform, Iterable],
func: Callable[[BaseTransform], None]):
if isinstance(t, BaseTransform):
if hasattr(t, '_cacheable_methods'):
func(t)
if isinstance(t, Iterable):
for _t in t:
_apply(_t, func)

try:
_apply(transforms, _start_cache)
yield
finally:
_apply(transforms, _end_cache)
Loading

0 comments on commit d00b0ce

Please sign in to comment.