forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request open-mmlab#16 from myownskyW7/dev
add high level api
- Loading branch information
Showing
8 changed files
with
275 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .env import init_dist, get_root_logger, set_random_seed | ||
from .train import train_detector | ||
from .inference import inference_detector | ||
|
||
__all__ = [ | ||
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector', | ||
'inference_detector' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
from mmcv.runner import get_dist_info | ||
|
||
|
||
def init_dist(launcher, backend='nccl', **kwargs): | ||
if mp.get_start_method(allow_none=True) is None: | ||
mp.set_start_method('spawn') | ||
if launcher == 'pytorch': | ||
_init_dist_pytorch(backend, **kwargs) | ||
elif launcher == 'mpi': | ||
_init_dist_mpi(backend, **kwargs) | ||
elif launcher == 'slurm': | ||
_init_dist_slurm(backend, **kwargs) | ||
else: | ||
raise ValueError('Invalid launcher type: {}'.format(launcher)) | ||
|
||
|
||
def _init_dist_pytorch(backend, **kwargs): | ||
# TODO: use local_rank instead of rank % num_gpus | ||
rank = int(os.environ['RANK']) | ||
num_gpus = torch.cuda.device_count() | ||
torch.cuda.set_device(rank % num_gpus) | ||
dist.init_process_group(backend=backend, **kwargs) | ||
|
||
|
||
def _init_dist_mpi(backend, **kwargs): | ||
raise NotImplementedError | ||
|
||
|
||
def _init_dist_slurm(backend, **kwargs): | ||
raise NotImplementedError | ||
|
||
|
||
def set_random_seed(seed): | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
|
||
|
||
def get_root_logger(log_level=logging.INFO): | ||
logger = logging.getLogger() | ||
if not logger.hasHandlers(): | ||
logging.basicConfig( | ||
format='%(asctime)s - %(levelname)s - %(message)s', | ||
level=log_level) | ||
rank, _ = get_dist_info() | ||
if rank != 0: | ||
logger.setLevel('ERROR') | ||
return logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import mmcv | ||
import numpy as np | ||
import torch | ||
|
||
from mmdet.datasets import to_tensor | ||
from mmdet.datasets.transforms import ImageTransform | ||
from mmdet.core import get_classes | ||
|
||
|
||
def _prepare_data(img, img_transform, cfg, device): | ||
ori_shape = img.shape | ||
img, img_shape, pad_shape, scale_factor = img_transform( | ||
img, scale=cfg.data.test.img_scale) | ||
img = to_tensor(img).to(device).unsqueeze(0) | ||
img_meta = [ | ||
dict( | ||
ori_shape=ori_shape, | ||
img_shape=img_shape, | ||
pad_shape=pad_shape, | ||
scale_factor=scale_factor, | ||
flip=False) | ||
] | ||
return dict(img=[img], img_meta=[img_meta]) | ||
|
||
|
||
def inference_detector(model, imgs, cfg, device='cuda:0'): | ||
|
||
imgs = imgs if isinstance(imgs, list) else [imgs] | ||
img_transform = ImageTransform( | ||
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) | ||
model = model.to(device) | ||
model.eval() | ||
for img in imgs: | ||
img = mmcv.imread(img) | ||
data = _prepare_data(img, img_transform, cfg, device) | ||
with torch.no_grad(): | ||
result = model(return_loss=False, rescale=True, **data) | ||
yield result | ||
|
||
|
||
def show_result(img, result, dataset='coco', score_thr=0.3): | ||
class_names = get_classes(dataset) | ||
labels = [ | ||
np.full(bbox.shape[0], i, dtype=np.int32) | ||
for i, bbox in enumerate(result) | ||
] | ||
labels = np.concatenate(labels) | ||
bboxes = np.vstack(result) | ||
mmcv.imshow_det_bboxes( | ||
img.copy(), | ||
bboxes, | ||
labels, | ||
class_names=class_names, | ||
score_thr=score_thr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from __future__ import division | ||
|
||
from collections import OrderedDict | ||
|
||
import torch | ||
from mmcv.runner import Runner, DistSamplerSeedHook | ||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | ||
|
||
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook, | ||
CocoDistEvalmAPHook) | ||
from mmdet.datasets import build_dataloader | ||
from mmdet.models import RPN | ||
from .env import get_root_logger | ||
|
||
|
||
def parse_losses(losses): | ||
log_vars = OrderedDict() | ||
for loss_name, loss_value in losses.items(): | ||
if isinstance(loss_value, torch.Tensor): | ||
log_vars[loss_name] = loss_value.mean() | ||
elif isinstance(loss_value, list): | ||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) | ||
else: | ||
raise TypeError( | ||
'{} is not a tensor or list of tensors'.format(loss_name)) | ||
|
||
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) | ||
|
||
log_vars['loss'] = loss | ||
for name in log_vars: | ||
log_vars[name] = log_vars[name].item() | ||
|
||
return loss, log_vars | ||
|
||
|
||
def batch_processor(model, data, train_mode): | ||
losses = model(**data) | ||
loss, log_vars = parse_losses(losses) | ||
|
||
outputs = dict( | ||
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data)) | ||
|
||
return outputs | ||
|
||
|
||
def train_detector(model, | ||
dataset, | ||
cfg, | ||
distributed=False, | ||
validate=False, | ||
logger=None): | ||
if logger is None: | ||
logger = get_root_logger(cfg.log_level) | ||
|
||
# start training | ||
if distributed: | ||
_dist_train(model, dataset, cfg, validate=validate) | ||
else: | ||
_non_dist_train(model, dataset, cfg, validate=validate) | ||
|
||
|
||
def _dist_train(model, dataset, cfg, validate=False): | ||
# prepare data loaders | ||
data_loaders = [ | ||
build_dataloader( | ||
dataset, | ||
cfg.data.imgs_per_gpu, | ||
cfg.data.workers_per_gpu, | ||
dist=True) | ||
] | ||
# put model on gpus | ||
model = MMDistributedDataParallel(model.cuda()) | ||
# build runner | ||
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, | ||
cfg.log_level) | ||
# register hooks | ||
optimizer_config = DistOptimizerHook(**cfg.optimizer_config) | ||
runner.register_training_hooks(cfg.lr_config, optimizer_config, | ||
cfg.checkpoint_config, cfg.log_config) | ||
runner.register_hook(DistSamplerSeedHook()) | ||
# register eval hooks | ||
if validate: | ||
if isinstance(model.module, RPN): | ||
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val)) | ||
elif cfg.data.val.type == 'CocoDataset': | ||
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val)) | ||
|
||
if cfg.resume_from: | ||
runner.resume(cfg.resume_from) | ||
elif cfg.load_from: | ||
runner.load_checkpoint(cfg.load_from) | ||
runner.run(data_loaders, cfg.workflow, cfg.total_epochs) | ||
|
||
|
||
def _non_dist_train(model, dataset, cfg, validate=False): | ||
# prepare data loaders | ||
data_loaders = [ | ||
build_dataloader( | ||
dataset, | ||
cfg.data.imgs_per_gpu, | ||
cfg.data.workers_per_gpu, | ||
cfg.gpus, | ||
dist=False) | ||
] | ||
# put model on gpus | ||
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() | ||
# build runner | ||
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, | ||
cfg.log_level) | ||
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, | ||
cfg.checkpoint_config, cfg.log_config) | ||
|
||
if cfg.resume_from: | ||
runner.resume(cfg.resume_from) | ||
elif cfg.load_from: | ||
runner.load_checkpoint(cfg.load_from) | ||
runner.run(data_loaders, cfg.workflow, cfg.total_epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook | ||
from .dist_utils import allreduce_grads, DistOptimizerHook | ||
from .misc import tensor2imgs, unmap, multi_apply | ||
|
||
__all__ = [ | ||
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', | ||
'unmap', 'multi_apply' | ||
'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap', | ||
'multi_apply' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.