Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add high level api #16

Merged
merged 5 commits into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mmdet/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector
from .inference import inference_detector

__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector'
]
57 changes: 57 additions & 0 deletions mmdet/api/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import os
import random

import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from mmcv.runner import get_dist_info


def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))


def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)


def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError


def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError


def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def get_root_logger(log_level=logging.INFO):
logger = logging.getLogger()
if not logger.hasHandlers():
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
return logger
54 changes: 54 additions & 0 deletions mmdet/api/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import mmcv
import numpy as np
import torch

from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform
from mmdet.core import get_classes


def _prepare_data(img, img_transform, cfg, device):
ori_shape = img.shape
img, img_shape, pad_shape, scale_factor = img_transform(
img, scale=cfg.data.test.img_scale)
img = to_tensor(img).to(device).unsqueeze(0)
img_meta = [
dict(
ori_shape=ori_shape,
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
flip=False)
]
return dict(img=[img], img_meta=[img_meta])


def inference_detector(model, imgs, cfg, device='cuda:0'):

imgs = imgs if isinstance(imgs, list) else [imgs]
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device)
model.eval()
for img in imgs:
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
yield result


def show_result(img, result, dataset='coco', score_thr=0.3):
class_names = get_classes(dataset)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
mmcv.imshow_det_bboxes(
img.copy(),
bboxes,
labels,
class_names=class_names,
score_thr=score_thr)
117 changes: 117 additions & 0 deletions mmdet/api/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import division

from collections import OrderedDict

import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
from mmdet.models import RPN
from .env import get_root_logger


def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))

loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

log_vars['loss'] = loss
for name in log_vars:
log_vars[name] = log_vars[name].item()

return loss, log_vars


def batch_processor(model, data, train_mode):
losses = model(**data)
loss, log_vars = parse_losses(losses)

outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))

return outputs


def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)

# start training
if distributed:
_dist_train(model, dataset, cfg, validate=validate)
else:
_non_dist_train(model, dataset, cfg, validate=validate)


def _dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
dist=True)
]
# put model on gpus
model = MMDistributedDataParallel(model.cuda())
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
# register hooks
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
if isinstance(model.module, RPN):
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)


def _non_dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config)

if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
6 changes: 3 additions & 3 deletions mmdet/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook
from .dist_utils import allreduce_grads, DistOptimizerHook
from .misc import tensor2imgs, unmap, multi_apply

__all__ = [
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs',
'unmap', 'multi_apply'
'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap',
'multi_apply'
]
32 changes: 0 additions & 32 deletions mmdet/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,11 @@
import os
from collections import OrderedDict

import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors)
from mmcv.runner import OptimizerHook


def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))


def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)


def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError


def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError


def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
Expand Down
2 changes: 1 addition & 1 deletion mmdet/datasets/loader/build_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def build_dataloader(dataset,
imgs_per_gpu,
workers_per_gpu,
num_gpus,
num_gpus=1,
dist=True,
**kwargs):
if dist:
Expand Down
Loading