diff --git a/mmdet/models/roi_heads/bbox_heads/bbox_head.py b/mmdet/models/roi_heads/bbox_heads/bbox_head.py index ac98f3c3e03..3b2e8aae083 100644 --- a/mmdet/models/roi_heads/bbox_heads/bbox_head.py +++ b/mmdet/models/roi_heads/bbox_heads/bbox_head.py @@ -87,8 +87,9 @@ def __init__(self, out_dim_reg = box_dim if reg_class_agnostic else \ box_dim * num_classes reg_predictor_cfg_ = self.reg_predictor_cfg.copy() - reg_predictor_cfg_.update( - in_features=in_channels, out_features=out_dim_reg) + if isinstance(reg_predictor_cfg_, (dict, ConfigDict)): + reg_predictor_cfg_.update( + in_features=in_channels, out_features=out_dim_reg) self.fc_reg = MODELS.build(reg_predictor_cfg_) self.debug_imgs = None if init_cfg is None: diff --git a/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py index 28d76c70c45..cb6aadd86d3 100644 --- a/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py +++ b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py @@ -95,8 +95,9 @@ def __init__(self, out_dim_reg = box_dim if self.reg_class_agnostic else \ box_dim * self.num_classes reg_predictor_cfg_ = self.reg_predictor_cfg.copy() - reg_predictor_cfg_.update( - in_features=self.reg_last_dim, out_features=out_dim_reg) + if isinstance(reg_predictor_cfg_, (dict, ConfigDict)): + reg_predictor_cfg_.update( + in_features=self.reg_last_dim, out_features=out_dim_reg) self.fc_reg = MODELS.build(reg_predictor_cfg_) if init_cfg is None: diff --git a/projects/Detic/README.md b/projects/Detic/README.md new file mode 100644 index 00000000000..4e99779342d --- /dev/null +++ b/projects/Detic/README.md @@ -0,0 +1,154 @@ +# Detecting Twenty-thousand Classes using Image-level Supervision + +## Description + +**Detic**: A **Det**ector with **i**mage **c**lasses that can use image-level labels to easily train detectors. + +

+ +> [**Detecting Twenty-thousand Classes using Image-level Supervision**](http://arxiv.org/abs/2201.02605), +> Xingyi Zhou, Rohit Girdhar, Armand Joulin, Philipp Krähenbühl, Ishan Misra, +> *ECCV 2022 ([arXiv 2201.02605](http://arxiv.org/abs/2201.02605))* + +## Usage + + + +## Installation + +Detic requires to install CLIP. + +```shell +pip install git+https://github.com/openai/CLIP.git +``` + +### Demo + +#### Inference with existing dataset vocabulary embeddings + +First, go to the Detic project folder. + +```shell +cd projects/Detic +``` + +Then, download the pre-computed CLIP embeddings from [dataset metainfo](https://github.com/facebookresearch/Detic/tree/main/datasets/metadata) to the `datasets/metadata` folder. +The CLIP embeddings will be loaded to the zero-shot classifier during inference. +For example, you can download LVIS's class name embeddings with the following command: + +```shell +wget -P datasets/metadata https://raw.githubusercontent.com/facebookresearch/Detic/main/datasets/metadata/lvis_v1_clip_a%2Bcname.npy +``` + +You can run demo like this: + +```shell +python demo.py \ + ${IMAGE_PATH} \ + ${CONFIG_PATH} \ + ${MODEL_PATH} \ + --show \ + --score-thr 0.5 \ + --dataset lvis +``` + +![image](https://user-images.githubusercontent.com/12907710/213624759-f0a2ba0c-0f5c-4424-a350-5ba5349e5842.png) + +### Inference with custom vocabularies + +- Detic can detects any class given class names by using CLIP. + +You can detect custom classes with `--class-name` command: + +``` +python demo.py \ + ${IMAGE_PATH} \ + ${CONFIG_PATH} \ + ${MODEL_PATH} \ + --show \ + --score-thr 0.3 \ + --class-name headphone webcam paper coffe +``` + +![image](https://user-images.githubusercontent.com/12907710/213624637-e9e8a313-9821-4782-a18a-4408c876852b.png) + +Note that `headphone`, `paper` and `coffe` (typo intended) are not LVIS classes. Despite the misspelled class name, Detic can produce a reasonable detection for `coffe`. + +## Results + +Here we only provide the Detic Swin-B model for the open vocabulary demo. Multi-dataset training and open-vocabulary testing will be supported in the future. + +To find more variants, please visit the [official model zoo](https://github.com/facebookresearch/Detic/blob/main/docs/MODEL_ZOO.md). + +| Backbone | Training data | Config | Download | +| :------: | :------------------------: | :-------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Swin-B | ImageNet-21K & LVIS & COCO | [config](./configs/detic_centernet2_swin-b_fpn_4x_lvis-coco-in21k.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/detic/detic_centernet2_swin-b_fpn_4x_lvis-coco-in21k/detic_centernet2_swin-b_fpn_4x_lvis-coco-in21k_20230120-0d301978.pth) | + +## Citation + +If you find Detic is useful in your research or applications, please consider giving a star 🌟 to the [official repository](https://github.com/facebookresearch/Detic) and citing Detic by the following BibTeX entry. + +```BibTeX +@inproceedings{zhou2022detecting, + title={Detecting Twenty-thousand Classes using Image-level Supervision}, + author={Zhou, Xingyi and Girdhar, Rohit and Joulin, Armand and Kr{\"a}henb{\"u}hl, Philipp and Misra, Ishan}, + booktitle={ECCV}, + year={2022} +} + +``` + +## Checklist + + + +- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [x] Finish the code + + + + - [x] Basic docstrings & proper citation + + + + - [x] Test-time correctness + + + + - [x] A full README + + + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] Metafile.yml + + + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/Detic/configs/detic_centernet2_swin-b_fpn_4x_lvis-coco-in21k.py b/projects/Detic/configs/detic_centernet2_swin-b_fpn_4x_lvis-coco-in21k.py new file mode 100644 index 00000000000..19a17aea7bc --- /dev/null +++ b/projects/Detic/configs/detic_centernet2_swin-b_fpn_4x_lvis-coco-in21k.py @@ -0,0 +1,298 @@ +_base_ = 'mmdet::common/lsj-200e_coco-detection.py' + +custom_imports = dict( + imports=['projects.Detic.detic'], allow_failed_imports=False) + +image_size = (1024, 1024) +batch_augments = [dict(type='BatchFixedSizePad', size=image_size)] + +cls_layer = dict( + type='ZeroShotClassifier', + zs_weight_path='rand', + zs_weight_dim=512, + use_bias=0.0, + norm_weight=True, + norm_temperature=50.0) +reg_layer = [ + dict(type='Linear', in_features=1024, out_features=1024), + dict(type='ReLU', inplace=True), + dict(type='Linear', in_features=1024, out_features=4) +] + +num_classes = 22047 + +model = dict( + type='CascadeRCNN', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32, + batch_augments=batch_augments), + backbone=dict( + type='SwinTransformer', + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=False), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024], + out_channels=256, + start_level=0, + add_extra_convs='on_output', + num_outs=5, + init_cfg=dict(type='Caffe2Xavier', layer='Conv2d'), + relu_before_extra_convs=True), + rpn_head=dict( + type='CenterNetRPNHead', + num_classes=1, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + conv_bias=True, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + loss_cls=dict( + type='GaussianFocalLoss', + pos_weight=0.25, + neg_weight=0.75, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + ), + roi_head=dict( + type='DeticRoIHead', + num_stages=3, + stage_loss_weights=[1, 0.5, 0.25], + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict( + type='RoIAlign', + output_size=7, + sampling_ratio=0, + use_torchvision=True), + out_channels=256, + featmap_strides=[8, 16, 32], + # approximately equal to + # canonical_box_size=224, canonical_level=4 in D2 + finest_scale=112), + bbox_head=[ + dict( + type='DeticBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=num_classes, + cls_predictor_cfg=cls_layer, + reg_predictor_cfg=reg_layer, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='DeticBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=num_classes, + cls_predictor_cfg=cls_layer, + reg_predictor_cfg=reg_layer, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.05, 0.05, 0.1, 0.1]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, + loss_weight=1.0)), + dict( + type='DeticBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=num_classes, + cls_predictor_cfg=cls_layer, + reg_predictor_cfg=reg_layer, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.033, 0.033, 0.067, 0.067]), + reg_class_agnostic=True, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) + ], + mask_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[8, 16, 32], + # approximately equal to + # canonical_box_size=224, canonical_level=4 in D2 + finest_scale=112), + mask_head=dict( + type='FCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + class_agnostic=True, + num_classes=num_classes, + loss_mask=dict( + type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + # model training and testing settings + train_cfg=dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=[ + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.6, + neg_iou_thr=0.6, + min_pos_iou=0.6, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.7, + min_pos_iou=0.7, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False), + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.8, + neg_iou_thr=0.8, + min_pos_iou=0.8, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False) + ]), + test_cfg=dict( + rpn=dict( + score_thr=0.0001, + nms_pre=1000, + max_per_img=256, + nms=dict(type='nms', iou_threshold=0.9), + min_bbox_size=0), + rcnn=dict( + score_thr=0.02, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=300, + mask_thr_binary=0.5))) + +backend = 'pillow' +test_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args=_base_.file_client_args, + imdecode_backend=backend), + dict(type='Resize', scale=(1333, 800), keep_ratio=True, backend=backend), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict(batch_size=8, num_workers=4) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader +# Enable automatic-mixed-precision training with AmpOptimWrapper. +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=dict( + type='SGD', lr=0.01 * 4, momentum=0.9, weight_decay=0.00004), + paramwise_cfg=dict(norm_decay_mult=0.)) + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=0.00025, + by_epoch=False, + begin=0, + end=4000), + dict( + type='MultiStepLR', + begin=0, + end=25, + by_epoch=True, + milestones=[22, 24], + gamma=0.1) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (8 samples per GPU) +auto_scale_lr = dict(base_batch_size=64) diff --git a/projects/Detic/demo.py b/projects/Detic/demo.py new file mode 100644 index 00000000000..d5c80c9aa5f --- /dev/null +++ b/projects/Detic/demo.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import urllib +from argparse import ArgumentParser + +import mmcv +import torch +from mmengine.logging import print_log +from mmengine.utils import ProgressBar, scandir + +from mmdet.apis import inference_detector, init_detector +from mmdet.registry import VISUALIZERS +from mmdet.utils import register_all_modules + +IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', + '.tiff', '.webp') + + +def get_file_list(source_root: str) -> [list, dict]: + """Get file list. + + Args: + source_root (str): image or video source path + + Return: + source_file_path_list (list): A list for all source file. + source_type (dict): Source type: file or url or dir. + """ + is_dir = os.path.isdir(source_root) + is_url = source_root.startswith(('http:/', 'https:/')) + is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS + + source_file_path_list = [] + if is_dir: + # when input source is dir + for file in scandir(source_root, IMG_EXTENSIONS, recursive=True): + source_file_path_list.append(os.path.join(source_root, file)) + elif is_url: + # when input source is url + filename = os.path.basename( + urllib.parse.unquote(source_root).split('?')[0]) + file_save_path = os.path.join(os.getcwd(), filename) + print(f'Downloading source file to {file_save_path}') + torch.hub.download_url_to_file(source_root, file_save_path) + source_file_path_list = [file_save_path] + elif is_file: + # when input source is single image + source_file_path_list = [source_root] + else: + print('Cannot find image file.') + + source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file) + + return source_file_path_list, source_type + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + 'img', help='Image path, include image file, dir and URL.') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--out-dir', default='./output', help='Path to output file') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--show', action='store_true', help='Show the detection results') + parser.add_argument( + '--score-thr', type=float, default=0.3, help='Bbox score threshold') + parser.add_argument( + '--dataset', type=str, help='dataset name to load the text embedding') + parser.add_argument( + '--class-name', nargs='+', type=str, help='custom class names') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # register all modules in mmdet into the registries + register_all_modules() + + # build the model from a config file and a checkpoint file + model = init_detector(args.config, args.checkpoint, device=args.device) + + if not os.path.exists(args.out_dir) and not args.show: + os.mkdir(args.out_dir) + + # init visualizer + visualizer = VISUALIZERS.build(model.cfg.visualizer) + visualizer.dataset_meta = model.dataset_meta + + # get file list + files, source_type = get_file_list(args.img) + from detic.utils import (get_class_names, get_text_embeddings, + reset_cls_layer_weight) + + # class name embeddings + if args.class_name: + dataset_classes = args.class_name + elif args.dataset: + dataset_classes = get_class_names(args.dataset) + embedding = get_text_embeddings( + dataset=args.dataset, custom_vocabulary=args.class_name) + visualizer.dataset_meta['classes'] = dataset_classes + reset_cls_layer_weight(model, embedding) + + # start detector inference + progress_bar = ProgressBar(len(files)) + for file in files: + result = inference_detector(model, file) + + img = mmcv.imread(file) + img = mmcv.imconvert(img, 'bgr', 'rgb') + + if source_type['is_dir']: + filename = os.path.relpath(file, args.img).replace('/', '_') + else: + filename = os.path.basename(file) + out_file = None if args.show else os.path.join(args.out_dir, filename) + + progress_bar.update() + + visualizer.add_datasample( + filename, + img, + data_sample=result, + draw_gt=False, + show=args.show, + wait_time=0, + out_file=out_file, + pred_score_thr=args.score_thr) + + if not args.show: + print_log( + f'\nResults have been saved at {os.path.abspath(args.out_dir)}') + + +if __name__ == '__main__': + main() diff --git a/projects/Detic/detic/__init__.py b/projects/Detic/detic/__init__.py new file mode 100644 index 00000000000..d0ad070259a --- /dev/null +++ b/projects/Detic/detic/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .centernet_rpn_head import CenterNetRPNHead +from .detic_bbox_head import DeticBBoxHead +from .detic_roi_head import DeticRoIHead +from .zero_shot_classifier import ZeroShotClassifier + +__all__ = [ + 'CenterNetRPNHead', 'DeticBBoxHead', 'DeticRoIHead', 'ZeroShotClassifier' +] diff --git a/projects/Detic/detic/centernet_rpn_head.py b/projects/Detic/detic/centernet_rpn_head.py new file mode 100644 index 00000000000..765d6dfb2b6 --- /dev/null +++ b/projects/Detic/detic/centernet_rpn_head.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Sequence, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import Scale +from mmengine import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.dense_heads import CenterNetUpdateHead +from mmdet.models.utils import multi_apply +from mmdet.registry import MODELS + +INF = 1000000000 +RangeType = Sequence[Tuple[int, int]] + + +@MODELS.register_module(force=True) # avoid bug +class CenterNetRPNHead(CenterNetUpdateHead): + """CenterNetUpdateHead is an improved version of CenterNet in CenterNet2. + + Paper link ``_. + """ + + def _init_layers(self) -> None: + """Initialize layers of the head.""" + self._init_reg_convs() + self._init_predictor() + + def _init_predictor(self) -> None: + """Initialize predictor layers of the head.""" + self.conv_cls = nn.Conv2d( + self.feat_channels, self.num_classes, 3, padding=1) + self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1) + + def forward(self, x: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + """Forward features from the upstream network. + + Args: + x (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: A tuple of each level outputs. + + - cls_scores (list[Tensor]): Box scores for each scale level, \ + each is a 4D-tensor, the channel number is num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for each \ + scale level, each is a 4D-tensor, the channel number is 4. + """ + res = multi_apply(self.forward_single, x, self.scales, self.strides) + return res + + def forward_single(self, x: Tensor, scale: Scale, + stride: int) -> Tuple[Tensor, Tensor]: + """Forward features of a single scale level. + + Args: + x (Tensor): FPN feature maps of the specified stride. + scale (:obj:`mmcv.cnn.Scale`): Learnable scale module to resize + the bbox prediction. + stride (int): The corresponding stride for feature maps. + + Returns: + tuple: scores for each class, bbox predictions of + input feature maps. + """ + for m in self.reg_convs: + x = m(x) + cls_score = self.conv_cls(x) + bbox_pred = self.conv_reg(x) + # scale the bbox_pred of different level + # float to avoid overflow when enabling FP16 + bbox_pred = scale(bbox_pred).float() + # bbox_pred needed for gradient computation has been modified + # by F.relu(bbox_pred) when run with PyTorch 1.10. So replace + # F.relu(bbox_pred) with bbox_pred.clamp(min=0) + bbox_pred = bbox_pred.clamp(min=0) + if not self.training: + bbox_pred *= stride + return cls_score, bbox_pred # score aligned, box larger + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigDict, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + + for level_idx, (cls_score, bbox_pred, score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, + score_factor_list, mlvl_priors)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + heatmap = cls_score.sigmoid() + score_thr = cfg.get('score_thr', 0) + + candidate_inds = heatmap > score_thr # 0.05 + pre_nms_top_n = candidate_inds.sum() # N + pre_nms_top_n = pre_nms_top_n.clamp(max=nms_pre) # N + + heatmap = heatmap[candidate_inds] # n + + candidate_nonzeros = candidate_inds.nonzero() # n + box_loc = candidate_nonzeros[:, 0] # n + labels = candidate_nonzeros[:, 1] # n + + bbox_pred = bbox_pred[box_loc] # n x 4 + per_grids = priors[box_loc] # n x 2 + + if candidate_inds.sum().item() > pre_nms_top_n.item(): + heatmap, top_k_indices = \ + heatmap.topk(pre_nms_top_n, sorted=False) + labels = labels[top_k_indices] + bbox_pred = bbox_pred[top_k_indices] + per_grids = per_grids[top_k_indices] + + bboxes = self.bbox_coder.decode(per_grids, bbox_pred) + # avoid invalid boxes in RoI heads + bboxes[:, 2] = torch.max(bboxes[:, 2], bboxes[:, 0] + 0.01) + bboxes[:, 3] = torch.max(bboxes[:, 3], bboxes[:, 1] + 0.01) + + mlvl_bbox_preds.append(bboxes) + mlvl_valid_priors.append(priors) + mlvl_scores.append(torch.sqrt(heatmap)) + mlvl_labels.append(labels) + + results = InstanceData() + results.bboxes = torch.cat(mlvl_bbox_preds) + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + + return self._bbox_post_process( + results=results, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) diff --git a/projects/Detic/detic/detic_bbox_head.py b/projects/Detic/detic/detic_bbox_head.py new file mode 100644 index 00000000000..9408cbe04fd --- /dev/null +++ b/projects/Detic/detic/detic_bbox_head.py @@ -0,0 +1,112 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.layers import multiclass_nms +from mmdet.models.roi_heads.bbox_heads import Shared2FCBBoxHead +from mmdet.models.utils import empty_instances +from mmdet.registry import MODELS +from mmdet.structures.bbox import get_box_tensor, scale_boxes + + +@MODELS.register_module(force=True) # avoid bug +class DeticBBoxHead(Shared2FCBBoxHead): + + def __init__(self, + *args, + init_cfg: Optional[Union[dict, ConfigDict]] = None, + **kwargs) -> None: + super().__init__(*args, init_cfg=init_cfg, **kwargs) + # reconstruct fc_cls and fc_reg since input channels are changed + assert self.with_cls + cls_channels = self.num_classes + cls_predictor_cfg_ = self.cls_predictor_cfg.copy() + cls_predictor_cfg_.update( + in_features=self.cls_last_dim, out_features=cls_channels) + self.fc_cls = MODELS.build(cls_predictor_cfg_) + + def _predict_by_feat_single( + self, + roi: Tensor, + cls_score: Tensor, + bbox_pred: Tensor, + img_meta: dict, + rescale: bool = False, + rcnn_test_cfg: Optional[ConfigDict] = None) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + roi (Tensor): Boxes to be transformed. Has shape (num_boxes, 5). + last dimension 5 arrange as (batch_index, x1, y1, x2, y2). + cls_score (Tensor): Box scores, has shape + (num_boxes, num_classes + 1). + bbox_pred (Tensor): Box energies / deltas. + has shape (num_boxes, num_classes * 4). + img_meta (dict): image information. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None + + Returns: + :obj:`InstanceData`: Detection results of each image\ + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + results = InstanceData() + if roi.shape[0] == 0: + return empty_instances([img_meta], + roi.device, + task_type='bbox', + instance_results=[results], + box_type=self.predict_box_type, + use_box_type=False, + num_classes=self.num_classes, + score_per_cls=rcnn_test_cfg is None)[0] + scores = cls_score + img_shape = img_meta['img_shape'] + num_rois = roi.size(0) + + num_classes = 1 if self.reg_class_agnostic else self.num_classes + roi = roi.repeat_interleave(num_classes, dim=0) + bbox_pred = bbox_pred.view(-1, self.bbox_coder.encode_size) + bboxes = self.bbox_coder.decode( + roi[..., 1:], bbox_pred, max_shape=img_shape) + + if rescale and bboxes.size(0) > 0: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + bboxes = scale_boxes(bboxes, scale_factor) + + # Get the inside tensor when `bboxes` is a box type + bboxes = get_box_tensor(bboxes) + box_dim = bboxes.size(-1) + bboxes = bboxes.view(num_rois, -1) + + if rcnn_test_cfg is None: + # This means that it is aug test. + # It needs to return the raw results without nms. + results.bboxes = bboxes + results.scores = scores + else: + det_bboxes, det_labels = multiclass_nms( + bboxes, + scores, + rcnn_test_cfg.score_thr, + rcnn_test_cfg.nms, + rcnn_test_cfg.max_per_img, + box_dim=box_dim) + results.bboxes = det_bboxes[:, :-1] + results.scores = det_bboxes[:, -1] + results.labels = det_labels + return results diff --git a/projects/Detic/detic/detic_roi_head.py b/projects/Detic/detic/detic_roi_head.py new file mode 100644 index 00000000000..a09c11c6e69 --- /dev/null +++ b/projects/Detic/detic/detic_roi_head.py @@ -0,0 +1,326 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence, Tuple + +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.roi_heads import CascadeRoIHead +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.test_time_augs import merge_aug_masks +from mmdet.models.utils.misc import empty_instances +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import bbox2roi, get_box_tensor +from mmdet.utils import ConfigType, InstanceList, MultiConfig + + +@MODELS.register_module(force=True) # avoid bug +class DeticRoIHead(CascadeRoIHead): + + def init_mask_head(self, mask_roi_extractor: MultiConfig, + mask_head: MultiConfig) -> None: + """Initialize mask head and mask roi extractor. + + Args: + mask_head (dict): Config of mask in mask head. + mask_roi_extractor (:obj:`ConfigDict`, dict or list): + Config of mask roi extractor. + """ + self.mask_head = MODELS.build(mask_head) + + if mask_roi_extractor is not None: + self.share_roi_extractor = False + self.mask_roi_extractor = MODELS.build(mask_roi_extractor) + else: + self.share_roi_extractor = True + self.mask_roi_extractor = self.bbox_roi_extractor + + def _refine_roi(self, x: Tuple[Tensor], rois: Tensor, + batch_img_metas: List[dict], + num_proposals_per_img: Sequence[int], **kwargs) -> tuple: + """Multi-stage refinement of RoI. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): shape (n, 5), [batch_ind, x1, y1, x2, y2] + batch_img_metas (list[dict]): List of image information. + num_proposals_per_img (sequence[int]): number of proposals + in each image. + + Returns: + tuple: + + - rois (Tensor): Refined RoI. + - cls_scores (list[Tensor]): Average predicted + cls score per image. + - bbox_preds (list[Tensor]): Bbox branch predictions + for the last stage of per image. + """ + # "ms" in variable names means multi-stage + ms_scores = [] + for stage in range(self.num_stages): + bbox_results = self._bbox_forward( + stage=stage, x=x, rois=rois, **kwargs) + + # split batch bbox prediction back to each image + cls_scores = bbox_results['cls_score'].sigmoid() + bbox_preds = bbox_results['bbox_pred'] + + rois = rois.split(num_proposals_per_img, 0) + cls_scores = cls_scores.split(num_proposals_per_img, 0) + ms_scores.append(cls_scores) + bbox_preds = bbox_preds.split(num_proposals_per_img, 0) + + if stage < self.num_stages - 1: + bbox_head = self.bbox_head[stage] + refine_rois_list = [] + for i in range(len(batch_img_metas)): + if rois[i].shape[0] > 0: + bbox_label = cls_scores[i][:, :-1].argmax(dim=1) + # Refactor `bbox_head.regress_by_class` to only accept + # box tensor without img_idx concatenated. + refined_bboxes = bbox_head.regress_by_class( + rois[i][:, 1:], bbox_label, bbox_preds[i], + batch_img_metas[i]) + refined_bboxes = get_box_tensor(refined_bboxes) + refined_rois = torch.cat( + [rois[i][:, [0]], refined_bboxes], dim=1) + refine_rois_list.append(refined_rois) + rois = torch.cat(refine_rois_list) + # ms_scores aligned + # average scores of each image by stages + cls_scores = [ + sum([score[i] for score in ms_scores]) / float(len(ms_scores)) + for i in range(len(batch_img_metas)) + ] # aligned + return rois, cls_scores, bbox_preds + + def _bbox_forward(self, stage: int, x: Tuple[Tensor], + rois: Tensor) -> dict: + """Box head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): List of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + """ + bbox_roi_extractor = self.bbox_roi_extractor[stage] + bbox_head = self.bbox_head[stage] + bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs], + rois) + # do not support caffe_c4 model anymore + cls_score, bbox_pred = bbox_head(bbox_feats) + + bbox_results = dict( + cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats) + return bbox_results + + def predict_bbox(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + rpn_results_list: InstanceList, + rcnn_test_cfg: ConfigType, + rescale: bool = False, + **kwargs) -> InstanceList: + """Perform forward propagation of the bbox head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + proposals = [res.bboxes for res in rpn_results_list] + proposal_scores = [res.scores for res in rpn_results_list] + num_proposals_per_img = tuple(len(p) for p in proposals) + rois = bbox2roi(proposals) + + if rois.shape[0] == 0: + return empty_instances( + batch_img_metas, + rois.device, + task_type='bbox', + box_type=self.bbox_head[-1].predict_box_type, + num_classes=self.bbox_head[-1].num_classes, + score_per_cls=rcnn_test_cfg is None) + # rois aligned + rois, cls_scores, bbox_preds = self._refine_roi( + x=x, + rois=rois, + batch_img_metas=batch_img_metas, + num_proposals_per_img=num_proposals_per_img, + **kwargs) + + # score reweighting in centernet2 + cls_scores = [(s * ps[:, None])**0.5 + for s, ps in zip(cls_scores, proposal_scores)] + cls_scores = [ + s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() + for s in cls_scores + ] + + # fast_rcnn_inference + results_list = self.bbox_head[-1].predict_by_feat( + rois=rois, + cls_scores=cls_scores, + bbox_preds=bbox_preds, + batch_img_metas=batch_img_metas, + rescale=rescale, + rcnn_test_cfg=rcnn_test_cfg) + return results_list + + def _mask_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: + """Mask head forward function used in both training and testing. + + Args: + stage (int): The current stage in Cascade RoI Head. + x (tuple[Tensor]): Tuple of multi-level img features. + rois (Tensor): RoIs with the shape (n, 5) where the first + column indicates batch id of each RoI. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + """ + mask_feats = self.mask_roi_extractor( + x[:self.mask_roi_extractor.num_inputs], rois) + # do not support caffe_c4 model anymore + mask_preds = self.mask_head(mask_feats) + + mask_results = dict(mask_preds=mask_preds) + return mask_results + + def mask_loss(self, x, sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList) -> dict: + """Run forward function and calculate loss for mask head in training. + + Args: + x (tuple[Tensor]): Tuple of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + + Returns: + dict: Usually returns a dictionary with keys: + + - `mask_preds` (Tensor): Mask prediction. + - `loss_mask` (dict): A dictionary of mask loss components. + """ + pos_rois = bbox2roi([res.pos_priors for res in sampling_results]) + mask_results = self._mask_forward(x, pos_rois) + + mask_loss_and_target = self.mask_head.loss_and_target( + mask_preds=mask_results['mask_preds'], + sampling_results=sampling_results, + batch_gt_instances=batch_gt_instances, + rcnn_train_cfg=self.train_cfg[-1]) + mask_results.update(mask_loss_and_target) + + return mask_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + rpn_results_list (list[:obj:`InstanceData`]): List of region + proposals. + batch_data_samples (list[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict[str, Tensor]: A dictionary of loss components + """ + raise NotImplementedError + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: List[InstanceData], + rescale: bool = False) -> List[InstanceData]: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + bboxes = [res.bboxes for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + num_mask_rois_per_img = [len(res) for res in results_list] + aug_masks = [] + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results['mask_preds'] + # split batch mask prediction back to each image + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + aug_masks.append([m.sigmoid().detach() for m in mask_preds]) + + merged_masks = [] + for i in range(len(batch_img_metas)): + aug_mask = [mask[i] for mask in aug_masks] + merged_mask = merge_aug_masks(aug_mask, batch_img_metas[i]) + merged_masks.append(merged_mask) + results_list = self.mask_head.predict_by_feat( + mask_preds=merged_masks, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale, + activate_map=True) + return results_list diff --git a/projects/Detic/detic/text_encoder.py b/projects/Detic/detic/text_encoder.py new file mode 100644 index 00000000000..f0024efaf30 --- /dev/null +++ b/projects/Detic/detic/text_encoder.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch +import torch.nn as nn + + +class CLIPTextEncoder(nn.Module): + + def __init__(self, model_name='ViT-B/32'): + super().__init__() + import clip + from clip.simple_tokenizer import SimpleTokenizer + self.tokenizer = SimpleTokenizer() + pretrained_model, _ = clip.load(model_name, device='cpu') + self.clip = pretrained_model + + @property + def device(self): + return self.clip.device + + @property + def dtype(self): + return self.clip.dtype + + def tokenize(self, + texts: Union[str, List[str]], + context_length: int = 77) -> torch.LongTensor: + if isinstance(texts, str): + texts = [texts] + + sot_token = self.tokenizer.encoder['<|startoftext|>'] + eot_token = self.tokenizer.encoder['<|endoftext|>'] + all_tokens = [[sot_token] + self.tokenizer.encode(text) + [eot_token] + for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + st = torch.randint(len(tokens) - context_length + 1, + (1, ))[0].item() + tokens = tokens[st:st + context_length] + result[i, :len(tokens)] = torch.tensor(tokens) + + return result + + def forward(self, text): + text = self.tokenize(text) + text_features = self.clip.encode_text(text) + return text_features diff --git a/projects/Detic/detic/utils.py b/projects/Detic/detic/utils.py new file mode 100644 index 00000000000..56d4fd429d7 --- /dev/null +++ b/projects/Detic/detic/utils.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn.functional as F +from mmengine.logging import print_log + +from .text_encoder import CLIPTextEncoder + +# download from +# https://github.com/facebookresearch/Detic/tree/main/datasets/metadata +DATASET_EMBEDDINGS = { + 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy', + 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy', + 'openimages': 'datasets/metadata/oid_clip_a+cname.npy', + 'coco': 'datasets/metadata/coco_clip_a+cname.npy', +} + + +def get_text_embeddings(dataset=None, + custom_vocabulary=None, + prompt_prefix='a '): + assert (dataset is None) ^ (custom_vocabulary is None), \ + 'Either `dataset` or `custom_vocabulary` should be specified.' + if dataset: + if dataset in DATASET_EMBEDDINGS: + return DATASET_EMBEDDINGS[dataset] + else: + custom_vocabulary = get_class_names(dataset) + + text_encoder = CLIPTextEncoder() + text_encoder.eval() + texts = [prompt_prefix + x for x in custom_vocabulary] + print_log( + f'Computing text embeddings for {len(custom_vocabulary)} classes.') + embeddings = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() + return embeddings + + +def get_class_names(dataset): + if dataset == 'coco': + from mmdet.datasets import CocoDataset + class_names = CocoDataset.METAINFO['classes'] + elif dataset == 'cityscapes': + from mmdet.datasets import CityscapesDataset + class_names = CityscapesDataset.METAINFO['classes'] + elif dataset == 'voc': + from mmdet.datasets import VOCDataset + class_names = VOCDataset.METAINFO['classes'] + elif dataset == 'openimages': + from mmdet.datasets import OpenImagesDataset + class_names = OpenImagesDataset.METAINFO['classes'] + elif dataset == 'lvis': + from mmdet.datasets import LVISV1Dataset + class_names = LVISV1Dataset.METAINFO['classes'] + else: + raise TypeError(f'Invalid type for dataset name: {type(dataset)}') + return class_names + + +def reset_cls_layer_weight(model, weight): + if type(weight) == str: + print_log(f'Resetting cls_layer_weight from file: {weight}') + zs_weight = torch.tensor( + np.load(weight), + dtype=torch.float32).permute(1, 0).contiguous() # D x C + else: + zs_weight = weight + zs_weight = torch.cat( + [zs_weight, zs_weight.new_zeros( + (zs_weight.shape[0], 1))], dim=1) # D x (C + 1) + zs_weight = F.normalize(zs_weight, p=2, dim=0) + zs_weight = zs_weight.to('cuda') + num_classes = zs_weight.shape[-1] + + for bbox_head in model.roi_head.bbox_head: + bbox_head.num_classes = num_classes + del bbox_head.fc_cls.zs_weight + bbox_head.fc_cls.zs_weight = zs_weight diff --git a/projects/Detic/detic/zero_shot_classifier.py b/projects/Detic/detic/zero_shot_classifier.py new file mode 100644 index 00000000000..35c9e49285c --- /dev/null +++ b/projects/Detic/detic/zero_shot_classifier.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from mmdet.registry import MODELS + + +@MODELS.register_module(force=True) # avoid bug +class ZeroShotClassifier(nn.Module): + + def __init__( + self, + in_features: int, + out_features: int, # num_classes + zs_weight_path: str, + zs_weight_dim: int = 512, + use_bias: float = 0.0, + norm_weight: bool = True, + norm_temperature: float = 50.0, + ): + super().__init__() + num_classes = out_features + self.norm_weight = norm_weight + self.norm_temperature = norm_temperature + + self.use_bias = use_bias < 0 + if self.use_bias: + self.cls_bias = nn.Parameter(torch.ones(1) * use_bias) + + self.linear = nn.Linear(in_features, zs_weight_dim) + + if zs_weight_path == 'rand': + zs_weight = torch.randn((zs_weight_dim, num_classes)) + nn.init.normal_(zs_weight, std=0.01) + else: + zs_weight = torch.tensor( + np.load(zs_weight_path), + dtype=torch.float32).permute(1, 0).contiguous() # D x C + zs_weight = torch.cat( + [zs_weight, zs_weight.new_zeros( + (zs_weight_dim, 1))], dim=1) # D x (C + 1) + + if self.norm_weight: + zs_weight = F.normalize(zs_weight, p=2, dim=0) + + if zs_weight_path == 'rand': + self.zs_weight = nn.Parameter(zs_weight) + else: + self.register_buffer('zs_weight', zs_weight) + + assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape + + def forward(self, x, classifier=None): + ''' + Inputs: + x: B x D' + classifier_info: (C', C' x D) + ''' + x = self.linear(x) + if classifier is not None: + zs_weight = classifier.permute(1, 0).contiguous() # D x C' + zs_weight = F.normalize(zs_weight, p=2, dim=0) \ + if self.norm_weight else zs_weight + else: + zs_weight = self.zs_weight + if self.norm_weight: + x = self.norm_temperature * F.normalize(x, p=2, dim=1) + x = torch.mm(x, zs_weight) + if self.use_bias: + x = x + self.cls_bias + return x