diff --git a/.dev_scripts/visualize_lr.py b/.dev_scripts/visualize_lr.py deleted file mode 100644 index 5ca9aaa116..0000000000 --- a/.dev_scripts/visualize_lr.py +++ /dev/null @@ -1,230 +0,0 @@ -import argparse -import json -import os -import os.path as osp -import time -import warnings -from collections import OrderedDict -from unittest.mock import patch - -import matplotlib.pyplot as plt -import numpy as np -import torch.nn as nn -from torch.optim import SGD -from torch.utils.data import DataLoader - -import mmcv -from mmcv.runner import build_runner -from mmcv.utils import get_logger - - -def parse_args(): - parser = argparse.ArgumentParser(description='Visualize the given config' - 'of learning rate and momentum, and this' - 'script will overwrite the log_config') - parser.add_argument('config', help='train config file path') - parser.add_argument( - '--work-dir', default='./', help='the dir to save logs and models') - parser.add_argument( - '--num-iters', default=300, help='The number of iters per epoch') - parser.add_argument( - '--num-epochs', default=300, help='Only used in EpochBasedRunner') - parser.add_argument( - '--window-size', - default='12*14', - help='Size of the window to display images, in format of "$W*$H".') - parser.add_argument( - '--log-interval', default=10, help='The interval of TextLoggerHook') - args = parser.parse_args() - return args - - -class SimpleModel(nn.Module): - - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def train_step(self, *args, **kwargs): - return dict() - - def val_step(self, *args, **kwargs): - return dict() - - -def iter_train(self, data_loader, **kwargs): - self.mode = 'train' - self.data_loader = data_loader - self.call_hook('before_train_iter') - self.call_hook('after_train_iter') - self._inner_iter += 1 - self._iter += 1 - - -def epoch_train(self, data_loader, **kwargs): - self.model.train() - self.mode = 'train' - self.data_loader = data_loader - self._max_iters = self._max_epochs * len(self.data_loader) - self.call_hook('before_train_epoch') - for i, data_batch in enumerate(self.data_loader): - self._inner_iter = i - self.call_hook('before_train_iter') - self.call_hook('after_train_iter') - self._iter += 1 - self.call_hook('after_train_epoch') - self._epoch += 1 - - -def log(self, runner): - cur_iter = self.get_iter(runner, inner_iter=True) - - log_dict = OrderedDict( - mode=self.get_mode(runner), - epoch=self.get_epoch(runner), - iter=cur_iter) - - # only record lr of the first param group - cur_lr = runner.current_lr() - if isinstance(cur_lr, list): - log_dict['lr'] = cur_lr[0] - else: - assert isinstance(cur_lr, dict) - log_dict['lr'] = {} - for k, lr_ in cur_lr.items(): - assert isinstance(lr_, list) - log_dict['lr'].update({k: lr_[0]}) - - cur_momentum = runner.current_momentum() - if isinstance(cur_momentum, list): - log_dict['momentum'] = cur_momentum[0] - else: - assert isinstance(cur_momentum, dict) - log_dict['momentum'] = {} - for k, lr_ in cur_momentum.items(): - assert isinstance(lr_, list) - log_dict['momentum'].update({k: lr_[0]}) - log_dict = dict(log_dict, **runner.log_buffer.output) - self._log_info(log_dict, runner) - self._dump_log(log_dict, runner) - return log_dict - - -@patch('torch.cuda.is_available', lambda: False) -@patch('mmcv.runner.EpochBasedRunner.train', epoch_train) -@patch('mmcv.runner.IterBasedRunner.train', iter_train) -@patch('mmcv.runner.hooks.TextLoggerHook.log', log) -def run(cfg, logger): - momentum_config = cfg.get('momentum_config') - lr_config = cfg.get('lr_config') - - model = SimpleModel() - optimizer = SGD(model.parameters(), 0.1, momentum=0.8) - cfg.work_dir = cfg.get('work_dir', './') - workflow = [('train', 1)] - - if cfg.get('runner') is None: - cfg.runner = { - 'type': 'EpochBasedRunner', - 'max_epochs': cfg.get('total_epochs', cfg.num_epochs) - } - warnings.warn( - 'config is now expected to have a `runner` section, ' - 'please set `runner` in your config.', UserWarning) - batch_size = 1 - data = cfg.get('data') - if data: - batch_size = data.get('samples_per_gpu') - fake_dataloader = DataLoader( - list(range(cfg.num_iters)), batch_size=batch_size) - runner = build_runner( - cfg.runner, - default_args=dict( - model=model, - batch_processor=None, - optimizer=optimizer, - work_dir=cfg.work_dir, - logger=logger, - meta=None)) - log_config = dict( - interval=cfg.log_interval, hooks=[ - dict(type='TextLoggerHook'), - ]) - - runner.register_training_hooks(lr_config, log_config=log_config) - runner.register_momentum_hook(momentum_config) - runner.run([fake_dataloader], workflow) - - -def plot_lr_curve(json_file, cfg): - data_dict = dict(LearningRate=[], Momentum=[]) - assert os.path.isfile(json_file) - with open(json_file) as f: - for line in f: - log = json.loads(line.strip()) - data_dict['LearningRate'].append(log['lr']) - data_dict['Momentum'].append(log['momentum']) - - wind_w, wind_h = (int(size) for size in cfg.window_size.split('*')) - # if legend is None, use {filename}_{key} as legend - fig, axes = plt.subplots(2, 1, figsize=(wind_w, wind_h)) - plt.subplots_adjust(hspace=0.5) - font_size = 20 - for index, (updater_type, data_list) in enumerate(data_dict.items()): - ax = axes[index] - if cfg.runner.type == 'EpochBasedRunner': - ax.plot(data_list, linewidth=1) - ax.xaxis.tick_top() - ax.set_xlabel('Iters', fontsize=font_size) - ax.xaxis.set_label_position('top') - sec_ax = ax.secondary_xaxis( - 'bottom', - functions=(lambda x: x / cfg.num_iters * cfg.log_interval, - lambda y: y * cfg.num_iters / cfg.log_interval)) - sec_ax.tick_params(labelsize=font_size) - sec_ax.set_xlabel('Epochs', fontsize=font_size) - else: - # plt.subplot(2, 1, index + 1) - x_list = np.arange(len(data_list)) * cfg.log_interval - ax.plot(x_list, data_list) - ax.set_xlabel('Iters', fontsize=font_size) - ax.set_ylabel(updater_type, fontsize=font_size) - if updater_type == 'LearningRate': - if cfg.get('lr_config'): - title = cfg.lr_config.type - else: - title = 'No learning rate scheduler' - else: - if cfg.get('momentum_config'): - title = cfg.momentum_config.type - else: - title = 'No momentum scheduler' - ax.set_title(title, fontsize=font_size) - ax.grid() - # set tick font size - ax.tick_params(labelsize=font_size) - save_path = osp.join(cfg.work_dir, 'visualization-result') - plt.savefig(save_path) - print(f'The learning rate graph is saved at {save_path}.png') - plt.show() - - -def main(): - args = parse_args() - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - cfg = mmcv.Config.fromfile(args.config) - cfg['num_iters'] = args.num_iters - cfg['num_epochs'] = args.num_epochs - cfg['log_interval'] = args.log_interval - cfg['window_size'] = args.window_size - - log_path = osp.join(cfg.get('work_dir', './'), f'{timestamp}.log') - json_path = log_path + '.json' - logger = get_logger('mmcv', log_path) - - run(cfg, logger) - plot_lr_curve(json_path, cfg) - - -if __name__ == '__main__': - main() diff --git a/docs/en/index.rst b/docs/en/index.rst index 1e5193ac30..fbfe9c5b7b 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -15,13 +15,10 @@ You can switch between Chinese and English documents in the lower-left corner of :maxdepth: 2 :caption: Understand MMCV - understand_mmcv/config.md - understand_mmcv/registry.md understand_mmcv/data_process.md understand_mmcv/visualization.md understand_mmcv/cnn.md understand_mmcv/ops.md - understand_mmcv/utils.md .. toctree:: :maxdepth: 2 diff --git a/docs/en/understand_mmcv/config.md b/docs/en/understand_mmcv/config.md deleted file mode 100644 index 9626dbe2c3..0000000000 --- a/docs/en/understand_mmcv/config.md +++ /dev/null @@ -1,200 +0,0 @@ -## Config - -`Config` class is used for manipulating config and config files. It supports -loading configs from multiple file formats including **python**, **json** and **yaml**. -It provides dict-like apis to get and set values. - -Here is an example of the config file `test.py`. - -```python -a = 1 -b = dict(b1=[0, 1, 2], b2=None) -c = (1, 2) -d = 'string' -``` - -To load and use configs - -```python ->>> cfg = Config.fromfile('test.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=None), -... c=(1, 2), -... d='string') -``` - -For all format configs, some predefined variables are supported. It will convert the variable in `{{ var }}` with its real value. - -Currently, it supports four predefined variables: - -`{{ fileDirname }}` - the current opened file's dirname, e.g. /home/your-username/your-project/folder - -`{{ fileBasename }}` - the current opened file's basename, e.g. file.ext - -`{{ fileBasenameNoExtension }}` - the current opened file's basename with no file extension, e.g. file - -`{{ fileExtname }}` - the current opened file's extension, e.g. .ext - -These variable names are referred from [VS Code](https://code.visualstudio.com/docs/editor/variables-reference). - -Here is one examples of config with predefined variables. - -`config_a.py` - -```python -a = 1 -b = './work_dir/{{ fileBasenameNoExtension }}' -c = '{{ fileExtname }}' -``` - -```python ->>> cfg = Config.fromfile('./config_a.py') ->>> print(cfg) ->>> dict(a=1, -... b='./work_dir/config_a', -... c='.py') -``` - -For all format configs, inheritance is supported. To reuse fields in other config files, -specify `_base_='./config_a.py'` or a list of configs `_base_=['./config_a.py', './config_b.py']`. -Here are 4 examples of config inheritance. - -`config_a.py` - -```python -a = 1 -b = dict(b1=[0, 1, 2], b2=None) -``` - -### Inherit from base config without overlapped keys - -`config_b.py` - -```python -_base_ = './config_a.py' -c = (1, 2) -d = 'string' -``` - -```python ->>> cfg = Config.fromfile('./config_b.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=None), -... c=(1, 2), -... d='string') -``` - -New fields in `config_b.py` are combined with old fields in `config_a.py` - -### Inherit from base config with overlapped keys - -`config_c.py` - -```python -_base_ = './config_a.py' -b = dict(b2=1) -c = (1, 2) -``` - -```python ->>> cfg = Config.fromfile('./config_c.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=1), -... c=(1, 2)) -``` - -`b.b2=None` in `config_a` is replaced with `b.b2=1` in `config_c.py`. - -### Inherit from base config with ignored fields - -`config_d.py` - -```python -_base_ = './config_a.py' -b = dict(_delete_=True, b2=None, b3=0.1) -c = (1, 2) -``` - -```python ->>> cfg = Config.fromfile('./config_d.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b2=None, b3=0.1), -... c=(1, 2)) -``` - -You may also set `_delete_=True` to ignore some fields in base configs. All old keys `b1, b2, b3` in `b` are replaced with new keys `b2, b3`. - -### Inherit from multiple base configs (the base configs should not contain the same keys) - -`config_e.py` - -```python -c = (1, 2) -d = 'string' -``` - -`config_f.py` - -```python -_base_ = ['./config_a.py', './config_e.py'] -``` - -```python ->>> cfg = Config.fromfile('./config_f.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=None), -... c=(1, 2), -... d='string') -``` - -### Reference variables from base - -You can reference variables defined in base using the following grammar. - -`base.py` - -```python -item1 = 'a' -item2 = dict(item3 = 'b') -``` - -`config_g.py` - -```python -_base_ = ['./base.py'] -item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }}) -``` - -```python ->>> cfg = Config.fromfile('./config_g.py') ->>> print(cfg.pretty_text) -item1 = 'a' -item2 = dict(item3='b') -item = dict(a='a', b='b') -``` - -### Add deprecation information in configs - -Deprecation information can be added in a config file, which will trigger a `UserWarning` when this config file is loaded. - -`deprecated_cfg.py` - -```python -_base_ = 'expected_cfg.py' - -_deprecation_ = dict( - expected = 'expected_cfg.py', # optional to show expected config path in the warning information - reference = 'url to related PR' # optional to show reference link in the warning information -) -``` - -```python ->>> cfg = Config.fromfile('./deprecated_cfg.py') - -UserWarning: The config file deprecated_cfg.py will be deprecated in the future. Please use expected_cfg.py instead. More information can be found at https://github.com/open-mmlab/mmcv/pull/1275 -``` diff --git a/docs/en/understand_mmcv/registry.md b/docs/en/understand_mmcv/registry.md deleted file mode 100644 index 6f3c767fcb..0000000000 --- a/docs/en/understand_mmcv/registry.md +++ /dev/null @@ -1,179 +0,0 @@ -## Registry - -MMCV implements [registry](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) to manage different modules that share similar functionalities, e.g., backbones, head, and necks, in detectors. -Most projects in OpenMMLab use registry to manage modules of datasets and models, such as [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting), etc. - -```{note} -In v1.5.1 and later, the Registry supports registering functions and calling them. -``` - -### What is registry - -In MMCV, registry can be regarded as a mapping that maps a class or function to a string. -These classes or functions contained by a single registry usually have similar APIs but implement different algorithms or support different datasets. -With the registry, users can find the class or function through its corresponding string, and instantiate the corresponding module or call the function to obtain the result according to needs. -One typical example is the config systems in most OpenMMLab projects, which use the registry to create hooks, runners, models, and datasets, through configs. -The API reference could be found [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry). - -To manage your modules in the codebase by `Registry`, there are three steps as below. - -1. Create a build method (optional, in most cases you can just use the default one). -2. Create a registry. -3. Use this registry to manage the modules. - -`build_func` argument of `Registry` is to customize how to instantiate the class instance or how to call the function to obtain the result, the default one is `build_from_cfg` implemented [here](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg). - -### A Simple Example - -Here we show a simple example of using registry to manage modules in a package. -You can find more practical examples in OpenMMLab projects. - -Assuming we want to implement a series of Dataset Converter for converting different formats of data to the expected data format. -We create a directory as a package named `converters`. -In the package, we first create a file to implement builders, named `converters/builder.py`, as below - -```python -from mmcv.utils import Registry -# create a registry for converters -CONVERTERS = Registry('converters') -``` - -Then we can implement different converters that is class or function in the package. For example, implement `Converter1` in `converters/converter1.py`, and `converter2` in `converters/converter2.py`. - -```python - -from .builder import CONVERTERS - -# use the registry to manage the module -@CONVERTERS.register_module() -class Converter1(object): - def __init__(self, a, b): - self.a = a - self.b = b -``` - -```python -# converter2.py -from .builder import CONVERTERS -from .converter1 import Converter1 - -# 使用注册器管理模块 -@CONVERTERS.register_module() -def converter2(a, b) - return Converter1(a, b) -``` - -The key step to use registry for managing the modules is to register the implemented module into the registry `CONVERTERS` through -`@CONVERTERS.register_module()` when you are creating the module. By this way, a mapping between a string and the class (function) is built and maintained by `CONVERTERS` as below - -```python -'Converter1' -> -'converter2' -> -``` - -```{note} -The registry mechanism will be triggered only when the file where the module is located is imported. -So you need to import that file somewhere. More details can be found at https://github.com/open-mmlab/mmdetection/issues/5974. -``` - -If the module is successfully registered, you can use this converter through configs as - -```python -converter1_cfg = dict(type='Converter1', a=a_value, b=b_value) -converter2_cfg = dict(type='converter2', a=a_value, b=b_value) -converter1 = CONVERTERS.build(converter1_cfg) -# returns the calling result -result = CONVERTERS.build(converter2_cfg) -``` - -### Customize Build Function - -Suppose we would like to customize how `converters` are built, we could implement a customized `build_func` and pass it into the registry. - -```python -from mmcv.utils import Registry - -# create a build function -def build_converter(cfg, registry, *args, **kwargs): - cfg_ = cfg.copy() - converter_type = cfg_.pop('type') - if converter_type not in registry: - raise KeyError(f'Unrecognized converter type {converter_type}') - else: - converter_cls = registry.get(converter_type) - - converter = converter_cls(*args, **kwargs, **cfg_) - return converter - -# create a registry for converters and pass ``build_converter`` function -CONVERTERS = Registry('converter', build_func=build_converter) -``` - -```{note} -In this example, we demonstrate how to use the `build_func` argument to customize the way to build a class instance. -The functionality is similar to the default `build_from_cfg`. In most cases, default one would be sufficient. -`build_model_from_cfg` is also implemented to build PyTorch module in `nn.Sequential`, you may directly use them instead of implementing by yourself. -``` - -### Hierarchy Registry - -You could also build modules from more than one OpenMMLab frameworks, e.g. you could use all backbones in [MMClassification](https://github.com/open-mmlab/mmclassification) for object detectors in [MMDetection](https://github.com/open-mmlab/mmdetection), you may also combine an object detection model in [MMDetection](https://github.com/open-mmlab/mmdetection) and semantic segmentation model in [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). - -All `MODELS` registries of downstream codebases are children registries of MMCV's `MODELS` registry. -Basically, there are two ways to build a module from child or sibling registries. - -1. Build from children registries. - - For example: - - In MMDetection we define: - - ```python - from mmengine.registry import Registry - from mmengine.registry import MODELS as MMENGINE_MODELS - MODELS = Registry('model', parent=MMENGINE_MODELS) - - @MODELS.register_module() - class NetA(nn.Module): - def forward(self, x): - return x - ``` - - In MMClassification we define: - - ```python - from mmengine.registry import Registry - from mmengine.registry import MODELS as MMENGINE_MODELS - MODELS = Registry('model', parent=MMENGINE_MODELS) - - @MODELS.register_module() - class NetB(nn.Module): - def forward(self, x): - return x + 1 - ``` - - We could build two net in either MMDetection or MMClassification by: - - ```python - from mmdet.models import MODELS - net_a = MODELS.build(cfg=dict(type='NetA')) - net_b = MODELS.build(cfg=dict(type='mmcls.NetB')) - ``` - - or - - ```python - from mmcls.models import MODELS - net_a = MODELS.build(cfg=dict(type='mmdet.NetA')) - net_b = MODELS.build(cfg=dict(type='NetB')) - ``` - -2. Build from parent registry. - - The shared `MODELS` registry in MMCV is the parent registry for all downstream codebases (root registry): - - ```python - from mmengine.registry import MODELS as MMENGINE_MODELS - net_a = MMENGINE_MODELS.build(cfg=dict(type='mmdet.NetA')) - net_b = MMENGINE_MODELS.build(cfg=dict(type='mmcls.NetB')) - ``` diff --git a/docs/en/understand_mmcv/utils.md b/docs/en/understand_mmcv/utils.md deleted file mode 100644 index 5d5e0adf9b..0000000000 --- a/docs/en/understand_mmcv/utils.md +++ /dev/null @@ -1,74 +0,0 @@ -## Utils - -### ProgressBar - -If you want to apply a method to a list of items and track the progress, `track_progress` -is a good choice. It will display a progress bar to tell the progress and ETA. - -```python -import mmcv - -def func(item): - # do something - pass - -tasks = [item_1, item_2, ..., item_n] - -mmcv.track_progress(func, tasks) -``` - -The output is like the following. - -![progress](../_static/progress.*) - -There is another method `track_parallel_progress`, which wraps multiprocessing and -progress visualization. - -```python -mmcv.track_parallel_progress(func, tasks, 8) # 8 workers -``` - -![progress](../_static/parallel_progress.*) - -If you want to iterate or enumerate a list of items and track the progress, `track_iter_progress` -is a good choice. It will display a progress bar to tell the progress and ETA. - -```python -import mmcv - -tasks = [item_1, item_2, ..., item_n] - -for task in mmcv.track_iter_progress(tasks): - # do something like print - print(task) - -for i, task in enumerate(mmcv.track_iter_progress(tasks)): - # do something like print - print(i) - print(task) -``` - -### Timer - -It is convenient to compute the runtime of a code block with `Timer`. - -```python -import time - -with mmcv.Timer(): - # simulate some code block - time.sleep(1) -``` - -or try with `since_start()` and `since_last_check()`. This former can -return the runtime since the timer starts and the latter will return the time -since the last time checked. - -```python -timer = mmcv.Timer() -# code block 1 here -print(timer.since_start()) -# code block 2 here -print(timer.since_last_check()) -print(timer.since_start()) -``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 5c067a9eb6..3bf1d9eda4 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -15,14 +15,11 @@ :maxdepth: 2 :caption: 深入理解 MMCV - understand_mmcv/config.md - understand_mmcv/registry.md understand_mmcv/data_process.md understand_mmcv/data_transform.md understand_mmcv/visualization.md understand_mmcv/cnn.md understand_mmcv/ops.md - understand_mmcv/utils.md .. toctree:: :maxdepth: 2 diff --git a/docs/zh_cn/understand_mmcv/config.md b/docs/zh_cn/understand_mmcv/config.md deleted file mode 100644 index 52d7ab37b4..0000000000 --- a/docs/zh_cn/understand_mmcv/config.md +++ /dev/null @@ -1,179 +0,0 @@ -## 配置 - -`Config` 类用于操作配置文件,它支持从多种文件格式中加载配置,包括 **python**, **json** 和 **yaml**。 -它提供了类似字典对象的接口来获取和设置值。 - -以配置文件 `test.py` 为例 - -```python -a = 1 -b = dict(b1=[0, 1, 2], b2=None) -c = (1, 2) -d = 'string' -``` - -加载与使用配置文件 - -```python ->>> cfg = Config.fromfile('test.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=None), -... c=(1, 2), -... d='string') -``` - -对于所有格式的配置文件,都支持一些预定义变量。它会将 `{{ var }}` 替换为实际值。 - -目前支持以下四个预定义变量: - -`{{ fileDirname }}` - 当前打开文件的目录名,例如 /home/your-username/your-project/folder - -`{{ fileBasename }}` - 当前打开文件的文件名,例如 file.ext - -`{{ fileBasenameNoExtension }}` - 当前打开文件不包含扩展名的文件名,例如 file - -`{{ fileExtname }}` - 当前打开文件的扩展名,例如 .ext - -这些变量名引用自 [VS Code](https://code.visualstudio.com/docs/editor/variables-reference)。 - -这里是一个带有预定义变量的配置文件的例子。 - -`config_a.py` - -```python -a = 1 -b = './work_dir/{{ fileBasenameNoExtension }}' -c = '{{ fileExtname }}' -``` - -```python ->>> cfg = Config.fromfile('./config_a.py') ->>> print(cfg) ->>> dict(a=1, -... b='./work_dir/config_a', -... c='.py') -``` - -对于所有格式的配置文件, 都支持继承。为了重用其他配置文件的字段, -需要指定 `_base_='./config_a.py'` 或者一个包含配置文件的列表 `_base_=['./config_a.py', './config_b.py']`。 - -这里有 4 个配置继承关系的例子。 - -`config_a.py` 作为基类配置文件 - -```python -a = 1 -b = dict(b1=[0, 1, 2], b2=None) -``` - -### 不含重复键值对从基类配置文件继承 - -`config_b.py` - -```python -_base_ = './config_a.py' -c = (1, 2) -d = 'string' -``` - -```python ->>> cfg = Config.fromfile('./config_b.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=None), -... c=(1, 2), -... d='string') -``` - -在`config_b.py`里的新字段与在`config_a.py`里的旧字段拼接 - -### 含重复键值对从基类配置文件继承 - -`config_c.py` - -```python -_base_ = './config_a.py' -b = dict(b2=1) -c = (1, 2) -``` - -```python ->>> cfg = Config.fromfile('./config_c.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=1), -... c=(1, 2)) -``` - -在基类配置文件:`config_a` 里的 `b.b2=None`被配置文件:`config_c.py`里的 `b.b2=1`替代。 - -### 从具有忽略字段的配置文件继承 - -`config_d.py` - -```python -_base_ = './config_a.py' -b = dict(_delete_=True, b2=None, b3=0.1) -c = (1, 2) -``` - -```python ->>> cfg = Config.fromfile('./config_d.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b2=None, b3=0.1), -... c=(1, 2)) -``` - -您还可以设置 `_delete_=True`忽略基类配置文件中的某些字段。所有在`b`中的旧键 `b1, b2, b3` 将会被新键 `b2, b3` 所取代。 - -### 从多个基类配置文件继承(基类配置文件不应包含相同的键) - -`config_e.py` - -```python -c = (1, 2) -d = 'string' -``` - -`config_f.py` - -```python -_base_ = ['./config_a.py', './config_e.py'] -``` - -```python ->>> cfg = Config.fromfile('./config_f.py') ->>> print(cfg) ->>> dict(a=1, -... b=dict(b1=[0, 1, 2], b2=None), -... c=(1, 2), -... d='string') -``` - -### 从基类引用变量 - -您可以使用以下语法引用在基类中定义的变量。 - -`base.py` - -```python -item1 = 'a' -item2 = dict(item3 = 'b') -``` - -`config_g.py` - -```python -_base_ = ['./base.py'] -item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }}) -``` - -```python ->>> cfg = Config.fromfile('./config_g.py') ->>> print(cfg.pretty_text) -item1 = 'a' -item2 = dict(item3='b') -item = dict(a='a', b='b') -``` diff --git a/docs/zh_cn/understand_mmcv/registry.md b/docs/zh_cn/understand_mmcv/registry.md deleted file mode 100644 index bd89fa3417..0000000000 --- a/docs/zh_cn/understand_mmcv/registry.md +++ /dev/null @@ -1,176 +0,0 @@ -## 注册器 - -MMCV 使用 [注册器](https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py) 来管理具有相似功能的不同模块, 例如, 检测器中的主干网络、头部、和模型颈部。 -在 OpenMMLab 家族中的绝大部分开源项目使用注册器去管理数据集和模型的模块,例如 [MMDetection](https://github.com/open-mmlab/mmdetection), [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [MMClassification](https://github.com/open-mmlab/mmclassification), [MMEditing](https://github.com/open-mmlab/mmediting) 等。 - -```{note} -在 v1.5.1 版本开始支持注册函数的功能。 -``` - -### 什么是注册器 - -在MMCV中,注册器可以看作类或函数到字符串的映射。 -一个注册器中的类或函数通常有相似的接口,但是可以实现不同的算法或支持不同的数据集。 -借助注册器,用户可以通过使用相应的字符串查找类或函数,并根据他们的需要实例化对应模块或调用函数获取结果。 -一个典型的案例是,OpenMMLab 中的大部分开源项目的配置系统,这些系统通过配置文件来使用注册器创建钩子、执行器、模型和数据集。 -可以在[这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.Registry)找到注册器接口使用文档。 - -使用 `registry`(注册器)管理代码库中的模型,需要以下三个步骤。 - -1. 创建一个构建方法(可选,在大多数情况下您可以只使用默认方法) -2. 创建注册器 -3. 使用此注册器来管理模块 - -`Registry`(注册器)的参数 `build_func`(构建函数) 用来自定义如何实例化类的实例或如何调用函数获取结果,默认使用 [这里](https://mmcv.readthedocs.io/en/latest/api.html?highlight=registry#mmcv.utils.build_from_cfg) 实现的`build_from_cfg`。 - -### 一个简单的例子 - -这里是一个使用注册器管理包中模块的简单示例。您可以在 OpenMMLab 开源项目中找到更多实例。 - -假设我们要实现一系列数据集转换器(Dataset Converter),用于将不同格式的数据转换为标准数据格式。我们先创建一个名为converters的目录作为包,在包中我们创建一个文件来实现构建器(builder),命名为converters/builder.py,如下 - -```python -from mmengine.registry import Registry -# 创建转换器(converter)的注册器(registry) -CONVERTERS = Registry('converter') -``` - -然后我们在包中可以实现不同的转换器(converter),其可以为类或函数。例如,在 `converters/converter1.py` 中实现 `Converter1`,在 `converters/converter2.py` 中实现 `converter2`。 - -```python -# converter1.py -from .builder import CONVERTERS - -# 使用注册器管理模块 -@CONVERTERS.register_module() -class Converter1(object): - def __init__(self, a, b): - self.a = a - self.b = b -``` - -```python -# converter2.py -from .builder import CONVERTERS -from .converter1 import Converter1 - -# 使用注册器管理模块 -@CONVERTERS.register_module() -def converter2(a, b) - return Converter1(a, b) -``` - -使用注册器管理模块的关键步骤是,将实现的模块注册到注册表 `CONVERTERS` 中。通过 `@CONVERTERS.register_module()` 装饰所实现的模块,字符串到类或函数之间的映射就可以由 `CONVERTERS` 构建和维护,如下所示: - -通过这种方式,就可以通过 `CONVERTERS` 建立字符串与类或函数之间的映射,如下所示: - -```python -'Converter1' -> -'converter2' -> -``` - -```{note} -只有模块所在的文件被导入时,注册机制才会被触发,所以您需要在某处导入该文件。更多详情请查看 https://github.com/open-mmlab/mmdetection/issues/5974。 -``` - -如果模块被成功注册了,你可以通过配置文件使用这个转换器(converter),如下所示: - -```python -converter1_cfg = dict(type='Converter1', a=a_value, b=b_value) -converter2_cfg = dict(type='converter2', a=a_value, b=b_value) -converter1 = CONVERTERS.build(converter1_cfg) -# returns the calling result -result = CONVERTERS.build(converter2_cfg) -``` - -### 自定义构建函数 - -假设我们想自定义 `converters` 的构建流程,我们可以实现一个自定义的 `build_func` (构建函数)并将其传递到注册器中。 - -```python -from mmcv.utils import Registry - -# 创建一个构建函数 -def build_converter(cfg, registry, *args, **kwargs): - cfg_ = cfg.copy() - converter_type = cfg_.pop('type') - if converter_type not in registry: - raise KeyError(f'Unrecognized converter type {converter_type}') - else: - converter_cls = registry.get(converter_type) - - converter = converter_cls(*args, **kwargs, **cfg_) - return converter - -# 创建一个用于转换器(converters)的注册器,并传递(registry)``build_converter`` 函数 -CONVERTERS = Registry('converter', build_func=build_converter) -``` - -```{note} -注:在这个例子中,我们演示了如何使用参数:`build_func` 自定义构建类的实例的方法。 -该功能类似于默认的`build_from_cfg`。在大多数情况下,默认就足够了。 -``` - -`build_model_from_cfg`也实现了在`nn.Sequential`中构建PyTorch模块,你可以直接使用它们。 - -### 注册器层结构 - -你也可以从多个 OpenMMLab 开源框架中构建模块,例如,你可以把所有 [MMClassification](https://github.com/open-mmlab/mmclassification) 中的主干网络(backbone)用到 [MMDetection](https://github.com/open-mmlab/mmdetection) 的目标检测中,你也可以融合 [MMDetection](https://github.com/open-mmlab/mmdetection) 中的目标检测模型 和 [MMSegmentation](https://github.com/open-mmlab/mmsegmentation) 语义分割模型。 - -下游代码库中所有 `MODELS` 注册器都是MMCV `MODELS` 注册器的子注册器。基本上,使用以下两种方法从子注册器或相邻兄弟注册器构建模块。 - -1. 从子注册器中构建 - - 例如: - - 我们在 MMDetection 中定义: - - ```python - from mmengine.resgitry import Registry - from mmengine.resgitry import MODELS as MMENGINE_MODELS - MODELS = Registry('model', parent=MMENGINE_MODELS) - - @MODELS.register_module() - class NetA(nn.Module): - def forward(self, x): - return x - ``` - - 我们在 MMClassification 中定义: - - ```python - from mmengine.registry import Registry - from mmengine.registry import MODELS as MMENGINE_MODELS - MODELS = Registry('model', parent=MMENGINE_MODELS) - - @MODELS.register_module() - class NetB(nn.Module): - def forward(self, x): - return x + 1 - ``` - - 我们可以通过以下代码在 MMDetection 或 MMClassification 中构建两个网络: - - ```python - from mmdet.models import MODELS - net_a = MODELS.build(cfg=dict(type='NetA')) - net_b = MODELS.build(cfg=dict(type='mmcls.NetB')) - ``` - - 或 - - ```python - from mmcls.models import MODELS - net_a = MODELS.build(cfg=dict(type='mmdet.NetA')) - net_b = MODELS.build(cfg=dict(type='NetB')) - ``` - -2. 从父注册器中构建 - - MMCV中的共享`MODELS`注册器是所有下游代码库的父注册器(根注册器): - - ```python - from mmengine.registry import MODELS as MMENGINE_MODELS - net_a = MMENGINE_MODELS.build(cfg=dict(type='mmdet.NetA')) - net_b = MMENGINE_MODELS.build(cfg=dict(type='mmcls.NetB')) - ``` diff --git a/docs/zh_cn/understand_mmcv/utils.md b/docs/zh_cn/understand_mmcv/utils.md deleted file mode 100644 index c02e5203a4..0000000000 --- a/docs/zh_cn/understand_mmcv/utils.md +++ /dev/null @@ -1,68 +0,0 @@ -## 辅助函数 - -### 进度条 - -如果你想跟踪函数批处理任务的进度,可以使用 `track_progress` 。它能以进度条的形式展示任务的完成情况以及剩余任务所需的时间(内部实现为for循环)。 - -```python -import mmcv - -def func(item): - # 执行相关操作 - pass - -tasks = [item_1, item_2, ..., item_n] - -mmcv.track_progress(func, tasks) -``` - -效果如下 -![progress](../../en/_static/progress.*) - -如果你想可视化多进程任务的进度,你可以使用 `track_parallel_progress` 。 - -```python -mmcv.track_parallel_progress(func, tasks, 8) # 8 workers -``` - -![progress](../../_static/parallel_progress.*) - -如果你想要迭代或枚举数据列表并可视化进度,你可以使用 `track_iter_progress` 。 - -```python -import mmcv - -tasks = [item_1, item_2, ..., item_n] - -for task in mmcv.track_iter_progress(tasks): - # do something like print - print(task) - -for i, task in enumerate(mmcv.track_iter_progress(tasks)): - # do something like print - print(i) - print(task) -``` - -### 计时器 - -mmcv提供的 `Timer` 可以很方便地计算代码块的执行时间。 - -```python -import time - -with mmcv.Timer(): - # simulate some code block - time.sleep(1) -``` - -你也可以使用 `since_start()` 和 `since_last_check()` 。前者返回计时器启动后的运行时长,后者返回最近一次查看计时器后的运行时长。 - -```python -timer = mmcv.Timer() -# code block 1 here -print(timer.since_start()) -# code block 2 here -print(timer.since_last_check()) -print(timer.since_start()) -``` diff --git a/mmcv/__init__.py b/mmcv/__init__.py index 36bfa336d5..2410ea555e 100644 --- a/mmcv/__init__.py +++ b/mmcv/__init__.py @@ -3,7 +3,6 @@ from .arraymisc import * from .image import * from .transforms import * -from .utils import * from .version import * from .video import * from .visualization import * @@ -11,3 +10,4 @@ # The following modules are not imported to this level, so mmcv may be used # without PyTorch. # - op +# - utils diff --git a/mmcv/cnn/bricks/conv_module.py b/mmcv/cnn/bricks/conv_module.py index de12f309db..54db494f9e 100644 --- a/mmcv/cnn/bricks/conv_module.py +++ b/mmcv/cnn/bricks/conv_module.py @@ -6,8 +6,8 @@ import torch.nn as nn from mmengine.model.utils import constant_init, kaiming_init from mmengine.registry import MODELS +from mmengine.utils.parrots_wrapper import _BatchNorm, _InstanceNorm -from mmcv.utils import _BatchNorm, _InstanceNorm from .activation import build_activation_layer from .conv import build_conv_layer from .norm import build_norm_layer diff --git a/mmcv/cnn/bricks/hswish.py b/mmcv/cnn/bricks/hswish.py index 975deab14c..b4e6af937f 100644 --- a/mmcv/cnn/bricks/hswish.py +++ b/mmcv/cnn/bricks/hswish.py @@ -2,8 +2,7 @@ import torch import torch.nn as nn from mmengine.registry import MODELS - -from mmcv.utils import TORCH_VERSION, digit_version +from mmengine.utils import TORCH_VERSION, digit_version class HSwish(nn.Module): diff --git a/mmcv/cnn/bricks/norm.py b/mmcv/cnn/bricks/norm.py index 193d8596b1..83c956cef8 100644 --- a/mmcv/cnn/bricks/norm.py +++ b/mmcv/cnn/bricks/norm.py @@ -4,9 +4,9 @@ import torch.nn as nn from mmengine.registry import MODELS - -from mmcv.utils import is_tuple_of -from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm +from mmengine.utils import is_tuple_of +from mmengine.utils.parrots_wrapper import (SyncBatchNorm, _BatchNorm, + _InstanceNorm) MODELS.register_module('BN', module=nn.BatchNorm2d) MODELS.register_module('BN1d', module=nn.BatchNorm1d) diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index fbdfe87451..32e453deca 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -10,10 +10,10 @@ from mmengine import ConfigDict from mmengine.model import BaseModule, ModuleList, Sequential from mmengine.registry import MODELS +from mmengine.utils import deprecated_api_warning, to_2tuple from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, build_norm_layer) -from mmcv.utils import deprecated_api_warning, to_2tuple from .drop import build_dropout # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file diff --git a/mmcv/cnn/utils/__init__.py b/mmcv/cnn/utils/__init__.py index 9b8b7cde50..cdec9399f6 100644 --- a/mmcv/cnn/utils/__init__.py +++ b/mmcv/cnn/utils/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .flops_counter import get_model_complexity_info from .fuse_conv_bn import fuse_conv_bn -from .sync_bn import revert_sync_batchnorm -__all__ = [ - 'get_model_complexity_info', 'fuse_conv_bn', 'revert_sync_batchnorm' -] +__all__ = ['get_model_complexity_info', 'fuse_conv_bn'] diff --git a/mmcv/cnn/utils/sync_bn.py b/mmcv/cnn/utils/sync_bn.py deleted file mode 100644 index c534fc0e17..0000000000 --- a/mmcv/cnn/utils/sync_bn.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn - -import mmcv - - -class _BatchNormXd(nn.modules.batchnorm._BatchNorm): - """A general BatchNorm layer without input dimension check. - - Reproduced from @kapily's work: - (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) - The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc - is `_check_input_dim` that is designed for tensor sanity checks. - The check has been bypassed in this class for the convenience of converting - SyncBatchNorm. - """ - - def _check_input_dim(self, input: torch.Tensor): - return - - -def revert_sync_batchnorm(module: nn.Module) -> nn.Module: - """Helper function to convert all `SyncBatchNorm` (SyncBN) and - `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to - `BatchNormXd` layers. - - Adapted from @kapily's work: - (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) - - Args: - module (nn.Module): The module containing `SyncBatchNorm` layers. - - Returns: - module_output: The converted module with `BatchNormXd` layers. - """ - module_output = module - module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] - if hasattr(mmcv, 'ops'): - module_checklist.append(mmcv.ops.SyncBatchNorm) - if isinstance(module, tuple(module_checklist)): - module_output = _BatchNormXd(module.num_features, module.eps, - module.momentum, module.affine, - module.track_running_stats) - if module.affine: - # no_grad() may not be needed here but - # just to be consistent with `convert_sync_batchnorm()` - with torch.no_grad(): - module_output.weight = module.weight - module_output.bias = module.bias - module_output.running_mean = module.running_mean - module_output.running_var = module.running_var - module_output.num_batches_tracked = module.num_batches_tracked - module_output.training = module.training - # qconfig exists in quantized models - if hasattr(module, 'qconfig'): - module_output.qconfig = module.qconfig - for name, child in module.named_children(): - module_output.add_module(name, revert_sync_batchnorm(child)) - del module - return module_output diff --git a/mmcv/image/geometric.py b/mmcv/image/geometric.py index 066f539882..59a93a32ae 100644 --- a/mmcv/image/geometric.py +++ b/mmcv/image/geometric.py @@ -5,8 +5,8 @@ import cv2 import numpy as np +from mmengine.utils import to_2tuple -from ..utils import to_2tuple from .io import imread_backend try: diff --git a/mmcv/image/io.py b/mmcv/image/io.py index b8f3a277c6..af13d38b66 100644 --- a/mmcv/image/io.py +++ b/mmcv/image/io.py @@ -9,8 +9,7 @@ from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION, IMREAD_UNCHANGED) from mmengine.fileio import FileClient - -from mmcv.utils import is_filepath, is_str +from mmengine.utils import is_filepath, is_str try: from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG diff --git a/mmcv/image/photometric.py b/mmcv/image/photometric.py index 2f2cfd0941..12cbb90822 100644 --- a/mmcv/image/photometric.py +++ b/mmcv/image/photometric.py @@ -4,9 +4,9 @@ import cv2 import numpy as np +from mmengine.utils import is_tuple_of from PIL import Image, ImageEnhance -from ..utils import is_tuple_of from .colorspace import bgr2gray, gray2bgr from .io import imread_backend diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 7b0953001c..4f349d50cf 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -6,12 +6,12 @@ import torch.nn.functional as F from mmengine import print_log from mmengine.registry import MODELS +from mmengine.utils import deprecated_api_warning from torch import Tensor from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single -from mmcv.utils import deprecated_api_warning from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', [ diff --git a/mmcv/ops/modulated_deform_conv.py b/mmcv/ops/modulated_deform_conv.py index c372419ec4..01478edd02 100644 --- a/mmcv/ops/modulated_deform_conv.py +++ b/mmcv/ops/modulated_deform_conv.py @@ -6,11 +6,11 @@ import torch.nn as nn from mmengine import print_log from mmengine.registry import MODELS +from mmengine.utils import deprecated_api_warning from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair, _single -from mmcv.utils import deprecated_api_warning from ..utils import ext_loader ext_module = ext_loader.load_ext( diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index 3f153b778f..9823aad98c 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -3,16 +3,16 @@ import warnings from typing import Optional, no_type_check +import mmengine import torch import torch.nn as nn import torch.nn.functional as F from mmengine.model import BaseModule from mmengine.model.utils import constant_init, xavier_init from mmengine.registry import MODELS +from mmengine.utils import deprecated_api_warning from torch.autograd.function import Function, once_differentiable -import mmcv -from mmcv import deprecated_api_warning from ..utils import ext_loader ext_module = ext_loader.load_ext( @@ -193,7 +193,7 @@ def __init__(self, dropout: float = 0.1, batch_first: bool = False, norm_cfg: Optional[dict] = None, - init_cfg: Optional[mmcv.ConfigDict] = None): + init_cfg: Optional[mmengine.ConfigDict] = None): super().__init__(init_cfg) if embed_dims % num_heads != 0: raise ValueError(f'embed_dims must be divisible by num_heads, ' diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index d41b1ac966..06282ab8a8 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -3,9 +3,9 @@ import numpy as np import torch +from mmengine.utils import deprecated_api_warning from torch import Tensor -from mmcv.utils import deprecated_api_warning from ..utils import ext_loader ext_module = ext_loader.load_ext( diff --git a/mmcv/ops/riroi_align_rotated.py b/mmcv/ops/riroi_align_rotated.py index 1de810cc5f..c4e5a542f2 100644 --- a/mmcv/ops/riroi_align_rotated.py +++ b/mmcv/ops/riroi_align_rotated.py @@ -3,9 +3,10 @@ import torch import torch.nn as nn +from mmengine.utils import is_tuple_of from torch.autograd import Function -from ..utils import ext_loader, is_tuple_of +from ..utils import ext_loader ext_module = ext_loader.load_ext( '_ext', ['riroi_align_rotated_forward', 'riroi_align_rotated_backward']) diff --git a/mmcv/ops/roi_align.py b/mmcv/ops/roi_align.py index ca802f60cd..8d26ad9481 100644 --- a/mmcv/ops/roi_align.py +++ b/mmcv/ops/roi_align.py @@ -3,11 +3,12 @@ import torch import torch.nn as nn +from mmengine.utils import deprecated_api_warning from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.utils import _pair -from ..utils import deprecated_api_warning, ext_loader +from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', ['roi_align_forward', 'roi_align_backward']) diff --git a/mmcv/ops/roi_align_rotated.py b/mmcv/ops/roi_align_rotated.py index f970ef4d8a..38e6ea3d32 100644 --- a/mmcv/ops/roi_align_rotated.py +++ b/mmcv/ops/roi_align_rotated.py @@ -3,10 +3,11 @@ import torch import torch.nn as nn +from mmengine.utils import deprecated_api_warning from torch.autograd import Function from torch.nn.modules.utils import _pair -from ..utils import deprecated_api_warning, ext_loader +from ..utils import ext_loader ext_module = ext_loader.load_ext( '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward']) diff --git a/mmcv/ops/roiaware_pool3d.py b/mmcv/ops/roiaware_pool3d.py index 9a09049b55..728f246809 100644 --- a/mmcv/ops/roiaware_pool3d.py +++ b/mmcv/ops/roiaware_pool3d.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Tuple, Union +import mmengine import torch from torch import nn as nn from torch.autograd import Function -import mmcv from ..utils import ext_loader ext_module = ext_loader.load_ext( @@ -86,7 +86,7 @@ def forward(ctx: Any, rois: torch.Tensor, pts: torch.Tensor, out_x = out_y = out_z = out_size else: assert len(out_size) == 3 - assert mmcv.is_tuple_of(out_size, int) + assert mmengine.is_tuple_of(out_size, int) out_x, out_y, out_z = out_size num_rois = rois.shape[0] diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index 60ab78bc24..ec0d09e0b8 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -4,10 +4,10 @@ import torch.nn.functional as F from mmengine.model.utils import constant_init from mmengine.registry import MODELS +from mmengine.utils import TORCH_VERSION, digit_version from mmcv.cnn import ConvAWS2d from mmcv.ops.deform_conv import deform_conv2d -from mmcv.utils import TORCH_VERSION, digit_version @MODELS.register_module(name='SAC') diff --git a/mmcv/ops/upfirdn2d.py b/mmcv/ops/upfirdn2d.py index 434238359a..574d4d315b 100644 --- a/mmcv/ops/upfirdn2d.py +++ b/mmcv/ops/upfirdn2d.py @@ -98,10 +98,10 @@ from typing import Any, List, Tuple, Union import torch +from mmengine.utils import to_2tuple from torch.autograd import Function from torch.nn import functional as F -from mmcv.utils import to_2tuple from ..utils import ext_loader upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d']) diff --git a/mmcv/transforms/formatting.py b/mmcv/transforms/formatting.py index 2a9bdbe44c..02089215e1 100644 --- a/mmcv/transforms/formatting.py +++ b/mmcv/transforms/formatting.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Sequence, Union +import mmengine import numpy as np import torch -import mmcv from .base import BaseTransform from .builder import TRANSFORMS @@ -29,7 +29,7 @@ def to_tensor( return data elif isinstance(data, np.ndarray): return torch.from_numpy(data) - elif isinstance(data, Sequence) and not mmcv.is_str(data): + elif isinstance(data, Sequence) and not mmengine.is_str(data): return torch.tensor(data) elif isinstance(data, int): return torch.LongTensor([data]) diff --git a/mmcv/transforms/processing.py b/mmcv/transforms/processing.py index 275a1c34f6..76ae33794f 100644 --- a/mmcv/transforms/processing.py +++ b/mmcv/transforms/processing.py @@ -3,6 +3,7 @@ import warnings from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union +import mmengine import numpy as np import mmcv @@ -797,7 +798,7 @@ def __init__( if scales is not None: self.scales = scales if isinstance(scales, list) else [scales] self.scale_key = 'scale' - assert mmcv.is_list_of(self.scales, tuple) + assert mmengine.is_list_of(self.scales, tuple) else: # if ``scales`` and ``scale_factor`` both be ``None`` if scale_factor is None: @@ -812,7 +813,7 @@ def __init__( self.allow_flip = allow_flip self.flip_direction = flip_direction if isinstance( flip_direction, list) else [flip_direction] - assert mmcv.is_list_of(self.flip_direction, str) + assert mmengine.is_list_of(self.flip_direction, str) if not self.allow_flip and self.flip_direction != ['horizontal']: warnings.warn( 'flip_direction has no effect when flip is set to False') @@ -934,7 +935,7 @@ def __init__( self.scales = scales else: self.scales = [scales] - assert mmcv.is_list_of(self.scales, tuple) + assert mmengine.is_list_of(self.scales, tuple) self.resize_cfg = dict(type=resize_type, **resize_kwargs) # create a empty Resize object @@ -950,7 +951,7 @@ def _random_select(self) -> Tuple[int, int]: ``scale_idx`` is the selected index in the given candidates. """ - assert mmcv.is_list_of(self.scales, tuple) + assert mmengine.is_list_of(self.scales, tuple) scale_idx = np.random.randint(len(self.scales)) scale = self.scales[scale_idx] return scale, scale_idx @@ -1033,7 +1034,7 @@ def __init__( direction: Union[str, Sequence[Optional[str]]] = 'horizontal') -> None: if isinstance(prob, list): - assert mmcv.is_list_of(prob, float) + assert mmengine.is_list_of(prob, float) assert 0 <= sum(prob) <= 1 elif isinstance(prob, float): assert 0 <= prob <= 1 @@ -1046,7 +1047,7 @@ def __init__( if isinstance(direction, str): assert direction in valid_directions elif isinstance(direction, list): - assert mmcv.is_list_of(direction, str) + assert mmengine.is_list_of(direction, str) assert set(direction).issubset(set(valid_directions)) else: raise ValueError(f'direction must be either str or list of str, \ @@ -1308,7 +1309,7 @@ def _random_sample(scales: Sequence[Tuple[int, int]]) -> tuple: tuple: The targeted scale of the image to be resized. """ - assert mmcv.is_list_of(scales, tuple) and len(scales) == 2 + assert mmengine.is_list_of(scales, tuple) and len(scales) == 2 scale_0 = [scales[0][0], scales[1][0]] scale_1 = [scales[0][1], scales[1][1]] edge_0 = np.random.randint(min(scale_0), max(scale_0) + 1) @@ -1350,12 +1351,12 @@ def _random_scale(self) -> tuple: tuple: The targeted scale of the image to be resized. """ - if mmcv.is_tuple_of(self.scale, int): + if mmengine.is_tuple_of(self.scale, int): assert self.ratio_range is not None and len(self.ratio_range) == 2 scale = self._random_sample_ratio( self.scale, # type: ignore self.ratio_range) - elif mmcv.is_seq_of(self.scale, tuple): + elif mmengine.is_seq_of(self.scale, tuple): scale = self._random_sample(self.scale) # type: ignore else: raise NotImplementedError('Do not support sampling function ' diff --git a/mmcv/transforms/wrappers.py b/mmcv/transforms/wrappers.py index 89ee48eda3..132ddcc4f9 100644 --- a/mmcv/transforms/wrappers.py +++ b/mmcv/transforms/wrappers.py @@ -2,9 +2,9 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union +import mmengine import numpy as np -import mmcv from .base import BaseTransform from .builder import TRANSFORMS from .utils import cache_random_params, cache_randomness @@ -569,7 +569,7 @@ def __init__(self, super().__init__() if prob is not None: - assert mmcv.is_seq_of(prob, float) + assert mmengine.is_seq_of(prob, float) assert len(transforms) == len(prob), \ '``transforms`` and ``prob`` must have same lengths. ' \ f'Got {len(transforms)} vs {len(prob)}.' diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 6bd3d3c8b6..242665a611 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -1,80 +1,9 @@ -# flake8: noqa # Copyright (c) OpenMMLab. All rights reserved. -from .config import Config, ConfigDict, DictAction -from .misc import (check_prerequisites, concat_list, deprecated_api_warning, - has_method, import_modules_from_strings, is_list_of, - is_method_overridden, is_seq_of, is_str, is_tuple_of, - iter_cast, list_cast, requires_executable, requires_package, - slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, - to_ntuple, tuple_cast) -from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, - scandir, symlink) -from .progressbar import (ProgressBar, track_iter_progress, - track_parallel_progress, track_progress) -from .testing import (assert_attrs_equal, assert_dict_contains_subset, - assert_dict_has_keys, assert_is_norm_layer, - assert_keys_equal, assert_params_all_zeros, - check_python_script) -from .timer import Timer, TimerError, check_time -from .version_utils import digit_version, get_git_hash +from .device_type import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE +from .env import collect_env +from .parrots_jit import jit, skip_no_elena -try: - import torch -except ImportError: - __all__ = [ - 'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast', - 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of', - 'slice_list', 'concat_list', 'check_prerequisites', 'requires_package', - 'requires_executable', 'is_filepath', 'fopen', 'check_file_exist', - 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', - 'track_progress', 'track_iter_progress', 'track_parallel_progress', - 'Timer', 'TimerError', 'check_time', 'deprecated_api_warning', - 'digit_version', 'get_git_hash', 'import_modules_from_strings', - 'assert_dict_contains_subset', 'assert_attrs_equal', - 'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script', - 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', - 'is_method_overridden', 'has_method' - ] -else: - from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE, - IS_MPS_AVAILABLE) - from .env import collect_env - from .logging import get_logger, print_log - from .parrots_jit import jit, skip_no_elena - # yapf: disable - from .parrots_wrapper import (IS_CUDA_AVAILABLE, TORCH_VERSION, - BuildExtension, CppExtension, CUDAExtension, - DataLoader, PoolDataLoader, SyncBatchNorm, - _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, - _AvgPoolNd, _BatchNorm, _ConvNd, - _ConvTransposeMixin, _get_cuda_home, - _InstanceNorm, _MaxPoolNd, get_build_config, - is_rocm_pytorch) - # yapf: enable - from .registry import Registry, build_from_cfg - from .seed import worker_init_fn - from .torch_ops import torch_meshgrid - from .trace import is_jit_tracing - __all__ = [ - 'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger', - 'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', - 'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list', - 'check_prerequisites', 'requires_package', 'requires_executable', - 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', - 'symlink', 'scandir', 'ProgressBar', 'track_progress', - 'track_iter_progress', 'track_parallel_progress', 'Registry', - 'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm', - '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm', - '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd', - 'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension', - 'DataLoader', 'PoolDataLoader', 'TORCH_VERSION', - 'deprecated_api_warning', 'digit_version', 'get_git_hash', - 'import_modules_from_strings', 'jit', 'skip_no_elena', - 'assert_dict_contains_subset', 'assert_attrs_equal', - 'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer', - 'assert_params_all_zeros', 'check_python_script', - 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', - '_get_cuda_home', 'has_method', 'IS_CUDA_AVAILABLE', 'worker_init_fn', - 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', 'IS_MPS_AVAILABLE', - 'torch_meshgrid' - ] +__all__ = [ + 'IS_MLU_AVAILABLE', 'IS_MPS_AVAILABLE', 'IS_CUDA_AVAILABLE', 'collect_env', + 'jit', 'skip_no_elena' +] diff --git a/mmcv/utils/config.py b/mmcv/utils/config.py deleted file mode 100644 index f5e9f1d979..0000000000 --- a/mmcv/utils/config.py +++ /dev/null @@ -1,740 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import ast -import copy -import os -import os.path as osp -import platform -import shutil -import sys -import tempfile -import types -import uuid -import warnings -from argparse import Action, ArgumentParser -from collections import abc -from importlib import import_module -from pathlib import Path - -import mmengine -from addict import Dict -from yapf.yapflib.yapf_api import FormatCode - -from .misc import import_modules_from_strings -from .path import check_file_exist - -if platform.system() == 'Windows': - import regex as re # type: ignore -else: - import re # type: ignore - -BASE_KEY = '_base_' -DELETE_KEY = '_delete_' -DEPRECATION_KEY = '_deprecation_' -RESERVED_KEYS = ['filename', 'text', 'pretty_text'] - - -class ConfigDict(Dict): - - def __missing__(self, name): - raise KeyError(name) - - def __getattr__(self, name): - try: - value = super().__getattr__(name) - except KeyError: - ex = AttributeError(f"'{self.__class__.__name__}' object has no " - f"attribute '{name}'") - except Exception as e: - ex = e - else: - return value - raise ex - - -def add_args(parser, cfg, prefix=''): - for k, v in cfg.items(): - if isinstance(v, str): - parser.add_argument('--' + prefix + k) - elif isinstance(v, int): - parser.add_argument('--' + prefix + k, type=int) - elif isinstance(v, float): - parser.add_argument('--' + prefix + k, type=float) - elif isinstance(v, bool): - parser.add_argument('--' + prefix + k, action='store_true') - elif isinstance(v, dict): - add_args(parser, v, prefix + k + '.') - elif isinstance(v, abc.Iterable): - parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') - else: - print(f'cannot parse key {prefix + k} of type {type(v)}') - return parser - - -class Config: - """A facility for config and config files. - - It supports common file formats as configs: python/json/yaml. The interface - is the same as a dict object and also allows access config values as - attributes. - - Example: - >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) - >>> cfg.a - 1 - >>> cfg.b - {'b1': [0, 1]} - >>> cfg.b.b1 - [0, 1] - >>> cfg = Config.fromfile('tests/data/config/a.py') - >>> cfg.filename - "/home/kchen/projects/mmcv/tests/data/config/a.py" - >>> cfg.item4 - 'test' - >>> cfg - "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " - "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" - """ - - @staticmethod - def _validate_py_syntax(filename): - with open(filename, encoding='utf-8') as f: - # Setting encoding explicitly to resolve coding issue on windows - content = f.read() - try: - ast.parse(content) - except SyntaxError as e: - raise SyntaxError('There are syntax errors in config ' - f'file {filename}: {e}') - - @staticmethod - def _substitute_predefined_vars(filename, temp_config_name): - file_dirname = osp.dirname(filename) - file_basename = osp.basename(filename) - file_basename_no_extension = osp.splitext(file_basename)[0] - file_extname = osp.splitext(filename)[1] - support_templates = dict( - fileDirname=file_dirname, - fileBasename=file_basename, - fileBasenameNoExtension=file_basename_no_extension, - fileExtname=file_extname) - with open(filename, encoding='utf-8') as f: - # Setting encoding explicitly to resolve coding issue on windows - config_file = f.read() - for key, value in support_templates.items(): - regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' - value = value.replace('\\', '/') - config_file = re.sub(regexp, value, config_file) - with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: - tmp_config_file.write(config_file) - - @staticmethod - def _pre_substitute_base_vars(filename, temp_config_name): - """Substitute base variable placehoders to string, so that parsing - would work.""" - with open(filename, encoding='utf-8') as f: - # Setting encoding explicitly to resolve coding issue on windows - config_file = f.read() - base_var_dict = {} - regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' - base_vars = set(re.findall(regexp, config_file)) - for base_var in base_vars: - randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' - base_var_dict[randstr] = base_var - regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' - config_file = re.sub(regexp, f'"{randstr}"', config_file) - with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: - tmp_config_file.write(config_file) - return base_var_dict - - @staticmethod - def _substitute_base_vars(cfg, base_var_dict, base_cfg): - """Substitute variable strings to their actual values.""" - cfg = copy.deepcopy(cfg) - - if isinstance(cfg, dict): - for k, v in cfg.items(): - if isinstance(v, str) and v in base_var_dict: - new_v = base_cfg - for new_k in base_var_dict[v].split('.'): - new_v = new_v[new_k] - cfg[k] = new_v - elif isinstance(v, (list, tuple, dict)): - cfg[k] = Config._substitute_base_vars( - v, base_var_dict, base_cfg) - elif isinstance(cfg, tuple): - cfg = tuple( - Config._substitute_base_vars(c, base_var_dict, base_cfg) - for c in cfg) - elif isinstance(cfg, list): - cfg = [ - Config._substitute_base_vars(c, base_var_dict, base_cfg) - for c in cfg - ] - elif isinstance(cfg, str) and cfg in base_var_dict: - new_v = base_cfg - for new_k in base_var_dict[cfg].split('.'): - new_v = new_v[new_k] - cfg = new_v - - return cfg - - @staticmethod - def _file2dict(filename, use_predefined_variables=True): - filename = osp.abspath(osp.expanduser(filename)) - check_file_exist(filename) - fileExtname = osp.splitext(filename)[1] - if fileExtname not in ['.py', '.json', '.yaml', '.yml']: - raise OSError('Only py/yml/yaml/json type are supported now!') - - with tempfile.TemporaryDirectory() as temp_config_dir: - temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix=fileExtname) - if platform.system() == 'Windows': - temp_config_file.close() - temp_config_name = osp.basename(temp_config_file.name) - # Substitute predefined variables - if use_predefined_variables: - Config._substitute_predefined_vars(filename, - temp_config_file.name) - else: - shutil.copyfile(filename, temp_config_file.name) - # Substitute base variables from placeholders to strings - base_var_dict = Config._pre_substitute_base_vars( - temp_config_file.name, temp_config_file.name) - - if filename.endswith('.py'): - temp_module_name = osp.splitext(temp_config_name)[0] - sys.path.insert(0, temp_config_dir) - Config._validate_py_syntax(filename) - mod = import_module(temp_module_name) - sys.path.pop(0) - cfg_dict = { - name: value - for name, value in mod.__dict__.items() - if not name.startswith('__') - and not isinstance(value, types.ModuleType) - and not isinstance(value, types.FunctionType) - } - # delete imported module - del sys.modules[temp_module_name] - elif filename.endswith(('.yml', '.yaml', '.json')): - cfg_dict = mmengine.load(temp_config_file.name) - # close temp file - temp_config_file.close() - - # check deprecation information - if DEPRECATION_KEY in cfg_dict: - deprecation_info = cfg_dict.pop(DEPRECATION_KEY) - warning_msg = f'The config file {filename} will be deprecated ' \ - 'in the future.' - if 'expected' in deprecation_info: - warning_msg += f' Please use {deprecation_info["expected"]} ' \ - 'instead.' - if 'reference' in deprecation_info: - warning_msg += ' More information can be found at ' \ - f'{deprecation_info["reference"]}' - warnings.warn(warning_msg, DeprecationWarning) - - cfg_text = filename + '\n' - with open(filename, encoding='utf-8') as f: - # Setting encoding explicitly to resolve coding issue on windows - cfg_text += f.read() - - if BASE_KEY in cfg_dict: - cfg_dir = osp.dirname(filename) - base_filename = cfg_dict.pop(BASE_KEY) - base_filename = base_filename if isinstance( - base_filename, list) else [base_filename] - - cfg_dict_list = list() - cfg_text_list = list() - for f in base_filename: - _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) - cfg_dict_list.append(_cfg_dict) - cfg_text_list.append(_cfg_text) - - base_cfg_dict = dict() - for c in cfg_dict_list: - duplicate_keys = base_cfg_dict.keys() & c.keys() - if len(duplicate_keys) > 0: - raise KeyError('Duplicate key is not allowed among bases. ' - f'Duplicate keys: {duplicate_keys}') - base_cfg_dict.update(c) - - # Substitute base variables from strings to their actual values - cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, - base_cfg_dict) - - base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) - cfg_dict = base_cfg_dict - - # merge cfg_text - cfg_text_list.append(cfg_text) - cfg_text = '\n'.join(cfg_text_list) - - return cfg_dict, cfg_text - - @staticmethod - def _merge_a_into_b(a, b, allow_list_keys=False): - """merge dict ``a`` into dict ``b`` (non-inplace). - - Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid - in-place modifications. - - Args: - a (dict): The source dict to be merged into ``b``. - b (dict): The origin dict to be fetch keys from ``a``. - allow_list_keys (bool): If True, int string keys (e.g. '0', '1') - are allowed in source ``a`` and will replace the element of the - corresponding index in b if b is a list. Default: False. - - Returns: - dict: The modified dict of ``b`` using ``a``. - - Examples: - # Normally merge a into b. - >>> Config._merge_a_into_b( - ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) - {'obj': {'a': 2}} - - # Delete b first and merge a into b. - >>> Config._merge_a_into_b( - ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) - {'obj': {'a': 2}} - - # b is a list - >>> Config._merge_a_into_b( - ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) - [{'a': 2}, {'b': 2}] - """ - b = b.copy() - for k, v in a.items(): - if allow_list_keys and k.isdigit() and isinstance(b, list): - k = int(k) - if len(b) <= k: - raise KeyError(f'Index {k} exceeds the length of list {b}') - b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) - elif isinstance(v, dict): - if k in b and not v.pop(DELETE_KEY, False): - allowed_types = (dict, list) if allow_list_keys else dict - if not isinstance(b[k], allowed_types): - raise TypeError( - f'{k}={v} in child config cannot inherit from ' - f'base because {k} is a dict in the child config ' - f'but is of type {type(b[k])} in base config. ' - f'You may set `{DELETE_KEY}=True` to ignore the ' - f'base config.') - b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) - else: - b[k] = ConfigDict(v) - else: - b[k] = v - return b - - @staticmethod - def fromfile(filename, - use_predefined_variables=True, - import_custom_modules=True): - if isinstance(filename, Path): - filename = str(filename) - cfg_dict, cfg_text = Config._file2dict(filename, - use_predefined_variables) - if import_custom_modules and cfg_dict.get('custom_imports', None): - import_modules_from_strings(**cfg_dict['custom_imports']) - return Config(cfg_dict, cfg_text=cfg_text, filename=filename) - - @staticmethod - def fromstring(cfg_str, file_format): - """Generate config from config str. - - Args: - cfg_str (str): Config str. - file_format (str): Config file format corresponding to the - config str. Only py/yml/yaml/json type are supported now! - - Returns: - :obj:`Config`: Config obj. - """ - if file_format not in ['.py', '.json', '.yaml', '.yml']: - raise OSError('Only py/yml/yaml/json type are supported now!') - if file_format != '.py' and 'dict(' in cfg_str: - # check if users specify a wrong suffix for python - warnings.warn( - 'Please check "file_format", the file format may be .py') - with tempfile.NamedTemporaryFile( - 'w', encoding='utf-8', suffix=file_format, - delete=False) as temp_file: - temp_file.write(cfg_str) - # on windows, previous implementation cause error - # see PR 1077 for details - cfg = Config.fromfile(temp_file.name) - os.remove(temp_file.name) - return cfg - - @staticmethod - def auto_argparser(description=None): - """Generate argparser from config file automatically (experimental)""" - partial_parser = ArgumentParser(description=description) - partial_parser.add_argument('config', help='config file path') - cfg_file = partial_parser.parse_known_args()[0].config - cfg = Config.fromfile(cfg_file) - parser = ArgumentParser(description=description) - parser.add_argument('config', help='config file path') - add_args(parser, cfg) - return parser, cfg - - def __init__(self, cfg_dict=None, cfg_text=None, filename=None): - if cfg_dict is None: - cfg_dict = dict() - elif not isinstance(cfg_dict, dict): - raise TypeError('cfg_dict must be a dict, but ' - f'got {type(cfg_dict)}') - for key in cfg_dict: - if key in RESERVED_KEYS: - raise KeyError(f'{key} is reserved for config file') - - if isinstance(filename, Path): - filename = str(filename) - - super().__setattr__('_cfg_dict', ConfigDict(cfg_dict)) - super().__setattr__('_filename', filename) - if cfg_text: - text = cfg_text - elif filename: - with open(filename) as f: - text = f.read() - else: - text = '' - super().__setattr__('_text', text) - - @property - def filename(self): - return self._filename - - @property - def text(self): - return self._text - - @property - def pretty_text(self): - - indent = 4 - - def _indent(s_, num_spaces): - s = s_.split('\n') - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s - return s - - def _format_basic_types(k, v, use_mapping=False): - if isinstance(v, str): - v_str = f"'{v}'" - else: - v_str = str(v) - - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) - - return attr_str - - def _format_list(k, v, use_mapping=False): - # check if all items in the list are dict - if all(isinstance(_, dict) for _ in v): - v_str = '[\n' - v_str += '\n'.join( - f'dict({_indent(_format_dict(v_), indent)}),' - for v_ in v).rstrip(',') - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + ']' - else: - attr_str = _format_basic_types(k, v, use_mapping) - return attr_str - - def _contain_invalid_identifier(dict_str): - contain_invalid_identifier = False - for key_name in dict_str: - contain_invalid_identifier |= \ - (not str(key_name).isidentifier()) - return contain_invalid_identifier - - def _format_dict(input_dict, outest_level=False): - r = '' - s = [] - - use_mapping = _contain_invalid_identifier(input_dict) - if use_mapping: - r += '{' - for idx, (k, v) in enumerate(input_dict.items()): - is_last = idx >= len(input_dict) - 1 - end = '' if outest_level or is_last else ',' - if isinstance(v, dict): - v_str = '\n' + _format_dict(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: dict({v_str}' - else: - attr_str = f'{str(k)}=dict({v_str}' - attr_str = _indent(attr_str, indent) + ')' + end - elif isinstance(v, list): - attr_str = _format_list(k, v, use_mapping) + end - else: - attr_str = _format_basic_types(k, v, use_mapping) + end - - s.append(attr_str) - r += '\n'.join(s) - if use_mapping: - r += '}' - return r - - cfg_dict = self._cfg_dict.to_dict() - text = _format_dict(cfg_dict, outest_level=True) - # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) - text, _ = FormatCode(text, style_config=yapf_style, verify=True) - - return text - - def __repr__(self): - return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' - - def __len__(self): - return len(self._cfg_dict) - - def __getattr__(self, name): - return getattr(self._cfg_dict, name) - - def __getitem__(self, name): - return self._cfg_dict.__getitem__(name) - - def __setattr__(self, name, value): - if isinstance(value, dict): - value = ConfigDict(value) - self._cfg_dict.__setattr__(name, value) - - def __setitem__(self, name, value): - if isinstance(value, dict): - value = ConfigDict(value) - self._cfg_dict.__setitem__(name, value) - - def __iter__(self): - return iter(self._cfg_dict) - - def __getstate__(self): - return (self._cfg_dict, self._filename, self._text) - - def __copy__(self): - cls = self.__class__ - other = cls.__new__(cls) - other.__dict__.update(self.__dict__) - - return other - - def __deepcopy__(self, memo): - cls = self.__class__ - other = cls.__new__(cls) - memo[id(self)] = other - - for key, value in self.__dict__.items(): - super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) - - return other - - def __setstate__(self, state): - _cfg_dict, _filename, _text = state - super().__setattr__('_cfg_dict', _cfg_dict) - super().__setattr__('_filename', _filename) - super().__setattr__('_text', _text) - - def dump(self, file=None): - """Dumps config into a file or returns a string representation of the - config. - - If a file argument is given, saves the config to that file using the - format defined by the file argument extension. - - Otherwise, returns a string representing the config. The formatting of - this returned string is defined by the extension of `self.filename`. If - `self.filename` is not defined, returns a string representation of a - dict (lowercased and using ' for strings). - - Examples: - >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), - ... item3=True, item4='test') - >>> cfg = Config(cfg_dict=cfg_dict) - >>> dump_file = "a.py" - >>> cfg.dump(dump_file) - - Args: - file (str, optional): Path of the output file where the config - will be dumped. Defaults to None. - """ - cfg_dict = super().__getattribute__('_cfg_dict').to_dict() - if file is None: - if self.filename is None or self.filename.endswith('.py'): - return self.pretty_text - else: - file_format = self.filename.split('.')[-1] - return mmengine.dump(cfg_dict, file_format=file_format) - elif file.endswith('.py'): - with open(file, 'w', encoding='utf-8') as f: - f.write(self.pretty_text) - else: - file_format = file.split('.')[-1] - return mmengine.dump(cfg_dict, file=file, file_format=file_format) - - def merge_from_dict(self, options, allow_list_keys=True): - """Merge list into cfg_dict. - - Merge the dict parsed by MultipleKVAction into this cfg. - - Examples: - >>> options = {'model.backbone.depth': 50, - ... 'model.backbone.with_cp':True} - >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) - >>> cfg.merge_from_dict(options) - >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') - >>> assert cfg_dict == dict( - ... model=dict(backbone=dict(depth=50, with_cp=True))) - - >>> # Merge list element - >>> cfg = Config(dict(pipeline=[ - ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) - >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) - >>> cfg.merge_from_dict(options, allow_list_keys=True) - >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') - >>> assert cfg_dict == dict(pipeline=[ - ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) - - Args: - options (dict): dict of configs to merge from. - allow_list_keys (bool): If True, int string keys (e.g. '0', '1') - are allowed in ``options`` and will replace the element of the - corresponding index in the config if the config is a list. - Default: True. - """ - option_cfg_dict = {} - for full_key, v in options.items(): - d = option_cfg_dict - key_list = full_key.split('.') - for subkey in key_list[:-1]: - d.setdefault(subkey, ConfigDict()) - d = d[subkey] - subkey = key_list[-1] - d[subkey] = v - - cfg_dict = super().__getattribute__('_cfg_dict') - super().__setattr__( - '_cfg_dict', - Config._merge_a_into_b( - option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) - - -class DictAction(Action): - """ - argparse action to split an argument into KEY=VALUE form - on the first = and append to a dictionary. List options can - be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit - brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build - list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' - """ - - @staticmethod - def _parse_int_float_bool(val): - try: - return int(val) - except ValueError: - pass - try: - return float(val) - except ValueError: - pass - if val.lower() in ['true', 'false']: - return True if val.lower() == 'true' else False - if val == 'None': - return None - return val - - @staticmethod - def _parse_iterable(val): - """Parse iterable values in the string. - - All elements inside '()' or '[]' are treated as iterable values. - - Args: - val (str): Value string. - - Returns: - list | tuple: The expanded list or tuple from the string. - - Examples: - >>> DictAction._parse_iterable('1,2,3') - [1, 2, 3] - >>> DictAction._parse_iterable('[a, b, c]') - ['a', 'b', 'c'] - >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') - [(1, 2, 3), ['a', 'b'], 'c'] - """ - - def find_next_comma(string): - """Find the position of next comma in the string. - - If no ',' is found in the string, return the string length. All - chars inside '()' and '[]' are treated as one element and thus ',' - inside these brackets are ignored. - """ - assert (string.count('(') == string.count(')')) and ( - string.count('[') == string.count(']')), \ - f'Imbalanced brackets exist in {string}' - end = len(string) - for idx, char in enumerate(string): - pre = string[:idx] - # The string before this ',' is balanced - if ((char == ',') and (pre.count('(') == pre.count(')')) - and (pre.count('[') == pre.count(']'))): - end = idx - break - return end - - # Strip ' and " characters and replace whitespace. - val = val.strip('\'\"').replace(' ', '') - is_tuple = False - if val.startswith('(') and val.endswith(')'): - is_tuple = True - val = val[1:-1] - elif val.startswith('[') and val.endswith(']'): - val = val[1:-1] - elif ',' not in val: - # val is a single value - return DictAction._parse_int_float_bool(val) - - values = [] - while len(val) > 0: - comma_idx = find_next_comma(val) - element = DictAction._parse_iterable(val[:comma_idx]) - values.append(element) - val = val[comma_idx + 1:] - if is_tuple: - values = tuple(values) - return values - - def __call__(self, parser, namespace, values, option_string=None): - options = {} - for kv in values: - key, val = kv.split('=', maxsplit=1) - options[key] = self._parse_iterable(val) - setattr(namespace, self.dest, options) diff --git a/mmcv/utils/device_type.py b/mmcv/utils/device_type.py index d42ff72e9f..84b185e8e7 100644 --- a/mmcv/utils/device_type.py +++ b/mmcv/utils/device_type.py @@ -1,40 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. - - -def is_ipu_available() -> bool: - try: - import poptorch - return poptorch.ipuHardwareIsAvailable() - except ImportError: - return False - - -IS_IPU_AVAILABLE = is_ipu_available() - - -def is_mlu_available() -> bool: - try: - import torch - return (hasattr(torch, 'is_mlu_available') - and torch.is_mlu_available()) - except Exception: - return False - +from mmengine.device import (is_cuda_available, is_mlu_available, + is_mps_available) IS_MLU_AVAILABLE = is_mlu_available() - - -def is_mps_available() -> bool: - """Return True if mps devices exist. - - It's specialized for mac m1 chips and require torch version 1.12 or higher. - """ - try: - import torch - return hasattr(torch.backends, - 'mps') and torch.backends.mps.is_available() - except Exception: - return False - - IS_MPS_AVAILABLE = is_mps_available() +IS_CUDA_AVAILABLE = is_cuda_available() diff --git a/mmcv/utils/env.py b/mmcv/utils/env.py index 511332506f..bf0a9f45f1 100644 --- a/mmcv/utils/env.py +++ b/mmcv/utils/env.py @@ -1,16 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. """This file holding some environment constant for sharing by other files.""" -import os.path as osp -import subprocess -import sys -from collections import defaultdict - -import cv2 -import torch +from mmengine.utils import collect_env as mmengine_collect_env import mmcv -from .parrots_wrapper import get_build_config def collect_env(): @@ -32,80 +25,12 @@ def collect_env(): ``torch.__config__.show()``. - TorchVision (optional): TorchVision version. - OpenCV: OpenCV version. + - MMEngine: MMEngine version. - MMCV: MMCV version. - MMCV Compiler: The GCC version for compiling MMCV ops. - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops. """ - env_info = {} - env_info['sys.platform'] = sys.platform - env_info['Python'] = sys.version.replace('\n', '') - - cuda_available = torch.cuda.is_available() - env_info['CUDA available'] = cuda_available - - if cuda_available: - devices = defaultdict(list) - for k in range(torch.cuda.device_count()): - devices[torch.cuda.get_device_name(k)].append(str(k)) - for name, device_ids in devices.items(): - env_info['GPU ' + ','.join(device_ids)] = name - - from mmcv.utils.parrots_wrapper import _get_cuda_home - CUDA_HOME = _get_cuda_home() - env_info['CUDA_HOME'] = CUDA_HOME - - if CUDA_HOME is not None and osp.isdir(CUDA_HOME): - try: - nvcc = osp.join(CUDA_HOME, 'bin/nvcc') - nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True) - nvcc = nvcc.decode('utf-8').strip() - release = nvcc.rfind('Cuda compilation tools') - build = nvcc.rfind('Build ') - nvcc = nvcc[release:build].strip() - except subprocess.SubprocessError: - nvcc = 'Not Available' - env_info['NVCC'] = nvcc - - try: - # Check C++ Compiler. - # For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...', - # indicating the compiler used, we use this to get the compiler name - import sysconfig - cc = sysconfig.get_config_var('CC') - if cc: - cc = osp.basename(cc.split()[0]) - cc_info = subprocess.check_output(f'{cc} --version', shell=True) - env_info['GCC'] = cc_info.decode('utf-8').partition( - '\n')[0].strip() - else: - # on Windows, cl.exe is not in PATH. We need to find the path. - # distutils.ccompiler.new_compiler() returns a msvccompiler - # object and after initialization, path to cl.exe is found. - import locale - import os - from distutils.ccompiler import new_compiler - ccompiler = new_compiler() - ccompiler.initialize() - cc = subprocess.check_output( - f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True) - encoding = os.device_encoding( - sys.stdout.fileno()) or locale.getpreferredencoding() - env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip() - env_info['GCC'] = 'n/a' - except subprocess.CalledProcessError: - env_info['GCC'] = 'n/a' - - env_info['PyTorch'] = torch.__version__ - env_info['PyTorch compiling details'] = get_build_config() - - try: - import torchvision - env_info['TorchVision'] = torchvision.__version__ - except ModuleNotFoundError: - pass - - env_info['OpenCV'] = cv2.__version__ - + env_info = mmengine_collect_env() env_info['MMCV'] = mmcv.__version__ try: diff --git a/mmcv/utils/logging.py b/mmcv/utils/logging.py deleted file mode 100644 index 5a90aac8b2..0000000000 --- a/mmcv/utils/logging.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging - -import torch.distributed as dist - -logger_initialized: dict = {} - - -def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): - """Initialize and get a logger by name. - - If the logger has not been initialized, this method will initialize the - logger by adding one or two handlers, otherwise the initialized logger will - be directly returned. During initialization, a StreamHandler will always be - added. If `log_file` is specified and the process rank is 0, a FileHandler - will also be added. - - Args: - name (str): Logger name. - log_file (str | None): The log filename. If specified, a FileHandler - will be added to the logger. - log_level (int): The logger level. Note that only the process of - rank 0 is affected, and other processes will set the level to - "Error" thus be silent most of the time. - file_mode (str): The file mode used in opening log file. - Defaults to 'w'. - - Returns: - logging.Logger: The expected logger. - """ - logger = logging.getLogger(name) - if name in logger_initialized: - return logger - # handle hierarchical names - # e.g., logger "a" is initialized, then logger "a.b" will skip the - # initialization since it is a child of "a". - for logger_name in logger_initialized: - if name.startswith(logger_name): - return logger - - # handle duplicate logs to the console - # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) - # to the root logger. As logger.propagate is True by default, this root - # level handler causes logging messages from rank>0 processes to - # unexpectedly show up on the console, creating much unwanted clutter. - # To fix this issue, we set the root logger's StreamHandler, if any, to log - # at the ERROR level. - for handler in logger.root.handlers: - if type(handler) is logging.StreamHandler: - handler.setLevel(logging.ERROR) - - stream_handler = logging.StreamHandler() - handlers = [stream_handler] - - if dist.is_available() and dist.is_initialized(): - rank = dist.get_rank() - else: - rank = 0 - - # only rank 0 will add a FileHandler - if rank == 0 and log_file is not None: - # Here, the default behaviour of the official logger is 'a'. Thus, we - # provide an interface to change the file mode to the default - # behaviour. - file_handler = logging.FileHandler(log_file, file_mode) - handlers.append(file_handler) - - formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s') - for handler in handlers: - handler.setFormatter(formatter) - handler.setLevel(log_level) - logger.addHandler(handler) - - if rank == 0: - logger.setLevel(log_level) - else: - logger.setLevel(logging.ERROR) - - logger_initialized[name] = True - - return logger - - -def print_log(msg, logger=None, level=logging.INFO): - """Print a log message. - - Args: - msg (str): The message to be logged. - logger (logging.Logger | str | None): The logger to be used. - Some special loggers are: - - - "silent": no message will be printed. - - other str: the logger obtained with `get_root_logger(logger)`. - - None: The `print()` method will be used to print log messages. - level (int): Logging level. Only available when `logger` is a Logger - object or "root". - """ - if logger is None: - print(msg) - elif isinstance(logger, logging.Logger): - logger.log(level, msg) - elif logger == 'silent': - pass - elif isinstance(logger, str): - _logger = get_logger(logger) - _logger.log(level, msg) - else: - raise TypeError( - 'logger should be either a logging.Logger object, str, ' - f'"silent" or None, but got {type(logger)}') diff --git a/mmcv/utils/misc.py b/mmcv/utils/misc.py deleted file mode 100644 index 7957ea89b7..0000000000 --- a/mmcv/utils/misc.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import collections.abc -import functools -import itertools -import subprocess -import warnings -from collections import abc -from importlib import import_module -from inspect import getfullargspec -from itertools import repeat - - -# From PyTorch internals -def _ntuple(n): - - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = _ntuple - - -def is_str(x): - """Whether the input is an string instance. - - Note: This method is deprecated since python 2 is no longer supported. - """ - return isinstance(x, str) - - -def import_modules_from_strings(imports, allow_failed_imports=False): - """Import modules from the given list of strings. - - Args: - imports (list | str | None): The given module names to be imported. - allow_failed_imports (bool): If True, the failed imports will return - None. Otherwise, an ImportError is raise. Default: False. - - Returns: - list[module] | module | None: The imported modules. - - Examples: - >>> osp, sys = import_modules_from_strings( - ... ['os.path', 'sys']) - >>> import os.path as osp_ - >>> import sys as sys_ - >>> assert osp == osp_ - >>> assert sys == sys_ - """ - if not imports: - return - single_import = False - if isinstance(imports, str): - single_import = True - imports = [imports] - if not isinstance(imports, list): - raise TypeError( - f'custom_imports must be a list but got type {type(imports)}') - imported = [] - for imp in imports: - if not isinstance(imp, str): - raise TypeError( - f'{imp} is of type {type(imp)} and cannot be imported.') - try: - imported_tmp = import_module(imp) - except ImportError: - if allow_failed_imports: - warnings.warn(f'{imp} failed to import and is ignored.', - UserWarning) - imported_tmp = None - else: - raise ImportError - imported.append(imported_tmp) - if single_import: - imported = imported[0] - return imported - - -def iter_cast(inputs, dst_type, return_type=None): - """Cast elements of an iterable object into some type. - - Args: - inputs (Iterable): The input object. - dst_type (type): Destination type. - return_type (type, optional): If specified, the output object will be - converted to this type, otherwise an iterator. - - Returns: - iterator or specified type: The converted object. - """ - if not isinstance(inputs, abc.Iterable): - raise TypeError('inputs must be an iterable object') - if not isinstance(dst_type, type): - raise TypeError('"dst_type" must be a valid type') - - out_iterable = map(dst_type, inputs) - - if return_type is None: - return out_iterable - else: - return return_type(out_iterable) - - -def list_cast(inputs, dst_type): - """Cast elements of an iterable object into a list of some type. - - A partial method of :func:`iter_cast`. - """ - return iter_cast(inputs, dst_type, return_type=list) - - -def tuple_cast(inputs, dst_type): - """Cast elements of an iterable object into a tuple of some type. - - A partial method of :func:`iter_cast`. - """ - return iter_cast(inputs, dst_type, return_type=tuple) - - -def is_seq_of(seq, expected_type, seq_type=None): - """Check whether it is a sequence of some type. - - Args: - seq (Sequence): The sequence to be checked. - expected_type (type): Expected type of sequence items. - seq_type (type, optional): Expected sequence type. - - Returns: - bool: Whether the sequence is valid. - """ - if seq_type is None: - exp_seq_type = abc.Sequence - else: - assert isinstance(seq_type, type) - exp_seq_type = seq_type - if not isinstance(seq, exp_seq_type): - return False - for item in seq: - if not isinstance(item, expected_type): - return False - return True - - -def is_list_of(seq, expected_type): - """Check whether it is a list of some type. - - A partial method of :func:`is_seq_of`. - """ - return is_seq_of(seq, expected_type, seq_type=list) - - -def is_tuple_of(seq, expected_type): - """Check whether it is a tuple of some type. - - A partial method of :func:`is_seq_of`. - """ - return is_seq_of(seq, expected_type, seq_type=tuple) - - -def slice_list(in_list, lens): - """Slice a list into several sub lists by a list of given length. - - Args: - in_list (list): The list to be sliced. - lens(int or list): The expected length of each out list. - - Returns: - list: A list of sliced list. - """ - if isinstance(lens, int): - assert len(in_list) % lens == 0 - lens = [lens] * int(len(in_list) / lens) - if not isinstance(lens, list): - raise TypeError('"indices" must be an integer or a list of integers') - elif sum(lens) != len(in_list): - raise ValueError('sum of lens and list length does not ' - f'match: {sum(lens)} != {len(in_list)}') - out_list = [] - idx = 0 - for i in range(len(lens)): - out_list.append(in_list[idx:idx + lens[i]]) - idx += lens[i] - return out_list - - -def concat_list(in_list): - """Concatenate a list of list into a single list. - - Args: - in_list (list): The list of list to be merged. - - Returns: - list: The concatenated flat list. - """ - return list(itertools.chain(*in_list)) - - -def check_prerequisites( - prerequisites, - checker, - msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' - 'found, please install them first.'): # yapf: disable - """A decorator factory to check if prerequisites are satisfied. - - Args: - prerequisites (str of list[str]): Prerequisites to be checked. - checker (callable): The checker method that returns True if a - prerequisite is meet, False otherwise. - msg_tmpl (str): The message template with two variables. - - Returns: - decorator: A specific decorator. - """ - - def wrap(func): - - @functools.wraps(func) - def wrapped_func(*args, **kwargs): - requirements = [prerequisites] if isinstance( - prerequisites, str) else prerequisites - missing = [] - for item in requirements: - if not checker(item): - missing.append(item) - if missing: - print(msg_tmpl.format(', '.join(missing), func.__name__)) - raise RuntimeError('Prerequisites not meet.') - else: - return func(*args, **kwargs) - - return wrapped_func - - return wrap - - -def _check_py_package(package): - try: - import_module(package) - except ImportError: - return False - else: - return True - - -def _check_executable(cmd): - if subprocess.call(f'which {cmd}', shell=True) != 0: - return False - else: - return True - - -def requires_package(prerequisites): - """A decorator to check if some python packages are installed. - - Example: - >>> @requires_package('numpy') - >>> func(arg1, args): - >>> return numpy.zeros(1) - array([0.]) - >>> @requires_package(['numpy', 'non_package']) - >>> func(arg1, args): - >>> return numpy.zeros(1) - ImportError - """ - return check_prerequisites(prerequisites, checker=_check_py_package) - - -def requires_executable(prerequisites): - """A decorator to check if some executable files are installed. - - Example: - >>> @requires_executable('ffmpeg') - >>> func(arg1, args): - >>> print(1) - 1 - """ - return check_prerequisites(prerequisites, checker=_check_executable) - - -def deprecated_api_warning(name_dict, cls_name=None): - """A decorator to check if some arguments are deprecate and try to replace - deprecate src_arg_name to dst_arg_name. - - Args: - name_dict(dict): - key (str): Deprecate argument names. - val (str): Expected argument names. - - Returns: - func: New function. - """ - - def api_warning_wrapper(old_func): - - @functools.wraps(old_func) - def new_func(*args, **kwargs): - # get the arg spec of the decorated method - args_info = getfullargspec(old_func) - # get name of the function - func_name = old_func.__name__ - if cls_name is not None: - func_name = f'{cls_name}.{func_name}' - if args: - arg_names = args_info.args[:len(args)] - for src_arg_name, dst_arg_name in name_dict.items(): - if src_arg_name in arg_names: - warnings.warn( - f'"{src_arg_name}" is deprecated in ' - f'`{func_name}`, please use "{dst_arg_name}" ' - 'instead', DeprecationWarning) - arg_names[arg_names.index(src_arg_name)] = dst_arg_name - if kwargs: - for src_arg_name, dst_arg_name in name_dict.items(): - if src_arg_name in kwargs: - - assert dst_arg_name not in kwargs, ( - f'The expected behavior is to replace ' - f'the deprecated key `{src_arg_name}` to ' - f'new key `{dst_arg_name}`, but got them ' - f'in the arguments at the same time, which ' - f'is confusing. `{src_arg_name} will be ' - f'deprecated in the future, please ' - f'use `{dst_arg_name}` instead.') - - warnings.warn( - f'"{src_arg_name}" is deprecated in ' - f'`{func_name}`, please use "{dst_arg_name}" ' - 'instead', DeprecationWarning) - kwargs[dst_arg_name] = kwargs.pop(src_arg_name) - - # apply converted arguments to the decorated method - output = old_func(*args, **kwargs) - return output - - return new_func - - return api_warning_wrapper - - -def is_method_overridden(method, base_class, derived_class): - """Check if a method of base class is overridden in derived class. - - Args: - method (str): the method name to check. - base_class (type): the class of the base class. - derived_class (type | Any): the class or instance of the derived class. - """ - assert isinstance(base_class, type), \ - "base_class doesn't accept instance, Please pass class instead." - - if not isinstance(derived_class, type): - derived_class = derived_class.__class__ - - base_method = getattr(base_class, method) - derived_method = getattr(derived_class, method) - return derived_method != base_method - - -def has_method(obj: object, method: str) -> bool: - """Check whether the object has a method. - - Args: - method (str): The method name to check. - obj (object): The object to check. - - Returns: - bool: True if the object has the method else False. - """ - return hasattr(obj, method) and callable(getattr(obj, method)) diff --git a/mmcv/utils/parrots_jit.py b/mmcv/utils/parrots_jit.py index 61873f6dbb..2b51c039ca 100644 --- a/mmcv/utils/parrots_jit.py +++ b/mmcv/utils/parrots_jit.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os -from .parrots_wrapper import TORCH_VERSION +from mmengine.utils.parrots_wrapper import TORCH_VERSION parrots_jit_option = os.getenv('PARROTS_JIT_OPTION') diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py deleted file mode 100644 index cf2c7e5ce0..0000000000 --- a/mmcv/utils/parrots_wrapper.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from functools import partial - -import torch - -TORCH_VERSION = torch.__version__ - - -def is_cuda_available() -> bool: - return torch.cuda.is_available() - - -IS_CUDA_AVAILABLE = is_cuda_available() - - -def is_rocm_pytorch() -> bool: - is_rocm = False - if TORCH_VERSION != 'parrots': - try: - from torch.utils.cpp_extension import ROCM_HOME - is_rocm = True if ((torch.version.hip is not None) and - (ROCM_HOME is not None)) else False - except ImportError: - pass - return is_rocm - - -def _get_cuda_home(): - if TORCH_VERSION == 'parrots': - from parrots.utils.build_extension import CUDA_HOME - else: - if is_rocm_pytorch(): - from torch.utils.cpp_extension import ROCM_HOME - CUDA_HOME = ROCM_HOME - else: - from torch.utils.cpp_extension import CUDA_HOME - return CUDA_HOME - - -def get_build_config(): - if TORCH_VERSION == 'parrots': - from parrots.config import get_build_info - return get_build_info() - else: - return torch.__config__.show() - - -def _get_conv(): - if TORCH_VERSION == 'parrots': - from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin - else: - from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin - return _ConvNd, _ConvTransposeMixin - - -def _get_dataloader(): - if TORCH_VERSION == 'parrots': - from torch.utils.data import DataLoader, PoolDataLoader - else: - from torch.utils.data import DataLoader - PoolDataLoader = DataLoader - return DataLoader, PoolDataLoader - - -def _get_extension(): - if TORCH_VERSION == 'parrots': - from parrots.utils.build_extension import BuildExtension, Extension - CppExtension = partial(Extension, cuda=False) - CUDAExtension = partial(Extension, cuda=True) - else: - from torch.utils.cpp_extension import (BuildExtension, CppExtension, - CUDAExtension) - return BuildExtension, CppExtension, CUDAExtension - - -def _get_pool(): - if TORCH_VERSION == 'parrots': - from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, _AvgPoolNd, - _MaxPoolNd) - else: - from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, - _AdaptiveMaxPoolNd, _AvgPoolNd, - _MaxPoolNd) - return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd - - -def _get_norm(): - if TORCH_VERSION == 'parrots': - from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm - SyncBatchNorm_ = torch.nn.SyncBatchNorm2d - else: - from torch.nn.modules.batchnorm import _BatchNorm - from torch.nn.modules.instancenorm import _InstanceNorm - SyncBatchNorm_ = torch.nn.SyncBatchNorm - return _BatchNorm, _InstanceNorm, SyncBatchNorm_ - - -_ConvNd, _ConvTransposeMixin = _get_conv() -DataLoader, PoolDataLoader = _get_dataloader() -BuildExtension, CppExtension, CUDAExtension = _get_extension() -_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm() -_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool() - - -class SyncBatchNorm(SyncBatchNorm_): # type: ignore - - def _check_input_dim(self, input): - if TORCH_VERSION == 'parrots': - if input.dim() < 2: - raise ValueError( - f'expected at least 2D input (got {input.dim()}D input)') - else: - super()._check_input_dim(input) diff --git a/mmcv/utils/path.py b/mmcv/utils/path.py deleted file mode 100644 index 5680818377..0000000000 --- a/mmcv/utils/path.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import os.path as osp -from pathlib import Path - -from .misc import is_str - - -def is_filepath(x): - return is_str(x) or isinstance(x, Path) - - -def fopen(filepath, *args, **kwargs): - if is_str(filepath): - return open(filepath, *args, **kwargs) - elif isinstance(filepath, Path): - return filepath.open(*args, **kwargs) - raise ValueError('`filepath` should be a string or a Path') - - -def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): - if not osp.isfile(filename): - raise FileNotFoundError(msg_tmpl.format(filename)) - - -def mkdir_or_exist(dir_name, mode=0o777): - if dir_name == '': - return - dir_name = osp.expanduser(dir_name) - os.makedirs(dir_name, mode=mode, exist_ok=True) - - -def symlink(src, dst, overwrite=True, **kwargs): - if os.path.lexists(dst) and overwrite: - os.remove(dst) - os.symlink(src, dst, **kwargs) - - -def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): - """Scan a directory to find the interested files. - - Args: - dir_path (str | :obj:`Path`): Path of the directory. - suffix (str | tuple(str), optional): File suffix that we are - interested in. Default: None. - recursive (bool, optional): If set to True, recursively scan the - directory. Default: False. - case_sensitive (bool, optional) : If set to False, ignore the case of - suffix. Default: True. - - Returns: - A generator for all the interested files with relative paths. - """ - if isinstance(dir_path, (str, Path)): - dir_path = str(dir_path) - else: - raise TypeError('"dir_path" must be a string or Path object') - - if (suffix is not None) and not isinstance(suffix, (str, tuple)): - raise TypeError('"suffix" must be a string or tuple of strings') - - if suffix is not None and not case_sensitive: - suffix = suffix.lower() if isinstance(suffix, str) else tuple( - item.lower() for item in suffix) - - root = dir_path - - def _scandir(dir_path, suffix, recursive, case_sensitive): - for entry in os.scandir(dir_path): - if not entry.name.startswith('.') and entry.is_file(): - rel_path = osp.relpath(entry.path, root) - _rel_path = rel_path if case_sensitive else rel_path.lower() - if suffix is None or _rel_path.endswith(suffix): - yield rel_path - elif recursive and os.path.isdir(entry.path): - # scan recursively if entry.path is a directory - yield from _scandir(entry.path, suffix, recursive, - case_sensitive) - - return _scandir(dir_path, suffix, recursive, case_sensitive) - - -def find_vcs_root(path, markers=('.git', )): - """Finds the root directory (including itself) of specified markers. - - Args: - path (str): Path of directory or file. - markers (list[str], optional): List of file or directory names. - - Returns: - The directory contained one of the markers or None if not found. - """ - if osp.isfile(path): - path = osp.dirname(path) - - prev, cur = None, osp.abspath(osp.expanduser(path)) - while cur != prev: - if any(osp.exists(osp.join(cur, marker)) for marker in markers): - return cur - prev, cur = cur, osp.split(cur)[0] - return None diff --git a/mmcv/utils/progressbar.py b/mmcv/utils/progressbar.py deleted file mode 100644 index 0062f670dd..0000000000 --- a/mmcv/utils/progressbar.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import sys -from collections.abc import Iterable -from multiprocessing import Pool -from shutil import get_terminal_size - -from .timer import Timer - - -class ProgressBar: - """A progress bar which can print the progress.""" - - def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): - self.task_num = task_num - self.bar_width = bar_width - self.completed = 0 - self.file = file - if start: - self.start() - - @property - def terminal_width(self): - width, _ = get_terminal_size() - return width - - def start(self): - if self.task_num > 0: - self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' - 'elapsed: 0s, ETA:') - else: - self.file.write('completed: 0, elapsed: 0s') - self.file.flush() - self.timer = Timer() - - def update(self, num_tasks=1): - assert num_tasks > 0 - self.completed += num_tasks - elapsed = self.timer.since_start() - if elapsed > 0: - fps = self.completed / elapsed - else: - fps = float('inf') - if self.task_num > 0: - percentage = self.completed / float(self.task_num) - eta = int(elapsed * (1 - percentage) / percentage + 0.5) - msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ - f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ - f'ETA: {eta:5}s' - - bar_width = min(self.bar_width, - int(self.terminal_width - len(msg)) + 2, - int(self.terminal_width * 0.6)) - bar_width = max(2, bar_width) - mark_width = int(bar_width * percentage) - bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) - self.file.write(msg.format(bar_chars)) - else: - self.file.write( - f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' - f' {fps:.1f} tasks/s') - self.file.flush() - - -def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): - """Track the progress of tasks execution with a progress bar. - - Tasks are done with a simple for-loop. - - Args: - func (callable): The function to be applied to each task. - tasks (list or tuple[Iterable, int]): A list of tasks or - (tasks, total num). - bar_width (int): Width of progress bar. - - Returns: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] - elif isinstance(tasks, Iterable): - task_num = len(tasks) - else: - raise TypeError( - '"tasks" must be an iterable object or a (iterator, int) tuple') - prog_bar = ProgressBar(task_num, bar_width, file=file) - results = [] - for task in tasks: - results.append(func(task, **kwargs)) - prog_bar.update() - prog_bar.file.write('\n') - return results - - -def init_pool(process_num, initializer=None, initargs=None): - if initializer is None: - return Pool(process_num) - elif initargs is None: - return Pool(process_num, initializer) - else: - if not isinstance(initargs, tuple): - raise TypeError('"initargs" must be a tuple') - return Pool(process_num, initializer, initargs) - - -def track_parallel_progress(func, - tasks, - nproc, - initializer=None, - initargs=None, - bar_width=50, - chunksize=1, - skip_first=False, - keep_order=True, - file=sys.stdout): - """Track the progress of parallel task execution with a progress bar. - - The built-in :mod:`multiprocessing` module is used for process pools and - tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. - - Args: - func (callable): The function to be applied to each task. - tasks (list or tuple[Iterable, int]): A list of tasks or - (tasks, total num). - nproc (int): Process (worker) number. - initializer (None or callable): Refer to :class:`multiprocessing.Pool` - for details. - initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for - details. - chunksize (int): Refer to :class:`multiprocessing.Pool` for details. - bar_width (int): Width of progress bar. - skip_first (bool): Whether to skip the first sample for each worker - when estimating fps, since the initialization step may takes - longer. - keep_order (bool): If True, :func:`Pool.imap` is used, otherwise - :func:`Pool.imap_unordered` is used. - - Returns: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] - elif isinstance(tasks, Iterable): - task_num = len(tasks) - else: - raise TypeError( - '"tasks" must be an iterable object or a (iterator, int) tuple') - pool = init_pool(nproc, initializer, initargs) - start = not skip_first - task_num -= nproc * chunksize * int(skip_first) - prog_bar = ProgressBar(task_num, bar_width, start, file=file) - results = [] - if keep_order: - gen = pool.imap(func, tasks, chunksize) - else: - gen = pool.imap_unordered(func, tasks, chunksize) - for result in gen: - results.append(result) - if skip_first: - if len(results) < nproc * chunksize: - continue - elif len(results) == nproc * chunksize: - prog_bar.start() - continue - prog_bar.update() - prog_bar.file.write('\n') - pool.close() - pool.join() - return results - - -def track_iter_progress(tasks, bar_width=50, file=sys.stdout): - """Track the progress of tasks iteration or enumeration with a progress - bar. - - Tasks are yielded with a simple for-loop. - - Args: - tasks (list or tuple[Iterable, int]): A list of tasks or - (tasks, total num). - bar_width (int): Width of progress bar. - - Yields: - list: The task results. - """ - if isinstance(tasks, tuple): - assert len(tasks) == 2 - assert isinstance(tasks[0], Iterable) - assert isinstance(tasks[1], int) - task_num = tasks[1] - tasks = tasks[0] - elif isinstance(tasks, Iterable): - task_num = len(tasks) - else: - raise TypeError( - '"tasks" must be an iterable object or a (iterator, int) tuple') - prog_bar = ProgressBar(task_num, bar_width, file=file) - for task in tasks: - yield task - prog_bar.update() - prog_bar.file.write('\n') diff --git a/mmcv/utils/registry.py b/mmcv/utils/registry.py deleted file mode 100644 index a7db6bd442..0000000000 --- a/mmcv/utils/registry.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import inspect -import warnings -from functools import partial -from typing import Any, Dict, Optional - -from .misc import deprecated_api_warning, is_seq_of - - -def build_from_cfg(cfg: Dict, - registry: 'Registry', - default_args: Optional[Dict] = None) -> Any: - """Build a module from config dict when it is a class configuration, or - call a function from config dict when it is a function configuration. - - Example: - >>> MODELS = Registry('models') - >>> @MODELS.register_module() - >>> class ResNet: - >>> pass - >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS) - >>> # Returns an instantiated object - >>> @MODELS.register_module() - >>> def resnet50(): - >>> pass - >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) - >>> # Return a result of the calling function - - Args: - cfg (dict): Config dict. It should at least contain the key "type". - registry (:obj:`Registry`): The registry to search the type from. - default_args (dict, optional): Default initialization arguments. - - Returns: - object: The constructed object. - """ - if not isinstance(cfg, dict): - raise TypeError(f'cfg must be a dict, but got {type(cfg)}') - if 'type' not in cfg: - if default_args is None or 'type' not in default_args: - raise KeyError( - '`cfg` or `default_args` must contain the key "type", ' - f'but got {cfg}\n{default_args}') - if not isinstance(registry, Registry): - raise TypeError('registry must be an mmcv.Registry object, ' - f'but got {type(registry)}') - if not (isinstance(default_args, dict) or default_args is None): - raise TypeError('default_args must be a dict or None, ' - f'but got {type(default_args)}') - - args = cfg.copy() - - if default_args is not None: - for name, value in default_args.items(): - args.setdefault(name, value) - - obj_type = args.pop('type') - if isinstance(obj_type, str): - obj_cls = registry.get(obj_type) - if obj_cls is None: - raise KeyError( - f'{obj_type} is not in the {registry.name} registry') - elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): - obj_cls = obj_type - else: - raise TypeError( - f'type must be a str or valid type, but got {type(obj_type)}') - try: - return obj_cls(**args) - except Exception as e: - # Normal TypeError does not print class name. - raise type(e)(f'{obj_cls.__name__}: {e}') - - -class Registry: - """A registry to map strings to classes or functions. - - Registered object could be built from registry. Meanwhile, registered - functions could be called from registry. - - Example: - >>> MODELS = Registry('models') - >>> @MODELS.register_module() - >>> class ResNet: - >>> pass - >>> resnet = MODELS.build(dict(type='ResNet')) - >>> @MODELS.register_module() - >>> def resnet50(): - >>> pass - >>> resnet = MODELS.build(dict(type='resnet50')) - - Please refer to - https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for - advanced usage. - - Args: - name (str): Registry name. - build_func(func, optional): Build function to construct instance from - Registry, func:`build_from_cfg` is used if neither ``parent`` or - ``build_func`` is specified. If ``parent`` is specified and - ``build_func`` is not given, ``build_func`` will be inherited - from ``parent``. Default: None. - parent (Registry, optional): Parent registry. The class registered in - children registry could be built from parent. Default: None. - scope (str, optional): The scope of registry. It is the key to search - for children registry. If not specified, scope will be the name of - the package where class is defined, e.g. mmdet, mmcls, mmseg. - Default: None. - """ - - def __init__(self, name, build_func=None, parent=None, scope=None): - self._name = name - self._module_dict = dict() - self._children = dict() - self._scope = self.infer_scope() if scope is None else scope - - # self.build_func will be set with the following priority: - # 1. build_func - # 2. parent.build_func - # 3. build_from_cfg - if build_func is None: - if parent is not None: - self.build_func = parent.build_func - else: - self.build_func = build_from_cfg - else: - self.build_func = build_func - if parent is not None: - assert isinstance(parent, Registry) - parent._add_children(self) - self.parent = parent - else: - self.parent = None - - def __len__(self): - return len(self._module_dict) - - def __contains__(self, key): - return self.get(key) is not None - - def __repr__(self): - format_str = self.__class__.__name__ + \ - f'(name={self._name}, ' \ - f'items={self._module_dict})' - return format_str - - @staticmethod - def infer_scope(): - """Infer the scope of registry. - - The name of the package where registry is defined will be returned. - - Example: - >>> # in mmdet/models/backbone/resnet.py - >>> MODELS = Registry('models') - >>> @MODELS.register_module() - >>> class ResNet: - >>> pass - The scope of ``ResNet`` will be ``mmdet``. - - Returns: - str: The inferred scope name. - """ - # We access the caller using inspect.currentframe() instead of - # inspect.stack() for performance reasons. See details in PR #1844 - frame = inspect.currentframe() - # get the frame where `infer_scope()` is called - infer_scope_caller = frame.f_back.f_back - filename = inspect.getmodule(infer_scope_caller).__name__ - split_filename = filename.split('.') - return split_filename[0] - - @staticmethod - def split_scope_key(key): - """Split scope and key. - - The first scope will be split from key. - - Examples: - >>> Registry.split_scope_key('mmdet.ResNet') - 'mmdet', 'ResNet' - >>> Registry.split_scope_key('ResNet') - None, 'ResNet' - - Return: - tuple[str | None, str]: The former element is the first scope of - the key, which can be ``None``. The latter is the remaining key. - """ - split_index = key.find('.') - if split_index != -1: - return key[:split_index], key[split_index + 1:] - else: - return None, key - - @property - def name(self): - return self._name - - @property - def scope(self): - return self._scope - - @property - def module_dict(self): - return self._module_dict - - @property - def children(self): - return self._children - - def get(self, key): - """Get the registry record. - - Args: - key (str): The class name in string format. - - Returns: - class: The corresponding class. - """ - scope, real_key = self.split_scope_key(key) - if scope is None or scope == self._scope: - # get from self - if real_key in self._module_dict: - return self._module_dict[real_key] - else: - # get from self._children - if scope in self._children: - return self._children[scope].get(real_key) - else: - # goto root - parent = self.parent - while parent.parent is not None: - parent = parent.parent - return parent.get(key) - - def build(self, *args, **kwargs): - return self.build_func(*args, **kwargs, registry=self) - - def _add_children(self, registry): - """Add children for a registry. - - The ``registry`` will be added as children based on its scope. - The parent registry could build objects from children registry. - - Example: - >>> models = Registry('models') - >>> mmdet_models = Registry('models', parent=models) - >>> @mmdet_models.register_module() - >>> class ResNet: - >>> pass - >>> resnet = models.build(dict(type='mmdet.ResNet')) - """ - - assert isinstance(registry, Registry) - assert registry.scope is not None - assert registry.scope not in self.children, \ - f'scope {registry.scope} exists in {self.name} registry' - self.children[registry.scope] = registry - - @deprecated_api_warning(name_dict=dict(module_class='module')) - def _register_module(self, module, module_name=None, force=False): - if not inspect.isclass(module) and not inspect.isfunction(module): - raise TypeError('module must be a class or a function, ' - f'but got {type(module)}') - - if module_name is None: - module_name = module.__name__ - if isinstance(module_name, str): - module_name = [module_name] - for name in module_name: - if not force and name in self._module_dict: - raise KeyError(f'{name} is already registered ' - f'in {self.name}') - self._module_dict[name] = module - - def deprecated_register_module(self, cls=None, force=False): - warnings.warn( - 'The old API of register_module(module, force=False) ' - 'is deprecated and will be removed, please use the new API ' - 'register_module(name=None, force=False, module=None) instead.', - DeprecationWarning) - if cls is None: - return partial(self.deprecated_register_module, force=force) - self._register_module(cls, force=force) - return cls - - def register_module(self, name=None, force=False, module=None): - """Register a module. - - A record will be added to `self._module_dict`, whose key is the class - name or the specified name, and value is the class itself. - It can be used as a decorator or a normal function. - - Example: - >>> backbones = Registry('backbone') - >>> @backbones.register_module() - >>> class ResNet: - >>> pass - - >>> backbones = Registry('backbone') - >>> @backbones.register_module(name='mnet') - >>> class MobileNet: - >>> pass - - >>> backbones = Registry('backbone') - >>> class ResNet: - >>> pass - >>> backbones.register_module(ResNet) - - Args: - name (str | None): The module name to be registered. If not - specified, the class name will be used. - force (bool, optional): Whether to override an existing class with - the same name. Default: False. - module (type): Module class or function to be registered. - """ - if not isinstance(force, bool): - raise TypeError(f'force must be a boolean, but got {type(force)}') - # NOTE: This is a walkaround to be compatible with the old api, - # while it may introduce unexpected bugs. - if isinstance(name, type): - return self.deprecated_register_module(name, force=force) - - # raise the error ahead of time - if not (name is None or isinstance(name, str) or is_seq_of(name, str)): - raise TypeError( - 'name must be either of None, an instance of str or a sequence' - f' of str, but got {type(name)}') - - # use it as a normal method: x.register_module(module=SomeClass) - if module is not None: - self._register_module(module=module, module_name=name, force=force) - return module - - # use it as a decorator: @x.register_module() - def _register(module): - self._register_module(module=module, module_name=name, force=force) - return module - - return _register diff --git a/mmcv/utils/seed.py b/mmcv/utils/seed.py deleted file mode 100644 index 003f923677..0000000000 --- a/mmcv/utils/seed.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import random - -import numpy as np -import torch - - -def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): - """Function to initialize each worker. - - The seed of each worker equals to - ``num_worker * rank + worker_id + user_seed``. - - Args: - worker_id (int): Id for each worker. - num_workers (int): Number of workers. - rank (int): Rank in distributed training. - seed (int): Random seed. - """ - worker_seed = num_workers * rank + worker_id + seed - np.random.seed(worker_seed) - random.seed(worker_seed) - torch.manual_seed(worker_seed) diff --git a/mmcv/utils/testing.py b/mmcv/utils/testing.py deleted file mode 100644 index 7b64e8fae3..0000000000 --- a/mmcv/utils/testing.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) Open-MMLab. -import sys -from collections.abc import Iterable -from runpy import run_path -from shlex import split -from typing import Any, Dict, List -from unittest.mock import patch - - -def check_python_script(cmd): - """Run the python cmd script with `__main__`. The difference between - `os.system` is that, this function exectues code in the current process, so - that it can be tracked by coverage tools. Currently it supports two forms: - - - ./tests/data/scripts/hello.py zz - - python tests/data/scripts/hello.py zz - """ - args = split(cmd) - if args[0] == 'python': - args = args[1:] - with patch.object(sys, 'argv', args): - run_path(args[0], run_name='__main__') - - -def _any(judge_result): - """Since built-in ``any`` works only when the element of iterable is not - iterable, implement the function.""" - if not isinstance(judge_result, Iterable): - return judge_result - - try: - for element in judge_result: - if _any(element): - return True - except TypeError: - # Maybe encounter the case: torch.tensor(True) | torch.tensor(False) - if judge_result: - return True - return False - - -def assert_dict_contains_subset(dict_obj: Dict[Any, Any], - expected_subset: Dict[Any, Any]) -> bool: - """Check if the dict_obj contains the expected_subset. - - Args: - dict_obj (Dict[Any, Any]): Dict object to be checked. - expected_subset (Dict[Any, Any]): Subset expected to be contained in - dict_obj. - - Returns: - bool: Whether the dict_obj contains the expected_subset. - """ - - for key, value in expected_subset.items(): - if key not in dict_obj.keys() or _any(dict_obj[key] != value): - return False - return True - - -def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool: - """Check if attribute of class object is correct. - - Args: - obj (object): Class object to be checked. - expected_attrs (Dict[str, Any]): Dict of the expected attrs. - - Returns: - bool: Whether the attribute of class object is correct. - """ - for attr, value in expected_attrs.items(): - if not hasattr(obj, attr) or _any(getattr(obj, attr) != value): - return False - return True - - -def assert_dict_has_keys(obj: Dict[str, Any], - expected_keys: List[str]) -> bool: - """Check if the obj has all the expected_keys. - - Args: - obj (Dict[str, Any]): Object to be checked. - expected_keys (List[str]): Keys expected to contained in the keys of - the obj. - - Returns: - bool: Whether the obj has the expected keys. - """ - return set(expected_keys).issubset(set(obj.keys())) - - -def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool: - """Check if target_keys is equal to result_keys. - - Args: - result_keys (List[str]): Result keys to be checked. - target_keys (List[str]): Target keys to be checked. - - Returns: - bool: Whether target_keys is equal to result_keys. - """ - return set(result_keys) == set(target_keys) - - -def assert_is_norm_layer(module) -> bool: - """Check if the module is a norm layer. - - Args: - module (nn.Module): The module to be checked. - - Returns: - bool: Whether the module is a norm layer. - """ - from torch.nn import GroupNorm, LayerNorm - - from .parrots_wrapper import _BatchNorm, _InstanceNorm - norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm) - return isinstance(module, norm_layer_candidates) - - -def assert_params_all_zeros(module) -> bool: - """Check if the parameters of the module is all zeros. - - Args: - module (nn.Module): The module to be checked. - - Returns: - bool: Whether the parameters of the module is all zeros. - """ - weight_data = module.weight.data - is_weight_zero = weight_data.allclose( - weight_data.new_zeros(weight_data.size())) - - if hasattr(module, 'bias') and module.bias is not None: - bias_data = module.bias.data - is_bias_zero = bias_data.allclose( - bias_data.new_zeros(bias_data.size())) - else: - is_bias_zero = True - - return is_weight_zero and is_bias_zero diff --git a/mmcv/utils/timer.py b/mmcv/utils/timer.py deleted file mode 100644 index 087a969cfa..0000000000 --- a/mmcv/utils/timer.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from time import time - - -class TimerError(Exception): - - def __init__(self, message): - self.message = message - super().__init__(message) - - -class Timer: - """A flexible Timer class. - - Examples: - >>> import time - >>> import mmcv - >>> with mmcv.Timer(): - >>> # simulate a code block that will run for 1s - >>> time.sleep(1) - 1.000 - >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): - >>> # simulate a code block that will run for 1s - >>> time.sleep(1) - it takes 1.0 seconds - >>> timer = mmcv.Timer() - >>> time.sleep(0.5) - >>> print(timer.since_start()) - 0.500 - >>> time.sleep(0.5) - >>> print(timer.since_last_check()) - 0.500 - >>> print(timer.since_start()) - 1.000 - """ - - def __init__(self, start=True, print_tmpl=None): - self._is_running = False - self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' - if start: - self.start() - - @property - def is_running(self): - """bool: indicate whether the timer is running""" - return self._is_running - - def __enter__(self): - self.start() - return self - - def __exit__(self, type, value, traceback): - print(self.print_tmpl.format(self.since_last_check())) - self._is_running = False - - def start(self): - """Start the timer.""" - if not self._is_running: - self._t_start = time() - self._is_running = True - self._t_last = time() - - def since_start(self): - """Total time since the timer is started. - - Returns: - float: Time in seconds. - """ - if not self._is_running: - raise TimerError('timer is not running') - self._t_last = time() - return self._t_last - self._t_start - - def since_last_check(self): - """Time since the last checking. - - Either :func:`since_start` or :func:`since_last_check` is a checking - operation. - - Returns: - float: Time in seconds. - """ - if not self._is_running: - raise TimerError('timer is not running') - dur = time() - self._t_last - self._t_last = time() - return dur - - -_g_timers = {} # global timers - - -def check_time(timer_id): - """Add check points in a single line. - - This method is suitable for running a task on a list of items. A timer will - be registered when the method is called for the first time. - - Examples: - >>> import time - >>> import mmcv - >>> for i in range(1, 6): - >>> # simulate a code block - >>> time.sleep(i) - >>> mmcv.check_time('task1') - 2.000 - 3.000 - 4.000 - 5.000 - - Args: - str: Timer identifier. - """ - if timer_id not in _g_timers: - _g_timers[timer_id] = Timer() - return 0 - else: - return _g_timers[timer_id].since_last_check() diff --git a/mmcv/utils/torch_ops.py b/mmcv/utils/torch_ops.py deleted file mode 100644 index b4f2213a43..0000000000 --- a/mmcv/utils/torch_ops.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from .parrots_wrapper import TORCH_VERSION -from .version_utils import digit_version - -_torch_version_meshgrid_indexing = ( - 'parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) - - -def torch_meshgrid(*tensors): - """A wrapper of torch.meshgrid to compat different PyTorch versions. - - Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``. - So we implement a wrapper here to avoid warning when using high-version - PyTorch and avoid compatibility issues when using previous versions of - PyTorch. - - Args: - tensors (List[Tensor]): List of scalars or 1 dimensional tensors. - - Returns: - Sequence[Tensor]: Sequence of meshgrid tensors. - """ - if _torch_version_meshgrid_indexing: - return torch.meshgrid(*tensors, indexing='ij') - else: - return torch.meshgrid(*tensors) # Uses indexing='ij' by default diff --git a/mmcv/utils/trace.py b/mmcv/utils/trace.py deleted file mode 100644 index 45423bd055..0000000000 --- a/mmcv/utils/trace.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import warnings - -import torch - -from mmcv.utils import digit_version - - -def is_jit_tracing() -> bool: - if (torch.__version__ != 'parrots' - and digit_version(torch.__version__) >= digit_version('1.6.0')): - on_trace = torch.jit.is_tracing() - # In PyTorch 1.6, torch.jit.is_tracing has a bug. - # Refers to https://github.com/pytorch/pytorch/issues/42448 - if isinstance(on_trace, bool): - return on_trace - else: - return torch._C._is_tracing() - else: - warnings.warn( - 'torch.jit.is_tracing is only supported after v1.6.0. ' - 'Therefore is_tracing returns False automatically. Please ' - 'set on_trace manually if you are using trace.', UserWarning) - return False diff --git a/mmcv/utils/version_utils.py b/mmcv/utils/version_utils.py deleted file mode 100644 index 77c41f6084..0000000000 --- a/mmcv/utils/version_utils.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import subprocess -import warnings - -from packaging.version import parse - - -def digit_version(version_str: str, length: int = 4): - """Convert a version string into a tuple of integers. - - This method is usually used for comparing two versions. For pre-release - versions: alpha < beta < rc. - - Args: - version_str (str): The version string. - length (int): The maximum number of version levels. Default: 4. - - Returns: - tuple[int]: The version info in digits (integers). - """ - assert 'parrots' not in version_str - version = parse(version_str) - assert version.release, f'failed to parse version {version_str}' - release = list(version.release) - release = release[:length] - if len(release) < length: - release = release + [0] * (length - len(release)) - if version.is_prerelease: - mapping = {'a': -3, 'b': -2, 'rc': -1} - val = -4 - # version.pre can be None - if version.pre: - if version.pre[0] not in mapping: - warnings.warn(f'unknown prerelease version {version.pre[0]}, ' - 'version checking may go wrong') - else: - val = mapping[version.pre[0]] - release.extend([val, version.pre[-1]]) - else: - release.extend([val, 0]) - - elif version.is_postrelease: - release.extend([1, version.post]) # type: ignore - else: - release.extend([0, 0]) - return tuple(release) - - -def _minimal_ext_cmd(cmd): - # construct minimal environment - env = {} - for k in ['SYSTEMROOT', 'PATH', 'HOME']: - v = os.environ.get(k) - if v is not None: - env[k] = v - # LANGUAGE is used on win32 - env['LANGUAGE'] = 'C' - env['LANG'] = 'C' - env['LC_ALL'] = 'C' - out = subprocess.Popen( - cmd, stdout=subprocess.PIPE, env=env).communicate()[0] - return out - - -def get_git_hash(fallback='unknown', digits=None): - """Get the git hash of the current repo. - - Args: - fallback (str, optional): The fallback string when git hash is - unavailable. Defaults to 'unknown'. - digits (int, optional): kept digits of the hash. Defaults to None, - meaning all digits are kept. - - Returns: - str: Git commit hash. - """ - - if digits is not None and not isinstance(digits, int): - raise TypeError('digits must be None or an integer') - - try: - out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) - sha = out.strip().decode('ascii') - if digits is not None: - sha = sha[:digits] - except OSError: - sha = fallback - - return sha diff --git a/mmcv/video/io.py b/mmcv/video/io.py index 09fa770db3..0ecc5eabba 100644 --- a/mmcv/video/io.py +++ b/mmcv/video/io.py @@ -6,9 +6,8 @@ from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT, CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH, CAP_PROP_POS_FRAMES, VideoWriter_fourcc) - -from mmcv.utils import (check_file_exist, mkdir_or_exist, scandir, - track_progress) +from mmengine.utils import (check_file_exist, mkdir_or_exist, scandir, + track_progress) class Cache: diff --git a/mmcv/video/optflow.py b/mmcv/video/optflow.py index 91ce004570..edd3e42069 100644 --- a/mmcv/video/optflow.py +++ b/mmcv/video/optflow.py @@ -4,10 +4,10 @@ import cv2 import numpy as np +from mmengine.utils import is_str from mmcv.arraymisc import dequantize, quantize from mmcv.image import imread, imwrite -from mmcv.utils import is_str def flowread(flow_or_path: Union[np.ndarray, str], diff --git a/mmcv/video/processing.py b/mmcv/video/processing.py index 90e2a4c022..4962e08a9e 100644 --- a/mmcv/video/processing.py +++ b/mmcv/video/processing.py @@ -5,7 +5,7 @@ import tempfile from typing import List, Optional, Union -from mmcv.utils import requires_executable +from mmengine.utils import requires_executable @requires_executable('ffmpeg') diff --git a/mmcv/visualization/color.py b/mmcv/visualization/color.py index 2cc0b523e0..05796a80c3 100644 --- a/mmcv/visualization/color.py +++ b/mmcv/visualization/color.py @@ -3,8 +3,7 @@ from typing import Union import numpy as np - -from mmcv.utils import is_str +from mmengine.utils import is_str class Color(Enum): diff --git a/tests/test_cnn/test_build_layers.py b/tests/test_cnn/test_build_layers.py index d4f8413c50..c3bf91f78b 100644 --- a/tests/test_cnn/test_build_layers.py +++ b/tests/test_cnn/test_build_layers.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from mmengine.registry import MODELS +from mmengine.utils.parrots_wrapper import _BatchNorm from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer, build_norm_layer, build_padding_layer, @@ -13,7 +14,6 @@ from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr from mmcv.cnn.bricks.upsample import PixelShufflePack -from mmcv.utils.parrots_wrapper import _BatchNorm def test_build_conv_layer(): diff --git a/tests/test_cnn/test_conv_module.py b/tests/test_cnn/test_conv_module.py index c44e8998ab..568e3527d7 100644 --- a/tests/test_cnn/test_conv_module.py +++ b/tests/test_cnn/test_conv_module.py @@ -6,9 +6,9 @@ import torch import torch.nn as nn from mmengine.registry import MODELS +from mmengine.utils import TORCH_VERSION, digit_version from mmcv.cnn.bricks import ConvModule, HSigmoid, HSwish -from mmcv.utils import TORCH_VERSION, digit_version @MODELS.register_module() diff --git a/tests/test_cnn/test_revert_syncbn.py b/tests/test_cnn/test_revert_syncbn.py deleted file mode 100644 index 187c2a6d0b..0000000000 --- a/tests/test_cnn/test_revert_syncbn.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import platform - -import numpy as np -import pytest -import torch -import torch.distributed as dist - -from mmcv.cnn.bricks import ConvModule -from mmcv.cnn.utils import revert_sync_batchnorm - -if platform.system() == 'Windows': - import regex as re -else: - import re - - -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') -def test_revert_syncbn(): - conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) - x = torch.randn(1, 3, 10, 10) - # Expect a ValueError prompting that SyncBN is not supported on CPU - with pytest.raises(ValueError): - y = conv(x) - conv = revert_sync_batchnorm(conv) - y = conv(x) - assert y.shape == (1, 8, 9, 9) - - -def test_revert_mmsyncbn(): - if 'SLURM_NTASKS' not in os.environ or int(os.environ['SLURM_NTASKS']) < 2: - print('Must run on slurm with more than 1 process!\n' - 'srun -p test --gres=gpu:2 -n2') - return - rank = int(os.environ['SLURM_PROCID']) - world_size = int(os.environ['SLURM_NTASKS']) - local_rank = int(os.environ['SLURM_LOCALID']) - node_list = str(os.environ['SLURM_NODELIST']) - - node_parts = re.findall('[0-9]+', node_list) - os.environ['MASTER_ADDR'] = (f'{node_parts[1]}.{node_parts[2]}' + - f'.{node_parts[3]}.{node_parts[4]}') - os.environ['MASTER_PORT'] = '12341' - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['RANK'] = str(rank) - - dist.init_process_group('nccl') - torch.cuda.set_device(local_rank) - x = torch.randn(1, 3, 10, 10).cuda() - dist.broadcast(x, src=0) - conv = ConvModule(3, 8, 2, norm_cfg=dict(type='MMSyncBN')).cuda() - conv.eval() - y_mmsyncbn = conv(x).detach().cpu().numpy() - conv = revert_sync_batchnorm(conv) - y_bn = conv(x).detach().cpu().numpy() - assert np.all(np.isclose(y_bn, y_mmsyncbn, 1e-3)) - conv, x = conv.to('cpu'), x.to('cpu') - y_bn_cpu = conv(x).detach().numpy() - assert np.all(np.isclose(y_bn, y_bn_cpu, 1e-3)) diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index e77b5f9753..9f973ef48b 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -2,8 +2,7 @@ import numpy as np import pytest import torch - -from mmcv.utils import TORCH_VERSION, digit_version +from mmengine.utils import TORCH_VERSION, digit_version try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 3b9070491a..94b2c12033 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -4,8 +4,7 @@ import numpy import pytest import torch - -from mmcv.utils import TORCH_VERSION, digit_version +from mmengine.utils import TORCH_VERSION, digit_version try: # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast diff --git a/tests/test_utils/test_config.py b/tests/test_utils/test_config.py deleted file mode 100644 index 4490a900d4..0000000000 --- a/tests/test_utils/test_config.py +++ /dev/null @@ -1,612 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import argparse -import copy -import json -import os -import os.path as osp -import shutil -import tempfile -from pathlib import Path - -import pytest -import yaml -from mmengine import dump, load - -from mmcv import Config, ConfigDict, DictAction - -data_path = osp.join(osp.dirname(osp.dirname(__file__)), 'data') - - -def test_construct(): - cfg = Config() - assert cfg.filename is None - assert cfg.text == '' - assert len(cfg) == 0 - assert cfg._cfg_dict == {} - - with pytest.raises(TypeError): - Config([0, 1]) - - cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test') - # test a.py - cfg_file = osp.join(data_path, 'config/a.py') - cfg_file_path = Path(cfg_file) - file_list = [cfg_file, cfg_file_path] - for item in file_list: - cfg = Config(cfg_dict, filename=item) - assert isinstance(cfg, Config) - assert isinstance(cfg.filename, str) and cfg.filename == str(item) - assert cfg.text == open(item).read() - assert cfg.dump() == cfg.pretty_text - with tempfile.TemporaryDirectory() as temp_config_dir: - dump_file = osp.join(temp_config_dir, 'a.py') - cfg.dump(dump_file) - assert cfg.dump() == open(dump_file).read() - assert Config.fromfile(dump_file) - - # test b.json - cfg_file = osp.join(data_path, 'config/b.json') - cfg = Config(cfg_dict, filename=cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - assert cfg.text == open(cfg_file).read() - assert cfg.dump() == json.dumps(cfg_dict) - with tempfile.TemporaryDirectory() as temp_config_dir: - dump_file = osp.join(temp_config_dir, 'b.json') - cfg.dump(dump_file) - assert cfg.dump() == open(dump_file).read() - assert Config.fromfile(dump_file) - - # test c.yaml - cfg_file = osp.join(data_path, 'config/c.yaml') - cfg = Config(cfg_dict, filename=cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - assert cfg.text == open(cfg_file).read() - assert cfg.dump() == yaml.dump(cfg_dict) - with tempfile.TemporaryDirectory() as temp_config_dir: - dump_file = osp.join(temp_config_dir, 'c.yaml') - cfg.dump(dump_file) - assert cfg.dump() == open(dump_file).read() - assert Config.fromfile(dump_file) - - # test h.py - cfg_file = osp.join(data_path, 'config/h.py') - path = osp.join(osp.dirname(__file__), 'data', 'config') - # the value of osp.dirname(__file__) may be `D:\a\xxx` in windows - # environment. When dumping the cfg_dict to file, `D:\a\xxx` will be - # converted to `D:\x07\xxx` and it will cause unexpected result when - # checking whether `D:\a\xxx` equals to `D:\x07\xxx`. Therefore, we forcely - # convert a string representation of the path with forward slashes (/) - path = Path(path).as_posix() - cfg_dict = dict(item1='h.py', item2=path, item3='abc_h') - cfg = Config(cfg_dict, filename=cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - assert cfg.text == open(cfg_file).read() - assert cfg.dump() == cfg.pretty_text - with tempfile.TemporaryDirectory() as temp_config_dir: - dump_file = osp.join(temp_config_dir, 'h.py') - cfg.dump(dump_file) - assert cfg.dump() == open(dump_file).read() - assert Config.fromfile(dump_file) - assert Config.fromfile(dump_file)['item1'] == cfg_dict['item1'] - assert Config.fromfile(dump_file)['item2'] == cfg_dict['item2'] - assert Config.fromfile(dump_file)['item3'] == cfg_dict['item3'] - - # test no use_predefined_variable - cfg_dict = dict( - item1='{{fileBasename}}', - item2='{{ fileDirname}}', - item3='abc_{{ fileBasenameNoExtension }}') - assert Config.fromfile(cfg_file, False) - assert Config.fromfile(cfg_file, False)['item1'] == cfg_dict['item1'] - assert Config.fromfile(cfg_file, False)['item2'] == cfg_dict['item2'] - assert Config.fromfile(cfg_file, False)['item3'] == cfg_dict['item3'] - - # test p.yaml - cfg_file = osp.join(data_path, 'config/p.yaml') - cfg_dict = dict(item1=osp.join(osp.dirname(__file__), 'data', 'config')) - cfg = Config(cfg_dict, filename=cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - assert cfg.text == open(cfg_file).read() - assert cfg.dump() == yaml.dump(cfg_dict) - with tempfile.TemporaryDirectory() as temp_config_dir: - dump_file = osp.join(temp_config_dir, 'p.yaml') - cfg.dump(dump_file) - assert cfg.dump() == open(dump_file).read() - assert Config.fromfile(dump_file) - assert Config.fromfile(dump_file)['item1'] == cfg_dict['item1'] - - # test no use_predefined_variable - assert Config.fromfile(cfg_file, False) - assert Config.fromfile(cfg_file, False)['item1'] == '{{ fileDirname }}' - - # test o.json - cfg_file = osp.join(data_path, 'config/o.json') - cfg_dict = dict(item1=osp.join(osp.dirname(__file__), 'data', 'config')) - cfg = Config(cfg_dict, filename=cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - assert cfg.text == open(cfg_file).read() - assert cfg.dump() == json.dumps(cfg_dict) - with tempfile.TemporaryDirectory() as temp_config_dir: - dump_file = osp.join(temp_config_dir, 'o.json') - cfg.dump(dump_file) - assert cfg.dump() == open(dump_file).read() - assert Config.fromfile(dump_file) - assert Config.fromfile(dump_file)['item1'] == cfg_dict['item1'] - - # test no use_predefined_variable - assert Config.fromfile(cfg_file, False) - assert Config.fromfile(cfg_file, False)['item1'] == '{{ fileDirname }}' - - -def test_fromfile(): - for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']: - cfg_file = osp.join(data_path, 'config', filename) - cfg_file_path = Path(cfg_file) - file_list = [cfg_file, cfg_file_path] - for item in file_list: - cfg = Config.fromfile(item) - assert isinstance(cfg, Config) - assert isinstance(cfg.filename, str) and cfg.filename == str(item) - assert cfg.text == osp.abspath(osp.expanduser(item)) + '\n' + \ - open(item).read() - - # test custom_imports for Config.fromfile - cfg_file = osp.join(data_path, 'config', 'q.py') - imported_file = osp.join(data_path, 'config', 'r.py') - target_pkg = osp.join(osp.dirname(__file__), 'r.py') - - # Since the imported config will be regarded as a tmp file - # it should be copied to the directory at the same level - shutil.copy(imported_file, target_pkg) - Config.fromfile(cfg_file, import_custom_modules=True) - - assert os.environ.pop('TEST_VALUE') == 'test' - os.remove(target_pkg) - - with pytest.raises(FileNotFoundError): - Config.fromfile('no_such_file.py') - with pytest.raises(IOError): - Config.fromfile(osp.join(data_path, 'color.jpg')) - - -def test_fromstring(): - for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']: - cfg_file = osp.join(data_path, 'config', filename) - file_format = osp.splitext(filename)[-1] - in_cfg = Config.fromfile(cfg_file) - - out_cfg = Config.fromstring(in_cfg.pretty_text, '.py') - assert in_cfg._cfg_dict == out_cfg._cfg_dict - - cfg_str = open(cfg_file).read() - out_cfg = Config.fromstring(cfg_str, file_format) - assert in_cfg._cfg_dict == out_cfg._cfg_dict - - # test pretty_text only supports py file format - cfg_file = osp.join(data_path, 'config', 'b.json') - in_cfg = Config.fromfile(cfg_file) - with pytest.raises(Exception): - Config.fromstring(in_cfg.pretty_text, '.json') - - # test file format error - cfg_str = open(cfg_file).read() - with pytest.raises(Exception): - Config.fromstring(cfg_str, '.py') - - -def test_merge_from_base(): - cfg_file = osp.join(data_path, 'config/d.py') - cfg = Config.fromfile(cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - base_cfg_file = osp.join(data_path, 'config/base.py') - merge_text = osp.abspath(osp.expanduser(base_cfg_file)) + '\n' + \ - open(base_cfg_file).read() - merge_text += '\n' + osp.abspath(osp.expanduser(cfg_file)) + '\n' + \ - open(cfg_file).read() - assert cfg.text == merge_text - assert cfg.item1 == [2, 3] - assert cfg.item2.a == 1 - assert cfg.item3 is False - assert cfg.item4 == 'test_base' - - with pytest.raises(TypeError): - Config.fromfile(osp.join(data_path, 'config/e.py')) - - -def test_merge_from_multiple_bases(): - cfg_file = osp.join(data_path, 'config/l.py') - cfg = Config.fromfile(cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - # cfg.field - assert cfg.item1 == [1, 2] - assert cfg.item2.a == 0 - assert cfg.item3 is False - assert cfg.item4 == 'test' - assert cfg.item5 == dict(a=0, b=1) - assert cfg.item6 == [dict(a=0), dict(b=1)] - assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - - with pytest.raises(KeyError): - Config.fromfile(osp.join(data_path, 'config/m.py')) - - -def test_base_variables(): - for file in ['t.py', 't.json', 't.yaml']: - cfg_file = osp.join(data_path, f'config/{file}') - cfg = Config.fromfile(cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - # cfg.field - assert cfg.item1 == [1, 2] - assert cfg.item2.a == 0 - assert cfg.item3 is False - assert cfg.item4 == 'test' - assert cfg.item5 == dict(a=0, b=1) - assert cfg.item6 == [dict(a=0), dict(b=1)] - assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - assert cfg.item8 == file - assert cfg.item9 == dict(a=0) - assert cfg.item10 == [3.1, 4.2, 5.3] - - # test nested base - for file in ['u.py', 'u.json', 'u.yaml']: - cfg_file = osp.join(data_path, f'config/{file}') - cfg = Config.fromfile(cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - # cfg.field - assert cfg.base == '_base_.item8' - assert cfg.item1 == [1, 2] - assert cfg.item2.a == 0 - assert cfg.item3 is False - assert cfg.item4 == 'test' - assert cfg.item5 == dict(a=0, b=1) - assert cfg.item6 == [dict(a=0), dict(b=1)] - assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) - assert cfg.item8 == 't.py' - assert cfg.item9 == dict(a=0) - assert cfg.item10 == [3.1, 4.2, 5.3] - assert cfg.item11 == 't.py' - assert cfg.item12 == dict(a=0) - assert cfg.item13 == [3.1, 4.2, 5.3] - assert cfg.item14 == [1, 2] - assert cfg.item15 == dict( - a=dict(b=dict(a=0)), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) - - # test reference assignment for py - cfg_file = osp.join(data_path, 'config/v.py') - cfg = Config.fromfile(cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - assert cfg.item21 == 't.py' - assert cfg.item22 == 't.py' - assert cfg.item23 == [3.1, 4.2, 5.3] - assert cfg.item24 == [3.1, 4.2, 5.3] - assert cfg.item25 == dict( - a=dict(b=[3.1, 4.2, 5.3]), - b=[[3.1, 4.2, 5.3]], - c=[[{ - 'e': 't.py' - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e='t.py') - - -def test_merge_recursive_bases(): - cfg_file = osp.join(data_path, 'config/f.py') - cfg = Config.fromfile(cfg_file) - assert isinstance(cfg, Config) - assert cfg.filename == cfg_file - # cfg.field - assert cfg.item1 == [2, 3] - assert cfg.item2.a == 1 - assert cfg.item3 is False - assert cfg.item4 == 'test_recursive_bases' - - -def test_merge_from_dict(): - cfg_file = osp.join(data_path, 'config/a.py') - cfg = Config.fromfile(cfg_file) - input_options = {'item2.a': 1, 'item2.b': 0.1, 'item3': False} - cfg.merge_from_dict(input_options) - assert cfg.item2 == dict(a=1, b=0.1) - assert cfg.item3 is False - - cfg_file = osp.join(data_path, 'config/s.py') - cfg = Config.fromfile(cfg_file) - - # Allow list keys - input_options = {'item.0.a': 1, 'item.1.b': 1} - cfg.merge_from_dict(input_options, allow_list_keys=True) - assert cfg.item == [{'a': 1}, {'b': 1, 'c': 0}] - - # allow_list_keys is False - input_options = {'item.0.a': 1, 'item.1.b': 1} - with pytest.raises(TypeError): - cfg.merge_from_dict(input_options, allow_list_keys=False) - - # Overflowed index number - input_options = {'item.2.a': 1} - with pytest.raises(KeyError): - cfg.merge_from_dict(input_options, allow_list_keys=True) - - -def test_merge_delete(): - cfg_file = osp.join(data_path, 'config/delete.py') - cfg = Config.fromfile(cfg_file) - # cfg.field - assert cfg.item1 == dict(a=0) - assert cfg.item2 == dict(a=0, b=0) - assert cfg.item3 is True - assert cfg.item4 == 'test' - assert '_delete_' not in cfg.item2 - - # related issue: https://github.com/open-mmlab/mmcv/issues/1570 - assert type(cfg.item1) == ConfigDict - assert type(cfg.item2) == ConfigDict - - -def test_merge_intermediate_variable(): - - cfg_file = osp.join(data_path, 'config/i_child.py') - cfg = Config.fromfile(cfg_file) - # cfg.field - assert cfg.item1 == [1, 2] - assert cfg.item2 == dict(a=0) - assert cfg.item3 is True - assert cfg.item4 == 'test' - assert cfg.item_cfg == dict(b=2) - assert cfg.item5 == dict(cfg=dict(b=1)) - assert cfg.item6 == dict(cfg=dict(b=2)) - - -def test_fromfile_in_config(): - cfg_file = osp.join(data_path, 'config/code.py') - cfg = Config.fromfile(cfg_file) - # cfg.field - assert cfg.cfg.item1 == [1, 2] - assert cfg.cfg.item2 == dict(a=0) - assert cfg.cfg.item3 is True - assert cfg.cfg.item4 == 'test' - assert cfg.item5 == 1 - - -def test_dict(): - cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test') - - for filename in ['a.py', 'b.json', 'c.yaml']: - cfg_file = osp.join(data_path, 'config', filename) - cfg = Config.fromfile(cfg_file) - - # len(cfg) - assert len(cfg) == 4 - # cfg.keys() - assert set(cfg.keys()) == set(cfg_dict.keys()) - assert set(cfg._cfg_dict.keys()) == set(cfg_dict.keys()) - # cfg.values() - for value in cfg.values(): - assert value in cfg_dict.values() - # cfg.items() - for name, value in cfg.items(): - assert name in cfg_dict - assert value in cfg_dict.values() - # cfg.field - assert cfg.item1 == cfg_dict['item1'] - assert cfg.item2 == cfg_dict['item2'] - assert cfg.item2.a == 0 - assert cfg.item3 == cfg_dict['item3'] - assert cfg.item4 == cfg_dict['item4'] - with pytest.raises(AttributeError): - cfg.not_exist - # field in cfg, cfg[field], cfg.get() - for name in ['item1', 'item2', 'item3', 'item4']: - assert name in cfg - assert cfg[name] == cfg_dict[name] - assert cfg.get(name) == cfg_dict[name] - assert cfg.get('not_exist') is None - assert cfg.get('not_exist', 0) == 0 - with pytest.raises(KeyError): - cfg['not_exist'] - assert 'item1' in cfg - assert 'not_exist' not in cfg - # cfg.update() - cfg.update(dict(item1=0)) - assert cfg.item1 == 0 - cfg.update(dict(item2=dict(a=1))) - assert cfg.item2.a == 1 - - -@pytest.mark.parametrize('file', ['a.json', 'b.py', 'c.yaml', 'd.yml', None]) -def test_dump(file): - # config loaded from dict - cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test') - cfg = Config(cfg_dict=cfg_dict) - assert cfg.item1 == cfg_dict['item1'] - assert cfg.item2 == cfg_dict['item2'] - assert cfg.item3 == cfg_dict['item3'] - assert cfg.item4 == cfg_dict['item4'] - assert cfg._filename is None - if file is not None: - # dump without a filename argument is only returning pretty_text. - with tempfile.TemporaryDirectory() as temp_config_dir: - cfg_file = osp.join(temp_config_dir, file) - cfg.dump(cfg_file) - dumped_cfg = Config.fromfile(cfg_file) - assert dumped_cfg._cfg_dict == cfg._cfg_dict - else: - assert cfg.dump() == cfg.pretty_text - - # The key of json must be a string, so key `1` will be converted to `'1'`. - def compare_json_cfg(ori_cfg, dumped_json_cfg): - for key, value in ori_cfg.items(): - assert str(key) in dumped_json_cfg - if not isinstance(value, dict): - assert ori_cfg[key] == dumped_json_cfg[str(key)] - else: - compare_json_cfg(value, dumped_json_cfg[str(key)]) - - # config loaded from file - cfg_file = osp.join(data_path, 'config/n.py') - cfg = Config.fromfile(cfg_file) - if file is not None: - with tempfile.TemporaryDirectory() as temp_config_dir: - cfg_file = osp.join(temp_config_dir, file) - cfg.dump(cfg_file) - dumped_cfg = Config.fromfile(cfg_file) - if not file.endswith('.json'): - assert dumped_cfg._cfg_dict == cfg._cfg_dict - else: - compare_json_cfg(cfg._cfg_dict, dumped_cfg._cfg_dict) - else: - assert cfg.dump() == cfg.pretty_text - - -def test_setattr(): - cfg = Config() - cfg.item1 = [1, 2] - cfg.item2 = {'a': 0} - cfg['item5'] = {'a': {'b': None}} - assert cfg._cfg_dict['item1'] == [1, 2] - assert cfg.item1 == [1, 2] - assert cfg._cfg_dict['item2'] == {'a': 0} - assert cfg.item2.a == 0 - assert cfg._cfg_dict['item5'] == {'a': {'b': None}} - assert cfg.item5.a.b is None - - -def test_pretty_text(): - cfg_file = osp.join(data_path, 'config/l.py') - cfg = Config.fromfile(cfg_file) - with tempfile.TemporaryDirectory() as temp_config_dir: - text_cfg_filename = osp.join(temp_config_dir, '_text_config.py') - with open(text_cfg_filename, 'w') as f: - f.write(cfg.pretty_text) - text_cfg = Config.fromfile(text_cfg_filename) - assert text_cfg._cfg_dict == cfg._cfg_dict - - -def test_dict_action(): - parser = argparse.ArgumentParser(description='Train a detector') - parser.add_argument( - '--options', nargs='+', action=DictAction, help='custom options') - # Nested brackets - args = parser.parse_args( - ['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]']) - out_dict = {'item2.a': ['a', 'b'], 'item2.b': [('a', 'b'), [1, 2], False]} - assert args.options == out_dict - # Single Nested brackets - args = parser.parse_args(['--options', 'item2.a=[[1]]']) - out_dict = {'item2.a': [[1]]} - assert args.options == out_dict - # Imbalance bracket - with pytest.raises(AssertionError): - parser.parse_args(['--options', 'item2.a=[(a,b), [1,2], false']) - # Normal values - args = parser.parse_args([ - '--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false', - 'item4=none', 'item5=None' - ]) - out_dict = { - 'item2.a': 1, - 'item2.b': 0.1, - 'item2.c': 'x', - 'item3': False, - 'item4': 'none', - 'item5': None, - } - assert args.options == out_dict - cfg_file = osp.join(data_path, 'config/a.py') - cfg = Config.fromfile(cfg_file) - cfg.merge_from_dict(args.options) - assert cfg.item2 == dict(a=1, b=0.1, c='x') - assert cfg.item3 is False - - -def test_reserved_key(): - cfg_file = osp.join(data_path, 'config/g.py') - with pytest.raises(KeyError): - Config.fromfile(cfg_file) - - -def test_syntax_error(): - # the name can not be used to open the file a second time in windows, - # so `delete` should be set as `False` and we need to manually remove it - # more details can be found at https://github.com/open-mmlab/mmcv/pull/1077 - temp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py', delete=False) - temp_cfg_path = temp_cfg_file.name - # write a file with syntax error - with open(temp_cfg_path, 'w') as f: - f.write('a=0b=dict(c=1)') - with pytest.raises( - SyntaxError, match='There are syntax errors in config file'): - Config.fromfile(temp_cfg_path) - temp_cfg_file.close() - os.remove(temp_cfg_path) - - -def test_pickle_support(): - cfg_file = osp.join(data_path, 'config/n.py') - cfg = Config.fromfile(cfg_file) - - with tempfile.TemporaryDirectory() as temp_config_dir: - pkl_cfg_filename = osp.join(temp_config_dir, '_pickle.pkl') - dump(cfg, pkl_cfg_filename) - pkl_cfg = load(pkl_cfg_filename) - - assert pkl_cfg._cfg_dict == cfg._cfg_dict - - -def test_deprecation(): - deprecated_cfg_files = [ - osp.join(data_path, 'config/deprecated.py'), - osp.join(data_path, 'config/deprecated_as_base.py') - ] - - for cfg_file in deprecated_cfg_files: - with pytest.warns(DeprecationWarning): - cfg = Config.fromfile(cfg_file) - assert cfg.item1 == 'expected' - - -def test_deepcopy(): - cfg_file = osp.join(data_path, 'config/n.py') - cfg = Config.fromfile(cfg_file) - new_cfg = copy.deepcopy(cfg) - - assert isinstance(new_cfg, Config) - assert new_cfg._cfg_dict == cfg._cfg_dict - assert new_cfg._cfg_dict is not cfg._cfg_dict - assert new_cfg._filename == cfg._filename - assert new_cfg._text == cfg._text - - -def test_copy(): - cfg_file = osp.join(data_path, 'config/n.py') - cfg = Config.fromfile(cfg_file) - new_cfg = copy.copy(cfg) - - assert isinstance(new_cfg, Config) - assert new_cfg is not cfg - assert new_cfg._cfg_dict is cfg._cfg_dict - assert new_cfg._filename == cfg._filename - assert new_cfg._text == cfg._text diff --git a/tests/test_utils/test_logging.py b/tests/test_utils/test_logging.py deleted file mode 100644 index ab66a34b94..0000000000 --- a/tests/test_utils/test_logging.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import logging -import os -import platform -import tempfile -from unittest.mock import patch - -import pytest - -from mmcv import get_logger, print_log - -if platform.system() == 'Windows': - import regex as re -else: - import re - - -@patch('torch.distributed.get_rank', lambda: 0) -@patch('torch.distributed.is_initialized', lambda: True) -@patch('torch.distributed.is_available', lambda: True) -def test_get_logger_rank0(): - logger = get_logger('rank0.pkg1') - assert isinstance(logger, logging.Logger) - assert len(logger.handlers) == 1 - assert isinstance(logger.handlers[0], logging.StreamHandler) - assert logger.handlers[0].level == logging.INFO - - logger = get_logger('rank0.pkg2', log_level=logging.DEBUG) - assert isinstance(logger, logging.Logger) - assert len(logger.handlers) == 1 - assert logger.handlers[0].level == logging.DEBUG - - # the name can not be used to open the file a second time in windows, - # so `delete` should be set as `False` and we need to manually remove it - # more details can be found at https://github.com/open-mmlab/mmcv/pull/1077 - with tempfile.NamedTemporaryFile(delete=False) as f: - logger = get_logger('rank0.pkg3', log_file=f.name) - assert isinstance(logger, logging.Logger) - assert len(logger.handlers) == 2 - assert isinstance(logger.handlers[0], logging.StreamHandler) - assert isinstance(logger.handlers[1], logging.FileHandler) - logger_pkg3 = get_logger('rank0.pkg3') - assert id(logger_pkg3) == id(logger) - # flushing and closing all handlers in order to remove `f.name` - logging.shutdown() - - os.remove(f.name) - - logger_pkg3 = get_logger('rank0.pkg3.subpkg') - assert logger_pkg3.handlers == logger_pkg3.handlers - - -@patch('torch.distributed.get_rank', lambda: 1) -@patch('torch.distributed.is_initialized', lambda: True) -@patch('torch.distributed.is_available', lambda: True) -def test_get_logger_rank1(): - logger = get_logger('rank1.pkg1') - assert isinstance(logger, logging.Logger) - assert len(logger.handlers) == 1 - assert isinstance(logger.handlers[0], logging.StreamHandler) - assert logger.handlers[0].level == logging.INFO - - # the name can not be used to open the file a second time in windows, - # so `delete` should be set as `False` and we need to manually remove it - # more details can be found at https://github.com/open-mmlab/mmcv/pull/1077 - with tempfile.NamedTemporaryFile(delete=False) as f: - logger = get_logger('rank1.pkg2', log_file=f.name) - assert isinstance(logger, logging.Logger) - assert len(logger.handlers) == 1 - assert logger.handlers[0].level == logging.INFO - # flushing and closing all handlers in order to remove `f.name` - logging.shutdown() - - os.remove(f.name) - - -def test_print_log_print(capsys): - print_log('welcome', logger=None) - out, _ = capsys.readouterr() - assert out == 'welcome\n' - - -def test_print_log_silent(capsys, caplog): - print_log('welcome', logger='silent') - out, _ = capsys.readouterr() - assert out == '' - assert len(caplog.records) == 0 - - -def test_print_log_logger(caplog): - print_log('welcome', logger='mmcv') - assert caplog.record_tuples[-1] == ('mmcv', logging.INFO, 'welcome') - - print_log('welcome', logger='mmcv', level=logging.ERROR) - assert caplog.record_tuples[-1] == ('mmcv', logging.ERROR, 'welcome') - - # the name can not be used to open the file a second time in windows, - # so `delete` should be set as `False` and we need to manually remove it - # more details can be found at https://github.com/open-mmlab/mmcv/pull/1077 - with tempfile.NamedTemporaryFile(delete=False) as f: - logger = get_logger('abc', log_file=f.name) - print_log('welcome', logger=logger) - assert caplog.record_tuples[-1] == ('abc', logging.INFO, 'welcome') - with open(f.name) as fin: - log_text = fin.read() - regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}' - match = re.fullmatch(regex_time + r' - abc - INFO - welcome\n', - log_text) - assert match is not None - # flushing and closing all handlers in order to remove `f.name` - logging.shutdown() - - os.remove(f.name) - - -def test_print_log_exception(): - with pytest.raises(TypeError): - print_log('welcome', logger=0) diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py deleted file mode 100644 index 2b14c00778..0000000000 --- a/tests/test_utils/test_misc.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest - -import mmcv -from mmcv import deprecated_api_warning -from mmcv.utils.misc import has_method - - -def test_to_ntuple(): - single_number = 2 - assert mmcv.utils.to_1tuple(single_number) == (single_number, ) - assert mmcv.utils.to_2tuple(single_number) == (single_number, - single_number) - assert mmcv.utils.to_3tuple(single_number) == (single_number, - single_number, - single_number) - assert mmcv.utils.to_4tuple(single_number) == (single_number, - single_number, - single_number, - single_number) - assert mmcv.utils.to_ntuple(5)(single_number) == (single_number, - single_number, - single_number, - single_number, - single_number) - assert mmcv.utils.to_ntuple(6)(single_number) == (single_number, - single_number, - single_number, - single_number, - single_number, - single_number) - - -def test_iter_cast(): - assert mmcv.list_cast([1, 2, 3], int) == [1, 2, 3] - assert mmcv.list_cast(['1.1', 2, '3'], float) == [1.1, 2.0, 3.0] - assert mmcv.list_cast([1, 2, 3], str) == ['1', '2', '3'] - assert mmcv.tuple_cast((1, 2, 3), str) == ('1', '2', '3') - assert next(mmcv.iter_cast([1, 2, 3], str)) == '1' - with pytest.raises(TypeError): - mmcv.iter_cast([1, 2, 3], '') - with pytest.raises(TypeError): - mmcv.iter_cast(1, str) - - -def test_is_seq_of(): - assert mmcv.is_seq_of([1.0, 2.0, 3.0], float) - assert mmcv.is_seq_of([(1, ), (2, ), (3, )], tuple) - assert mmcv.is_seq_of((1.0, 2.0, 3.0), float) - assert mmcv.is_list_of([1.0, 2.0, 3.0], float) - assert not mmcv.is_seq_of((1.0, 2.0, 3.0), float, seq_type=list) - assert not mmcv.is_tuple_of([1.0, 2.0, 3.0], float) - assert not mmcv.is_seq_of([1.0, 2, 3], int) - assert not mmcv.is_seq_of((1.0, 2, 3), int) - - -def test_slice_list(): - in_list = [1, 2, 3, 4, 5, 6] - assert mmcv.slice_list(in_list, [1, 2, 3]) == [[1], [2, 3], [4, 5, 6]] - assert mmcv.slice_list(in_list, [len(in_list)]) == [in_list] - with pytest.raises(TypeError): - mmcv.slice_list(in_list, 2.0) - with pytest.raises(ValueError): - mmcv.slice_list(in_list, [1, 2]) - - -def test_concat_list(): - assert mmcv.concat_list([[1, 2]]) == [1, 2] - assert mmcv.concat_list([[1, 2], [3, 4, 5], [6]]) == [1, 2, 3, 4, 5, 6] - - -def test_requires_package(capsys): - - @mmcv.requires_package('nnn') - def func_a(): - pass - - @mmcv.requires_package(['numpy', 'n1', 'n2']) - def func_b(): - pass - - @mmcv.requires_package('numpy') - def func_c(): - return 1 - - with pytest.raises(RuntimeError): - func_a() - out, _ = capsys.readouterr() - assert out == ('Prerequisites "nnn" are required in method "func_a" but ' - 'not found, please install them first.\n') - - with pytest.raises(RuntimeError): - func_b() - out, _ = capsys.readouterr() - assert out == ( - 'Prerequisites "n1, n2" are required in method "func_b" but not found,' - ' please install them first.\n') - - assert func_c() == 1 - - -def test_requires_executable(capsys): - - @mmcv.requires_executable('nnn') - def func_a(): - pass - - @mmcv.requires_executable(['ls', 'n1', 'n2']) - def func_b(): - pass - - @mmcv.requires_executable('mv') - def func_c(): - return 1 - - with pytest.raises(RuntimeError): - func_a() - out, _ = capsys.readouterr() - assert out == ('Prerequisites "nnn" are required in method "func_a" but ' - 'not found, please install them first.\n') - - with pytest.raises(RuntimeError): - func_b() - out, _ = capsys.readouterr() - assert out == ( - 'Prerequisites "n1, n2" are required in method "func_b" but not found,' - ' please install them first.\n') - - assert func_c() == 1 - - -def test_import_modules_from_strings(): - # multiple imports - import os.path as osp_ - import sys as sys_ - osp, sys = mmcv.import_modules_from_strings(['os.path', 'sys']) - assert osp == osp_ - assert sys == sys_ - - # single imports - osp = mmcv.import_modules_from_strings('os.path') - assert osp == osp_ - # No imports - assert mmcv.import_modules_from_strings(None) is None - assert mmcv.import_modules_from_strings([]) is None - assert mmcv.import_modules_from_strings('') is None - # Unsupported types - with pytest.raises(TypeError): - mmcv.import_modules_from_strings(1) - with pytest.raises(TypeError): - mmcv.import_modules_from_strings([1]) - # Failed imports - with pytest.raises(ImportError): - mmcv.import_modules_from_strings('_not_implemented_module') - with pytest.warns(UserWarning): - imported = mmcv.import_modules_from_strings( - '_not_implemented_module', allow_failed_imports=True) - assert imported is None - with pytest.warns(UserWarning): - imported = mmcv.import_modules_from_strings( - ['os.path', '_not_implemented'], allow_failed_imports=True) - assert imported[0] == osp - assert imported[1] is None - - -def test_is_method_overridden(): - - class Base: - - def foo1(): - pass - - def foo2(): - pass - - class Sub(Base): - - def foo1(): - pass - - # test passing sub class directly - assert mmcv.is_method_overridden('foo1', Base, Sub) - assert not mmcv.is_method_overridden('foo2', Base, Sub) - - # test passing instance of sub class - sub_instance = Sub() - assert mmcv.is_method_overridden('foo1', Base, sub_instance) - assert not mmcv.is_method_overridden('foo2', Base, sub_instance) - - # base_class should be a class, not instance - base_instance = Base() - with pytest.raises(AssertionError): - mmcv.is_method_overridden('foo1', base_instance, sub_instance) - - -def test_has_method(): - - class Foo: - - def __init__(self, name): - self.name = name - - def print_name(self): - print(self.name) - - foo = Foo('foo') - assert not has_method(foo, 'name') - assert has_method(foo, 'print_name') - - -def test_deprecated_api_warning(): - - @deprecated_api_warning(name_dict=dict(old_key='new_key')) - def dummy_func(new_key=1): - return new_key - - # replace `old_key` to `new_key` - assert dummy_func(old_key=2) == 2 - - # The expected behavior is to replace the - # deprecated key `old_key` to `new_key`, - # but got them in the arguments at the same time - with pytest.raises(AssertionError): - dummy_func(old_key=1, new_key=2) diff --git a/tests/test_utils/test_parrots_jit.py b/tests/test_utils/test_parrots_jit.py index 71be929fb4..7e35567bdb 100644 --- a/tests/test_utils/test_parrots_jit.py +++ b/tests/test_utils/test_parrots_jit.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import pytest import torch +from mmengine.utils import TORCH_VERSION import mmcv -from mmcv.utils import TORCH_VERSION pytest.skip('this test not ready now', allow_module_level=True) skip_no_parrots = pytest.mark.skipif( diff --git a/tests/test_utils/test_path.py b/tests/test_utils/test_path.py deleted file mode 100644 index 56d65ce264..0000000000 --- a/tests/test_utils/test_path.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp -from pathlib import Path - -import pytest - -import mmcv - - -def test_is_filepath(): - assert mmcv.is_filepath(__file__) - assert mmcv.is_filepath('abc') - assert mmcv.is_filepath(Path('/etc')) - assert not mmcv.is_filepath(0) - - -def test_fopen(): - assert hasattr(mmcv.fopen(__file__), 'read') - assert hasattr(mmcv.fopen(Path(__file__)), 'read') - - -def test_check_file_exist(): - mmcv.check_file_exist(__file__) - with pytest.raises(FileNotFoundError): - mmcv.check_file_exist('no_such_file.txt') - - -def test_scandir(): - folder = osp.join(osp.dirname(osp.dirname(__file__)), 'data/for_scan') - filenames = ['a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT'] - assert set(mmcv.scandir(folder)) == set(filenames) - assert set(mmcv.scandir(Path(folder))) == set(filenames) - assert set(mmcv.scandir(folder, '.txt')) == { - filename - for filename in filenames if filename.endswith('.txt') - } - assert set(mmcv.scandir(folder, ('.json', '.txt'))) == { - filename - for filename in filenames if filename.endswith(('.txt', '.json')) - } - assert set(mmcv.scandir(folder, '.png')) == set() - - # path of sep is `\\` in windows but `/` in linux, so osp.join should be - # used to join string for compatibility - filenames_recursive = [ - 'a.bin', '1.txt', '2.txt', '1.json', '2.json', '3.TXT', - osp.join('sub', '1.json'), - osp.join('sub', '1.txt'), '.file' - ] - # .file starts with '.' and is a file so it will not be scanned - assert set(mmcv.scandir(folder, recursive=True)) == { - filename - for filename in filenames_recursive if filename != '.file' - } - assert set(mmcv.scandir(Path(folder), recursive=True)) == { - filename - for filename in filenames_recursive if filename != '.file' - } - assert set(mmcv.scandir(folder, '.txt', recursive=True)) == { - filename - for filename in filenames_recursive if filename.endswith('.txt') - } - assert set( - mmcv.scandir(folder, '.TXT', recursive=True, - case_sensitive=False)) == { - filename - for filename in filenames_recursive - if filename.endswith(('.txt', '.TXT')) - } - assert set( - mmcv.scandir( - folder, ('.TXT', '.JSON'), recursive=True, - case_sensitive=False)) == { - filename - for filename in filenames_recursive - if filename.endswith(('.txt', '.json', '.TXT')) - } - with pytest.raises(TypeError): - list(mmcv.scandir(123)) - with pytest.raises(TypeError): - list(mmcv.scandir(folder, 111)) diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py deleted file mode 100644 index 982aa247f7..0000000000 --- a/tests/test_utils/test_progressbar.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os -import time -from io import StringIO -from unittest.mock import patch - -import mmcv - - -def reset_string_io(io): - io.truncate(0) - io.seek(0) - - -class TestProgressBar: - - def test_start(self): - out = StringIO() - bar_width = 20 - # without total task num - prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) - assert out.getvalue() == 'completed: 0, elapsed: 0s' - reset_string_io(out) - prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out) - assert out.getvalue() == '' - reset_string_io(out) - prog_bar.start() - assert out.getvalue() == 'completed: 0, elapsed: 0s' - # with total task num - reset_string_io(out) - prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) - assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' - reset_string_io(out) - prog_bar = mmcv.ProgressBar( - 10, bar_width=bar_width, start=False, file=out) - assert out.getvalue() == '' - reset_string_io(out) - prog_bar.start() - assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' - - def test_update(self): - out = StringIO() - bar_width = 20 - # without total task num - prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) - time.sleep(1) - reset_string_io(out) - prog_bar.update() - assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s' - reset_string_io(out) - # with total task num - prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) - time.sleep(1) - reset_string_io(out) - prog_bar.update() - assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ - 'task/s, elapsed: 1s, ETA: 9s' - - def test_adaptive_length(self): - with patch.dict('os.environ', {'COLUMNS': '80'}): - out = StringIO() - bar_width = 20 - prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) - time.sleep(1) - reset_string_io(out) - prog_bar.update() - assert len(out.getvalue()) == 66 - - os.environ['COLUMNS'] = '30' - reset_string_io(out) - prog_bar.update() - assert len(out.getvalue()) == 48 - - os.environ['COLUMNS'] = '60' - reset_string_io(out) - prog_bar.update() - assert len(out.getvalue()) == 60 - - -def sleep_1s(num): - time.sleep(1) - return num - - -def test_track_progress_list(): - out = StringIO() - ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) - assert out.getvalue() == ( - '[ ] 0/3, elapsed: 0s, ETA:' - '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' - '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' - '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') - assert ret == [1, 2, 3] - - -def test_track_progress_iterator(): - out = StringIO() - ret = mmcv.track_progress( - sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) - assert out.getvalue() == ( - '[ ] 0/3, elapsed: 0s, ETA:' - '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' - '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' - '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') - assert ret == [1, 2, 3] - - -def test_track_iter_progress(): - out = StringIO() - ret = [] - for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out): - ret.append(sleep_1s(num)) - assert out.getvalue() == ( - '[ ] 0/3, elapsed: 0s, ETA:' - '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' - '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' - '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') - assert ret == [1, 2, 3] - - -def test_track_enum_progress(): - out = StringIO() - ret = [] - count = [] - for i, num in enumerate( - mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)): - ret.append(sleep_1s(num)) - count.append(i) - assert out.getvalue() == ( - '[ ] 0/3, elapsed: 0s, ETA:' - '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' - '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' - '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') - assert ret == [1, 2, 3] - assert count == [0, 1, 2] - - -def test_track_parallel_progress_list(): - out = StringIO() - results = mmcv.track_parallel_progress( - sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out) - # The following cannot pass CI on Github Action - # assert out.getvalue() == ( - # '[ ] 0/4, elapsed: 0s, ETA:' - # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' - # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' - # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' - # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') - assert results == [1, 2, 3, 4] - - -def test_track_parallel_progress_iterator(): - out = StringIO() - results = mmcv.track_parallel_progress( - sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out) - # The following cannot pass CI on Github Action - # assert out.getvalue() == ( - # '[ ] 0/4, elapsed: 0s, ETA:' - # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' - # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' - # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' - # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') - assert results == [1, 2, 3, 4] diff --git a/tests/test_utils/test_registry.py b/tests/test_utils/test_registry.py deleted file mode 100644 index 09dc46b7cd..0000000000 --- a/tests/test_utils/test_registry.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest - -import mmcv - - -def test_registry(): - CATS = mmcv.Registry('cat') - assert CATS.name == 'cat' - assert CATS.module_dict == {} - assert len(CATS) == 0 - - @CATS.register_module() - class BritishShorthair: - pass - - assert len(CATS) == 1 - assert CATS.get('BritishShorthair') is BritishShorthair - - class Munchkin: - pass - - CATS.register_module(Munchkin) - assert len(CATS) == 2 - assert CATS.get('Munchkin') is Munchkin - assert 'Munchkin' in CATS - - with pytest.raises(KeyError): - CATS.register_module(Munchkin) - - CATS.register_module(Munchkin, force=True) - assert len(CATS) == 2 - - # force=False - with pytest.raises(KeyError): - - @CATS.register_module() - class BritishShorthair: - pass - - @CATS.register_module(force=True) - class BritishShorthair: - pass - - assert len(CATS) == 2 - - assert CATS.get('PersianCat') is None - assert 'PersianCat' not in CATS - - @CATS.register_module(name=['Siamese', 'Siamese2']) - class SiameseCat: - pass - - assert CATS.get('Siamese').__name__ == 'SiameseCat' - assert CATS.get('Siamese2').__name__ == 'SiameseCat' - - class SphynxCat: - pass - - CATS.register_module(name='Sphynx', module=SphynxCat) - assert CATS.get('Sphynx') is SphynxCat - - CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat) - assert CATS.get('Sphynx2') is SphynxCat - - repr_str = 'Registry(name=cat, items={' - repr_str += ("'BritishShorthair': .BritishShorthair'>, ") - repr_str += ("'Munchkin': .Munchkin'>, ") - repr_str += ("'Siamese': .SiameseCat'>, ") - repr_str += ("'Siamese2': .SiameseCat'>, ") - repr_str += ("'Sphynx': .SphynxCat'>, ") - repr_str += ("'Sphynx1': .SphynxCat'>, ") - repr_str += ("'Sphynx2': .SphynxCat'>") - repr_str += '})' - assert repr(CATS) == repr_str - - # name type - with pytest.raises(TypeError): - CATS.register_module(name=7474741, module=SphynxCat) - - # the registered module should be a class - with pytest.raises(TypeError): - CATS.register_module(0) - - @CATS.register_module() - def muchkin(): - pass - - assert CATS.get('muchkin') is muchkin - assert 'muchkin' in CATS - - # can only decorate a class or a function - with pytest.raises(TypeError): - - class Demo: - - def some_method(self): - pass - - method = Demo().some_method - CATS.register_module(name='some_method', module=method) - - # begin: test old APIs - with pytest.warns(DeprecationWarning): - CATS.register_module(SphynxCat) - assert CATS.get('SphynxCat').__name__ == 'SphynxCat' - - with pytest.warns(DeprecationWarning): - CATS.register_module(SphynxCat, force=True) - assert CATS.get('SphynxCat').__name__ == 'SphynxCat' - - with pytest.warns(DeprecationWarning): - - @CATS.register_module - class NewCat: - pass - - assert CATS.get('NewCat').__name__ == 'NewCat' - - with pytest.warns(DeprecationWarning): - CATS.deprecated_register_module(SphynxCat, force=True) - assert CATS.get('SphynxCat').__name__ == 'SphynxCat' - - with pytest.warns(DeprecationWarning): - - @CATS.deprecated_register_module - class CuteCat: - pass - - assert CATS.get('CuteCat').__name__ == 'CuteCat' - - with pytest.warns(DeprecationWarning): - - @CATS.deprecated_register_module(force=True) - class NewCat2: - pass - - assert CATS.get('NewCat2').__name__ == 'NewCat2' - - # end: test old APIs - - -def test_multi_scope_registry(): - DOGS = mmcv.Registry('dogs') - assert DOGS.name == 'dogs' - assert DOGS.scope == 'test_registry' - assert DOGS.module_dict == {} - assert len(DOGS) == 0 - - @DOGS.register_module() - class GoldenRetriever: - pass - - assert len(DOGS) == 1 - assert DOGS.get('GoldenRetriever') is GoldenRetriever - - HOUNDS = mmcv.Registry('dogs', parent=DOGS, scope='hound') - - @HOUNDS.register_module() - class BloodHound: - pass - - assert len(HOUNDS) == 1 - assert HOUNDS.get('BloodHound') is BloodHound - assert DOGS.get('hound.BloodHound') is BloodHound - assert HOUNDS.get('hound.BloodHound') is BloodHound - - LITTLE_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='little_hound') - - @LITTLE_HOUNDS.register_module() - class Dachshund: - pass - - assert len(LITTLE_HOUNDS) == 1 - assert LITTLE_HOUNDS.get('Dachshund') is Dachshund - assert LITTLE_HOUNDS.get('hound.BloodHound') is BloodHound - assert HOUNDS.get('little_hound.Dachshund') is Dachshund - assert DOGS.get('hound.little_hound.Dachshund') is Dachshund - - MID_HOUNDS = mmcv.Registry('dogs', parent=HOUNDS, scope='mid_hound') - - @MID_HOUNDS.register_module() - class Beagle: - pass - - assert MID_HOUNDS.get('Beagle') is Beagle - assert HOUNDS.get('mid_hound.Beagle') is Beagle - assert DOGS.get('hound.mid_hound.Beagle') is Beagle - assert LITTLE_HOUNDS.get('hound.mid_hound.Beagle') is Beagle - assert MID_HOUNDS.get('hound.BloodHound') is BloodHound - assert MID_HOUNDS.get('hound.Dachshund') is None - - -def test_build_from_cfg(): - BACKBONES = mmcv.Registry('backbone') - - @BACKBONES.register_module() - class ResNet: - - def __init__(self, depth, stages=4): - self.depth = depth - self.stages = stages - - @BACKBONES.register_module() - class ResNeXt: - - def __init__(self, depth, stages=4): - self.depth = depth - self.stages = stages - - cfg = dict(type='ResNet', depth=50) - model = mmcv.build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - cfg = dict(type='ResNet', depth=50) - model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3}) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 3 - - cfg = dict(type='ResNeXt', depth=50, stages=3) - model = mmcv.build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNeXt) - assert model.depth == 50 and model.stages == 3 - - cfg = dict(type=ResNet, depth=50) - model = mmcv.build_from_cfg(cfg, BACKBONES) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # type defined using default_args - cfg = dict(depth=50) - model = mmcv.build_from_cfg( - cfg, BACKBONES, default_args=dict(type='ResNet')) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - cfg = dict(depth=50) - model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet)) - assert isinstance(model, ResNet) - assert model.depth == 50 and model.stages == 4 - - # not a registry - with pytest.raises(TypeError): - cfg = dict(type='VGG') - model = mmcv.build_from_cfg(cfg, 'BACKBONES') - - # non-registered class - with pytest.raises(KeyError): - cfg = dict(type='VGG') - model = mmcv.build_from_cfg(cfg, BACKBONES) - - # default_args must be a dict or None - with pytest.raises(TypeError): - cfg = dict(type='ResNet', depth=50) - model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1) - - # cfg['type'] should be a str or class - with pytest.raises(TypeError): - cfg = dict(type=1000) - model = mmcv.build_from_cfg(cfg, BACKBONES) - - # cfg should contain the key "type" - with pytest.raises(KeyError, match='must contain the key "type"'): - cfg = dict(depth=50, stages=4) - model = mmcv.build_from_cfg(cfg, BACKBONES) - - # cfg or default_args should contain the key "type" - with pytest.raises(KeyError, match='must contain the key "type"'): - cfg = dict(depth=50) - model = mmcv.build_from_cfg( - cfg, BACKBONES, default_args=dict(stages=4)) - - # incorrect registry type - with pytest.raises(TypeError): - cfg = dict(type='ResNet', depth=50) - model = mmcv.build_from_cfg(cfg, 'BACKBONES') - - # incorrect default_args type - with pytest.raises(TypeError): - cfg = dict(type='ResNet', depth=50) - model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0) - - # incorrect arguments - with pytest.raises(TypeError): - cfg = dict(type='ResNet', non_existing_arg=50) - model = mmcv.build_from_cfg(cfg, BACKBONES) diff --git a/tests/test_utils/test_testing.py b/tests/test_utils/test_testing.py deleted file mode 100644 index c6f8e8d123..0000000000 --- a/tests/test_utils/test_testing.py +++ /dev/null @@ -1,195 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import numpy as np -import pytest - -import mmcv - -try: - import torch -except ImportError: - torch = None -else: - import torch.nn as nn - - -def test_assert_dict_contains_subset(): - dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6)} - - # case 1 - expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6)} - assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - # case 2 - expected_subset = {'a': 'test1', 'b': 2, 'c': (6, 4)} - assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - # case 3 - expected_subset = {'a': 'test1', 'b': 2, 'c': None} - assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - # case 4 - expected_subset = {'a': 'test1', 'b': 2, 'd': (4, 6)} - assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - # case 5 - dict_obj = { - 'a': 'test1', - 'b': 2, - 'c': (4, 6), - 'd': np.array([[5, 3, 5], [1, 2, 3]]) - } - expected_subset = { - 'a': 'test1', - 'b': 2, - 'c': (4, 6), - 'd': np.array([[5, 3, 5], [6, 2, 3]]) - } - assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - # case 6 - dict_obj = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} - expected_subset = {'a': 'test1', 'b': 2, 'c': (4, 6), 'd': np.array([[1]])} - assert mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - if torch is not None: - dict_obj = { - 'a': 'test1', - 'b': 2, - 'c': (4, 6), - 'd': torch.tensor([5, 3, 5]) - } - - # case 7 - expected_subset = {'d': torch.tensor([5, 5, 5])} - assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - # case 8 - expected_subset = {'d': torch.tensor([[5, 3, 5], [4, 1, 2]])} - assert not mmcv.assert_dict_contains_subset(dict_obj, expected_subset) - - -def test_assert_attrs_equal(): - - class TestExample: - a, b, c = 1, ('wvi', 3), [4.5, 3.14] - - def test_func(self): - return self.b - - # case 1 - assert mmcv.assert_attrs_equal(TestExample, { - 'a': 1, - 'b': ('wvi', 3), - 'c': [4.5, 3.14] - }) - - # case 2 - assert not mmcv.assert_attrs_equal(TestExample, { - 'a': 1, - 'b': ('wvi', 3), - 'c': [4.5, 3.14, 2] - }) - - # case 3 - assert not mmcv.assert_attrs_equal(TestExample, { - 'bc': 54, - 'c': [4.5, 3.14] - }) - - # case 4 - assert mmcv.assert_attrs_equal(TestExample, { - 'b': ('wvi', 3), - 'test_func': TestExample.test_func - }) - - if torch is not None: - - class TestExample: - a, b = torch.tensor([1]), torch.tensor([4, 5]) - - # case 5 - assert mmcv.assert_attrs_equal(TestExample, { - 'a': torch.tensor([1]), - 'b': torch.tensor([4, 5]) - }) - - # case 6 - assert not mmcv.assert_attrs_equal(TestExample, { - 'a': torch.tensor([1]), - 'b': torch.tensor([4, 6]) - }) - - -assert_dict_has_keys_data_1 = [({ - 'res_layer': 1, - 'norm_layer': 2, - 'dense_layer': 3 -})] -assert_dict_has_keys_data_2 = [(['res_layer', 'dense_layer'], True), - (['res_layer', 'conv_layer'], False)] - - -@pytest.mark.parametrize('obj', assert_dict_has_keys_data_1) -@pytest.mark.parametrize('expected_keys, ret_value', - assert_dict_has_keys_data_2) -def test_assert_dict_has_keys(obj, expected_keys, ret_value): - assert mmcv.assert_dict_has_keys(obj, expected_keys) == ret_value - - -assert_keys_equal_data_1 = [(['res_layer', 'norm_layer', 'dense_layer'])] -assert_keys_equal_data_2 = [(['res_layer', 'norm_layer', 'dense_layer'], True), - (['res_layer', 'dense_layer', 'norm_layer'], True), - (['res_layer', 'norm_layer'], False), - (['res_layer', 'conv_layer', 'norm_layer'], False)] - - -@pytest.mark.parametrize('result_keys', assert_keys_equal_data_1) -@pytest.mark.parametrize('target_keys, ret_value', assert_keys_equal_data_2) -def test_assert_keys_equal(result_keys, target_keys, ret_value): - assert mmcv.assert_keys_equal(result_keys, target_keys) == ret_value - - -@pytest.mark.skipif(torch is None, reason='requires torch library') -def test_assert_is_norm_layer(): - # case 1 - assert not mmcv.assert_is_norm_layer(nn.Conv3d(3, 64, 3)) - - # case 2 - assert mmcv.assert_is_norm_layer(nn.BatchNorm3d(128)) - - # case 3 - assert mmcv.assert_is_norm_layer(nn.GroupNorm(8, 64)) - - # case 4 - assert not mmcv.assert_is_norm_layer(nn.Sigmoid()) - - -@pytest.mark.skipif(torch is None, reason='requires torch library') -def test_assert_params_all_zeros(): - demo_module = nn.Conv2d(3, 64, 3) - nn.init.constant_(demo_module.weight, 0) - nn.init.constant_(demo_module.bias, 0) - assert mmcv.assert_params_all_zeros(demo_module) - - nn.init.xavier_normal_(demo_module.weight) - nn.init.constant_(demo_module.bias, 0) - assert not mmcv.assert_params_all_zeros(demo_module) - - demo_module = nn.Linear(2048, 400, bias=False) - nn.init.constant_(demo_module.weight, 0) - assert mmcv.assert_params_all_zeros(demo_module) - - nn.init.normal_(demo_module.weight, mean=0, std=0.01) - assert not mmcv.assert_params_all_zeros(demo_module) - - -def test_check_python_script(capsys): - mmcv.utils.check_python_script('./tests/data/scripts/hello.py zz') - captured = capsys.readouterr().out - assert captured == 'hello zz!\n' - mmcv.utils.check_python_script('./tests/data/scripts/hello.py agent') - captured = capsys.readouterr().out - assert captured == 'hello agent!\n' - # Make sure that wrong cmd raises an error - with pytest.raises(SystemExit): - mmcv.utils.check_python_script('./tests/data/scripts/hello.py li zz') diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py deleted file mode 100644 index 983f64f58e..0000000000 --- a/tests/test_utils/test_timer.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import time - -import pytest - -import mmcv - - -def test_timer_init(): - timer = mmcv.Timer(start=False) - assert not timer.is_running - timer.start() - assert timer.is_running - timer = mmcv.Timer() - assert timer.is_running - - -def test_timer_run(): - timer = mmcv.Timer() - time.sleep(1) - assert abs(timer.since_start() - 1) < 1e-2 - time.sleep(1) - assert abs(timer.since_last_check() - 1) < 1e-2 - assert abs(timer.since_start() - 2) < 1e-2 - timer = mmcv.Timer(False) - with pytest.raises(mmcv.TimerError): - timer.since_start() - with pytest.raises(mmcv.TimerError): - timer.since_last_check() - - -def test_timer_context(capsys): - with mmcv.Timer(): - time.sleep(1) - out, _ = capsys.readouterr() - assert abs(float(out) - 1) < 1e-2 - with mmcv.Timer(print_tmpl='time: {:.1f}s'): - time.sleep(1) - out, _ = capsys.readouterr() - assert out == 'time: 1.0s\n' diff --git a/tests/test_utils/test_torch_ops.py b/tests/test_utils/test_torch_ops.py deleted file mode 100644 index e8752e0fd6..0000000000 --- a/tests/test_utils/test_torch_ops.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch - -from mmcv.utils import torch_meshgrid - - -def test_torch_meshgrid(): - # torch_meshgrid should not throw warning - with pytest.warns(None) as record: - x = torch.tensor([1, 2, 3]) - y = torch.tensor([4, 5, 6]) - grid_x, grid_y = torch_meshgrid(x, y) - - assert len(record) == 0 diff --git a/tests/test_utils/test_trace.py b/tests/test_utils/test_trace.py deleted file mode 100644 index 2dbf2c8549..0000000000 --- a/tests/test_utils/test_trace.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import pytest -import torch - -from mmcv.utils import digit_version, is_jit_tracing - - -@pytest.mark.skipif( - digit_version(torch.__version__) < digit_version('1.6.0'), - reason='torch.jit.is_tracing is not available before 1.6.0') -def test_is_jit_tracing(): - - def foo(x): - if is_jit_tracing(): - return x - else: - return x.tolist() - - x = torch.rand(3) - # test without trace - assert isinstance(foo(x), list) - - # test with trace - traced_foo = torch.jit.trace(foo, (torch.rand(1), )) - assert isinstance(traced_foo(x), torch.Tensor) diff --git a/tests/test_utils/test_version_utils.py b/tests/test_utils/test_version_utils.py deleted file mode 100644 index 5400e3c86a..0000000000 --- a/tests/test_utils/test_version_utils.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import patch - -import pytest - -from mmcv import get_git_hash, parse_version_info -from mmcv.utils import digit_version - - -def test_digit_version(): - assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0) - assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0) - assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0) - assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1) - assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0) - assert digit_version('1.0') == digit_version('1.0.0') - assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5') - assert digit_version('1.0.0dev') < digit_version('1.0.0a') - assert digit_version('1.0.0a') < digit_version('1.0.0a1') - assert digit_version('1.0.0a') < digit_version('1.0.0b') - assert digit_version('1.0.0b') < digit_version('1.0.0rc') - assert digit_version('1.0.0rc1') < digit_version('1.0.0') - assert digit_version('1.0.0') < digit_version('1.0.0post') - assert digit_version('1.0.0post') < digit_version('1.0.0post1') - assert digit_version('v1') == (1, 0, 0, 0, 0, 0) - assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0) - with pytest.raises(AssertionError): - digit_version('a') - with pytest.raises(AssertionError): - digit_version('1x') - with pytest.raises(AssertionError): - digit_version('1.x') - - -def test_parse_version_info(): - assert parse_version_info('0.2.16') == (0, 2, 16, 0, 0, 0) - assert parse_version_info('1.2.3') == (1, 2, 3, 0, 0, 0) - assert parse_version_info('1.2.3rc0') == (1, 2, 3, 0, 'rc', 0) - assert parse_version_info('1.2.3rc1') == (1, 2, 3, 0, 'rc', 1) - assert parse_version_info('1.0rc0') == (1, 0, 0, 0, 'rc', 0) - - -def _mock_cmd_success(cmd): - return b'3b46d33e90c397869ad5103075838fdfc9812aa0' - - -def _mock_cmd_fail(cmd): - raise OSError - - -def test_get_git_hash(): - with patch('mmcv.utils.version_utils._minimal_ext_cmd', _mock_cmd_success): - assert get_git_hash() == '3b46d33e90c397869ad5103075838fdfc9812aa0' - assert get_git_hash(digits=6) == '3b46d3' - assert get_git_hash(digits=100) == get_git_hash() - with patch('mmcv.utils.version_utils._minimal_ext_cmd', _mock_cmd_fail): - assert get_git_hash() == 'unknown' - assert get_git_hash(fallback='n/a') == 'n/a'