forked from open-mmlab/mmcv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add base transform interface (open-mmlab#1538)
* Support deepcopy for Config (open-mmlab#1658) * Support deepcopy for Config * Iterate the `__dict__` of Config directly. * Use __new__ to avoid unnecessary initialization. * Improve according to comments * [Feature] Add spconv ops from mmdet3d (open-mmlab#1581) * add ops (spconv) of mmdet3d * fix typo * refactor code * resolve comments in open-mmlab#1452 * fix compile error * fix bugs * fix bug * transform from 'types.h' to 'extension.h' * fix bug * transform from 'types.h' to 'extension.h' in parrots * add extension.h in pybind.cpp * add unittest * Recover code * (1) Remove prettyprint.h (2) Switch `T` to `scalar_t` (3) Remove useless lines (4) Refine example in docstring of sparse_modules.py * (1) rename from `cu.h` to `cuh` (2) remove useless files (3) move cpu files to `pytorch/cpu` * reorganize files * Add docstring for sparse_functional.py * use dispatcher * remove template * use dispatch in cuda ops * resolve Segmentation fault * remove useless files * fix lint * fix lint * fix lint * fix unittest in test_build_layers.py * add tensorview into include_dirs when compiling * recover all deleted files * fix lint and comments * recover setup.py * replace tv::GPU as tv::TorchGPU & support device guard * fix lint Co-authored-by: hdc <hudingchang.vendor@sensetime.com> Co-authored-by: grimoire <yaoqian@sensetime.com> * Imporve the docstring of imfrombytes and fix a deprecation-warning (open-mmlab#1731) * [Refactor] Refactor the interface for RoIAlignRotated (open-mmlab#1662) * fix interface for RoIAlignRotated * Add a unit test for RoIAlignRotated * Make a unit test for RoIAlignRotated concise * fix interface for RoIAlignRotated * Refactor ext_module.nms_rotated * Lint cpp files * add transforms * add invoking time check for cacheable methods * fix lint * add unittest * fix bug in non-strict input mapping * fix ci * fix ci * fix compatibility with python<3.9 * fix typing compatibility * fix import * fix typing * add alternative for nullcontext * fix import * fix import * add docstrings * add docstrings * fix callable check * resolve comments * fix lint * enrich unittest cases * fix lint * fix unittest Co-authored-by: Ma Zerun <mzr1996@163.com> Co-authored-by: Wenhao Wu <79644370+wHao-Wu@users.noreply.github.com> Co-authored-by: hdc <hudingchang.vendor@sensetime.com> Co-authored-by: grimoire <yaoqian@sensetime.com> Co-authored-by: Jiazhen Wang <47851024+teamwong111@users.noreply.github.com> Co-authored-by: Hakjin Lee <nijkah@gmail.com>
- Loading branch information
Showing
9 changed files
with
1,036 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,3 +47,8 @@ ops | |
------ | ||
.. automodule:: mmcv.ops | ||
:members: | ||
|
||
transform | ||
--------- | ||
.. automodule:: mmcv.transform | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,3 +47,8 @@ ops | |
------ | ||
.. automodule:: mmcv.ops | ||
:members: | ||
|
||
transform | ||
--------- | ||
.. automodule:: mmcv.transform | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .builder import TRANSFORMS | ||
from .wrappers import ApplyToMultiple, Compose, RandomChoice, Remap | ||
|
||
__all__ = ['TRANSFORMS', 'ApplyToMultiple', 'Compose', 'RandomChoice', 'Remap'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from abc import ABCMeta, abstractmethod | ||
from typing import Dict | ||
|
||
|
||
class BaseTransform(metaclass=ABCMeta): | ||
|
||
def __call__(self, results: Dict) -> Dict: | ||
|
||
return self.transform(results) | ||
|
||
@abstractmethod | ||
def transform(self, results: Dict) -> Dict: | ||
"""The transform function. All subclass of BaseTransform should | ||
override this method. | ||
This function takes the result dict as the input, and can add new | ||
items to the dict or modify existing items in the dict. And the result | ||
dict will be returned in the end, which allows to concate multiple | ||
transforms into a pipeline. | ||
Args: | ||
results (dict): The result dict. | ||
Returns: | ||
dict: The result dict. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from ..utils.registry import Registry | ||
|
||
TRANSFORMS = Registry('transform') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import functools | ||
import inspect | ||
import weakref | ||
from collections import defaultdict | ||
from collections.abc import Iterable | ||
from contextlib import contextmanager | ||
from typing import Callable, Union | ||
|
||
from .base import BaseTransform | ||
|
||
|
||
class cacheable_method: | ||
"""Decorator that marks a method of a transform class as a cacheable | ||
method. | ||
This decorator is usually used together with the context-manager | ||
:func`:cache_random_params`. In this context, a cacheable method will | ||
cache its return value(s) at the first time of being invoked, and always | ||
return the cached values when being invoked again. | ||
.. note:: | ||
Only a instance method can be decorated as a cacheable_method. | ||
""" | ||
|
||
def __init__(self, func): | ||
|
||
# Check `func` is to be bound as an instance method | ||
if not inspect.isfunction(func): | ||
raise TypeError('Unsupport callable to decorate with' | ||
'@cacheable_method.') | ||
func_args = inspect.getfullargspec(func).args | ||
if len(func_args) == 0 or func_args[0] != 'self': | ||
raise TypeError( | ||
'@cacheable_method should only be used to decorate ' | ||
'instance methods (the first argument is `self`).') | ||
|
||
functools.update_wrapper(self, func) | ||
self.func = func | ||
self.instance_ref = None | ||
|
||
def __set_name__(self, owner, name): | ||
# Maintain a record of decorated methods in the class | ||
if not hasattr(owner, '_cacheable_methods'): | ||
setattr(owner, '_cacheable_methods', []) | ||
owner._cacheable_methods.append(self.__name__) | ||
|
||
def __call__(self, *args, **kwargs): | ||
# Get the transform instance whose method is decorated | ||
# by cacheable_method | ||
instance = self.instance_ref() | ||
name = self.__name__ | ||
|
||
# Check the flag `self._cache_enabled`, which should be | ||
# set by the contextmanagers like `cache_random_parameters` | ||
cache_enabled = getattr(instance, '_cache_enabled', False) | ||
|
||
if cache_enabled: | ||
# Initialize the cache of the transform instances. The flag | ||
# `cache_enabled` is set by contextmanagers like | ||
# `cache_random_params`. | ||
if not hasattr(instance, '_cache'): | ||
setattr(instance, '_cache', {}) | ||
|
||
if name not in instance._cache: | ||
instance._cache[name] = self.func(instance, *args, **kwargs) | ||
# Return the cached value | ||
return instance._cache[name] | ||
else: | ||
# Clear cache | ||
if hasattr(instance, '_cache'): | ||
del instance._cache | ||
# Return function output | ||
return self.func(instance, *args, **kwargs) | ||
|
||
def __get__(self, obj, cls): | ||
self.instance_ref = weakref.ref(obj) | ||
return self | ||
|
||
|
||
@contextmanager | ||
def cache_random_params(transforms: Union[BaseTransform, Iterable]): | ||
"""Context-manager that enables the cache of cacheable methods in | ||
transforms. | ||
In this mode, cacheable methods will cache their return values on the | ||
first invoking, and always return the cached value afterward. This allow | ||
to apply random transforms in a deterministic way. For example, apply same | ||
transforms on multiple examples. See `cacheable_method` for more | ||
information. | ||
Args: | ||
transforms (BaseTransform|list[BaseTransform]): The transforms to | ||
enable cache. | ||
""" | ||
|
||
# key2method stores the original methods that are replaced by the wrapped | ||
# ones. These methods will be restituted when exiting the context. | ||
key2method = dict() | ||
|
||
# key2counter stores the usage number of each cacheable_method. This is | ||
# used to check that any cacheable_method is invoked once during processing | ||
# on data sample. | ||
key2counter = defaultdict(int) | ||
|
||
def _add_counter(obj, method_name): | ||
method = getattr(obj, method_name) | ||
key = f'{id(obj)}.{method_name}' | ||
key2method[key] = method | ||
|
||
@functools.wraps(method) | ||
def wrapped(*args, **kwargs): | ||
key2counter[key] += 1 | ||
return method(*args, **kwargs) | ||
|
||
return wrapped | ||
|
||
def _start_cache(t: BaseTransform): | ||
# Set cache enabled flag | ||
setattr(t, '_cache_enabled', True) | ||
|
||
# Store the original method and init the counter | ||
if hasattr(t, '_cacheable_methods'): | ||
setattr(t, 'transform', _add_counter(t, 'transform')) | ||
for name in t._cacheable_methods: | ||
setattr(t, name, _add_counter(t, name)) | ||
|
||
def _end_cache(t: BaseTransform): | ||
# Remove cache enabled flag | ||
del t._cache_enabled | ||
if hasattr(t, '_cache'): | ||
del t._cache | ||
|
||
# Restore the original method | ||
if hasattr(t, '_cacheable_methods'): | ||
key_transform = f'{id(t)}.transform' | ||
for name in t._cacheable_methods: | ||
key = f'{id(t)}.{name}' | ||
if key2counter[key] != key2counter[key_transform]: | ||
raise RuntimeError( | ||
'The cacheable method should be called once and only' | ||
f'once during processing one data sample. {t} got' | ||
f'unmatched number of {key2counter[key]} ({name}) vs' | ||
f'{key2counter[key_transform]} (data samples)') | ||
setattr(t, name, key2method[key]) | ||
setattr(t, 'transform', key2method[key_transform]) | ||
|
||
def _apply(t: Union[BaseTransform, Iterable], | ||
func: Callable[[BaseTransform], None]): | ||
if isinstance(t, BaseTransform): | ||
if hasattr(t, '_cacheable_methods'): | ||
func(t) | ||
if isinstance(t, Iterable): | ||
for _t in t: | ||
_apply(_t, func) | ||
|
||
try: | ||
_apply(transforms, _start_cache) | ||
yield | ||
finally: | ||
_apply(transforms, _end_cache) |
Oops, something went wrong.