diff --git a/configs/lad/README.md b/configs/lad/README.md new file mode 100644 index 00000000000..907a27316f3 --- /dev/null +++ b/configs/lad/README.md @@ -0,0 +1,32 @@ +# Improving Object Detection by Label Assignment Distillation + + + +```latex +@inproceedings{nguyen2021improving, + title={Improving Object Detection by Label Assignment Distillation}, + author={Chuong H. Nguyen and Thuy C. Nguyen and Tuan N. Tang and Nam L. H. Phan}, + booktitle = {WACV}, + year={2022} +} +``` + +## Results and Models + +We provide config files to reproduce the object detection results in the +WACV 2022 paper for Improving Object Detection by Label Assignment +Distillation. + +### PAA with LAD + +| Teacher | Student | Training schedule | AP (val) | Config | +| :-------: | :-----: | :---------------: | :------: | :----------------------------------------------------: | +| -- | R-50 | 1x | 40.4 | | +| -- | R-101 | 1x | 42.6 | | +| R-101 | R-50 | 1x | 41.6 | [config](configs/lad/lad_r50_paa_r101_fpn_coco_1x.py) | +| R-50 | R-101 | 1x | 43.2 | [config](configs/lad/lad_r101_paa_r50_fpn_coco_1x.py) | + +## Note + +- Meaning of Config name: lad_r50(student model)_paa(based on paa)_r101(teacher model)_fpn(neck)_coco(dataset)_1x(12 epoch).py +- Results may fluctuate by about 0.2 mAP. diff --git a/configs/lad/lad_r101_paa_r50_fpn_coco_1x.py b/configs/lad/lad_r101_paa_r50_fpn_coco_1x.py new file mode 100644 index 00000000000..c7e3fcf6d54 --- /dev/null +++ b/configs/lad/lad_r101_paa_r50_fpn_coco_1x.py @@ -0,0 +1,121 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +teacher_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_coco/paa_r50_fpn_1x_coco_20200821-936edec3.pth' # noqa +model = dict( + type='LAD', + # student + backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', + checkpoint='torchvision://resnet101')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='LADHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + # teacher + teacher_ckpt=teacher_ckpt, + teacher_backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + teacher_neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + teacher_bbox_head=dict( + type='LADHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + score_voting=True, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) +data = dict(samples_per_gpu=8, workers_per_gpu=4) +optimizer = dict(lr=0.01) +fp16 = dict(loss_scale=512.) diff --git a/configs/lad/lad_r50_paa_r101_fpn_coco_1x.py b/configs/lad/lad_r50_paa_r101_fpn_coco_1x.py new file mode 100644 index 00000000000..3c46f7db831 --- /dev/null +++ b/configs/lad/lad_r50_paa_r101_fpn_coco_1x.py @@ -0,0 +1,120 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +teacher_ckpt = 'http://download.openmmlab.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_coco/paa_r101_fpn_1x_coco_20200821-0a1825a4.pth' # noqa +model = dict( + type='LAD', + # student + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='LADHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + # teacher + teacher_ckpt=teacher_ckpt, + teacher_backbone=dict( + type='ResNet', + depth=101, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + teacher_neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + teacher_bbox_head=dict( + type='LADHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + score_voting=True, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) +data = dict(samples_per_gpu=8, workers_per_gpu=4) +optimizer = dict(lr=0.01) +fp16 = dict(loss_scale=512.) diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index c3440fed559..e78444aca59 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -18,6 +18,7 @@ from .ga_rpn_head import GARPNHead from .gfl_head import GFLHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead +from .lad_head import LADHead from .ld_head import LDHead from .nasfcos_head import NASFCOSHead from .paa_head import PAAHead @@ -47,5 +48,5 @@ 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', - 'DecoupledSOLOLightHead' + 'DecoupledSOLOLightHead', 'LADHead' ] diff --git a/mmdet/models/dense_heads/lad_head.py b/mmdet/models/dense_heads/lad_head.py new file mode 100644 index 00000000000..d45480b64c9 --- /dev/null +++ b/mmdet/models/dense_heads/lad_head.py @@ -0,0 +1,231 @@ +import torch +from mmcv.runner import force_fp32 + +from mmdet.core import bbox_overlaps, multi_apply +from ..builder import HEADS +from .paa_head import PAAHead, levels_to_images + + +@HEADS.register_module() +class LADHead(PAAHead): + """Label Assignment Head from the paper: `Improving Object Detection by + Label Assignment Distillation `_""" + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds')) + def get_label_assignment(self, + cls_scores, + bbox_preds, + iou_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Get label assignment (from teacher). + + Args: + cls_scores (list[Tensor]): Box scores for each scale level. + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + iou_preds (list[Tensor]): iou_preds for each scale + level with shape (N, num_anchors * 1, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): Specify which bounding + boxes can be ignored when are computing the loss. + + Returns: + tuple: Returns a tuple containing label assignment variables. + + - labels (Tensor): Labels of all anchors, each with + shape (num_anchors,). + - labels_weight (Tensor): Label weights of all anchor. + each with shape (num_anchors,). + - bboxes_target (Tensor): BBox targets of all anchors. + each with shape (num_anchors, 4). + - bboxes_weight (Tensor): BBox weights of all anchors. + each with shape (num_anchors, 4). + - pos_inds_flatten (Tensor): Contains all index of positive + sample in all anchor. + - pos_anchors (Tensor): Positive anchors. + - num_pos (int): Number of positive anchors. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.anchor_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels, + ) + (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds, + pos_gt_index) = cls_reg_targets + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + bbox_preds = levels_to_images(bbox_preds) + bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] + pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, + cls_scores, bbox_preds, labels, + labels_weight, bboxes_target, + bboxes_weight, pos_inds) + + with torch.no_grad(): + reassign_labels, reassign_label_weight, \ + reassign_bbox_weights, num_pos = multi_apply( + self.paa_reassign, + pos_losses_list, + labels, + labels_weight, + bboxes_weight, + pos_inds, + pos_gt_index, + anchor_list) + num_pos = sum(num_pos) + # convert all tensor list to a flatten tensor + labels = torch.cat(reassign_labels, 0).view(-1) + flatten_anchors = torch.cat( + [torch.cat(item, 0) for item in anchor_list]) + labels_weight = torch.cat(reassign_label_weight, 0).view(-1) + bboxes_target = torch.cat(bboxes_target, + 0).view(-1, bboxes_target[0].size(-1)) + + pos_inds_flatten = ((labels >= 0) + & + (labels < self.num_classes)).nonzero().reshape(-1) + + if num_pos: + pos_anchors = flatten_anchors[pos_inds_flatten] + else: + pos_anchors = None + + label_assignment_results = (labels, labels_weight, bboxes_target, + bboxes_weight, pos_inds_flatten, + pos_anchors, num_pos) + return label_assignment_results + + def forward_train(self, + x, + label_assignment_results, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=None, + **kwargs): + """Forward train with the available label assignment (student receives + from teacher). + + Args: + x (list[Tensor]): Features from FPN. + label_assignment_results (tuple): As the outputs defined in the + function `self.get_label_assignment`. + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + + Returns: + losses: (dict[str, Tensor]): A dictionary of loss components. + """ + outs = self(x) + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + losses = self.loss( + *loss_inputs, + gt_bboxes_ignore=gt_bboxes_ignore, + label_assignment_results=label_assignment_results) + return losses + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds')) + def loss(self, + cls_scores, + bbox_preds, + iou_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None, + label_assignment_results=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + iou_preds (list[Tensor]): iou_preds for each scale + level with shape (N, num_anchors * 1, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): Specify which bounding + boxes can be ignored when are computing the loss. + label_assignment_results (tuple): As the outputs defined in the + function `self.get_label_assignment`. + + Returns: + dict[str, Tensor]: A dictionary of loss gmm_assignment. + """ + + (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds_flatten, + pos_anchors, num_pos) = label_assignment_results + + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + bbox_preds = levels_to_images(bbox_preds) + bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] + iou_preds = levels_to_images(iou_preds) + iou_preds = [item.reshape(-1, 1) for item in iou_preds] + + # convert all tensor list to a flatten tensor + cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) + bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) + iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1)) + + losses_cls = self.loss_cls( + cls_scores, + labels, + labels_weight, + avg_factor=max(num_pos, len(img_metas))) # avoid num_pos=0 + if num_pos: + pos_bbox_pred = self.bbox_coder.decode( + pos_anchors, bbox_preds[pos_inds_flatten]) + pos_bbox_target = bboxes_target[pos_inds_flatten] + iou_target = bbox_overlaps( + pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True) + losses_iou = self.loss_centerness( + iou_preds[pos_inds_flatten], + iou_target.unsqueeze(-1), + avg_factor=num_pos) + losses_bbox = self.loss_bbox( + pos_bbox_pred, pos_bbox_target, avg_factor=num_pos) + + else: + losses_iou = iou_preds.sum() * 0 + losses_bbox = bbox_preds.sum() * 0 + + return dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 690f67ddf35..637434a60b5 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -16,6 +16,7 @@ from .grid_rcnn import GridRCNN from .htc import HybridTaskCascade from .kd_one_stage import KnowledgeDistillationSingleStageDetector +from .lad import LAD from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN from .nasfcos import NASFCOS @@ -47,5 +48,5 @@ 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', - 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst' + 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD' ] diff --git a/mmdet/models/detectors/lad.py b/mmdet/models/detectors/lad.py new file mode 100644 index 00000000000..6d232197e8f --- /dev/null +++ b/mmdet/models/detectors/lad.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +from mmcv.runner import load_checkpoint + +from ..builder import DETECTORS, build_backbone, build_head, build_neck +from .kd_one_stage import KnowledgeDistillationSingleStageDetector + + +@DETECTORS.register_module() +class LAD(KnowledgeDistillationSingleStageDetector): + """Implementation of `LAD `_.""" + + def __init__(self, + backbone, + neck, + bbox_head, + teacher_backbone, + teacher_neck, + teacher_bbox_head, + teacher_ckpt, + eval_teacher=True, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(KnowledgeDistillationSingleStageDetector, + self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, + pretrained) + self.eval_teacher = eval_teacher + self.teacher_model = nn.Module() + self.teacher_model.backbone = build_backbone(teacher_backbone) + if teacher_neck is not None: + self.teacher_model.neck = build_neck(teacher_neck) + teacher_bbox_head.update(train_cfg=train_cfg) + teacher_bbox_head.update(test_cfg=test_cfg) + self.teacher_model.bbox_head = build_head(teacher_bbox_head) + if teacher_ckpt is not None: + load_checkpoint( + self.teacher_model, teacher_ckpt, map_location='cpu') + + @property + def with_teacher_neck(self): + """bool: whether the detector has a teacher_neck""" + return hasattr(self.teacher_model, 'neck') and \ + self.teacher_model.neck is not None + + def extract_teacher_feat(self, img): + """Directly extract teacher features from the backbone+neck.""" + x = self.teacher_model.backbone(img) + if self.with_teacher_neck: + x = self.teacher_model.neck(x) + return x + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + # get label assignment from the teacher + with torch.no_grad(): + x_teacher = self.extract_teacher_feat(img) + outs_teacher = self.teacher_model.bbox_head(x_teacher) + label_assignment_results = \ + self.teacher_model.bbox_head.get_label_assignment( + *outs_teacher, gt_bboxes, gt_labels, img_metas, + gt_bboxes_ignore) + + # the student use the label assignment from the teacher to learn + x = self.extract_feat(img) + losses = self.bbox_head.forward_train(x, label_assignment_results, + img_metas, gt_bboxes, gt_labels, + gt_bboxes_ignore) + return losses diff --git a/tests/test_models/test_dense_heads/test_lad_head.py b/tests/test_models/test_dense_heads/test_lad_head.py new file mode 100644 index 00000000000..0ca45f4b98e --- /dev/null +++ b/tests/test_models/test_dense_heads/test_lad_head.py @@ -0,0 +1,149 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +import torch + +from mmdet.models.dense_heads import LADHead, lad_head +from mmdet.models.dense_heads.lad_head import levels_to_images + + +def test_lad_head_loss(): + """Tests lad head loss when truth is empty and non-empty.""" + + class mock_skm: + + def GaussianMixture(self, *args, **kwargs): + return self + + def fit(self, loss): + pass + + def predict(self, loss): + components = np.zeros_like(loss, dtype=np.long) + return components.reshape(-1) + + def score_samples(self, loss): + scores = np.random.random(len(loss)) + return scores + + lad_head.skm = mock_skm() + + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + train_cfg = mmcv.Config( + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False)) + # since Focal Loss is not supported on CPU + self = LADHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)) + teacher_model = LADHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)) + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + self.init_weights() + teacher_model.init_weights() + + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + + outs_teacher = teacher_model(feat) + label_assignment_results = teacher_model.get_label_assignment( + *outs_teacher, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) + + outs = teacher_model(feat) + empty_gt_losses = self.loss(*outs, gt_bboxes, gt_labels, img_metas, + gt_bboxes_ignore, label_assignment_results) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + empty_iou_loss = empty_gt_losses['loss_iou'] + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + assert empty_iou_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + + label_assignment_results = teacher_model.get_label_assignment( + *outs_teacher, gt_bboxes, gt_labels, img_metas, gt_bboxes_ignore) + + one_gt_losses = self.loss(*outs, gt_bboxes, gt_labels, img_metas, + gt_bboxes_ignore, label_assignment_results) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + onegt_iou_loss = one_gt_losses['loss_iou'] + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' + assert onegt_iou_loss.item() > 0, 'box loss should be non-zero' + n, c, h, w = 10, 4, 20, 20 + mlvl_tensor = [torch.ones(n, c, h, w) for i in range(5)] + results = levels_to_images(mlvl_tensor) + assert len(results) == n + assert results[0].size() == (h * w * 5, c) + assert self.with_score_voting + + self = LADHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)) + cls_scores = [torch.ones(2, 4, 5, 5)] + bbox_preds = [torch.ones(2, 4, 5, 5)] + iou_preds = [torch.ones(2, 1, 5, 5)] + cfg = mmcv.Config( + dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + rescale = False + self.get_bboxes( + cls_scores, bbox_preds, iou_preds, img_metas, cfg, rescale=rescale)