diff --git a/engine.py b/engine.py index 6bda0d5..fd53703 100644 --- a/engine.py +++ b/engine.py @@ -46,7 +46,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, with torch.cuda.amp.autocast(enabled=args.amp): if need_tgt_for_training: outputs, mask_dict = model(samples, dn_args=(targets, args.scalar, args.label_noise_scale, - args.box_noise_scale, args.num_patterns)) + args.box_noise_scale, args.num_patterns, args.contrastive)) loss_dict = criterion(outputs, targets, mask_dict) else: outputs = model(samples) @@ -82,6 +82,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, scaler.update() else: # original backward function + optimizer.zero_grad() losses.backward() if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) diff --git a/main.py b/main.py index da3f607..8c7bd5a 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,10 @@ import random import time from pathlib import Path +from os import path import os, sys from typing import Optional - from util.logger import setup_logger import numpy as np @@ -24,6 +24,7 @@ from datasets import build_dataset, get_coco_api_from_dataset from engine import evaluate, train_one_epoch from models import build_DABDETR, build_dab_deformable_detr, build_dab_deformable_detr_deformable_encoder_only +from models import build_dab_dino_deformable_detr from util.utils import clean_state_dict @@ -39,6 +40,12 @@ def get_args_parser(): help="label noise ratio to flip") parser.add_argument('--box_noise_scale', default=0.4, type=float, help="box noise scale to shift and scale") + parser.add_argument('--contrastive', action="store_true", + help="use contrastive training.") + parser.add_argument('--use_mqs', action="store_true", + help="use mixed query selection from DINO.") + parser.add_argument('--use_lft', action="store_true", + help="use look forward twice from DINO.") # about lr parser.add_argument('--lr', default=1e-4, type=float, @@ -50,6 +57,7 @@ def get_args_parser(): parser.add_argument('--weight_decay', default=1e-4, type=float) parser.add_argument('--epochs', default=50, type=int) parser.add_argument('--lr_drop', default=40, type=int) + parser.add_argument('--override_resumed_lr_drop', default=False, action='store_true') parser.add_argument('--drop_lr_now', action="store_true", help="load checkpoint and drop for 12epoch setting") parser.add_argument('--save_checkpoint_interval', default=10, type=int) parser.add_argument('--clip_max_norm', default=0.1, type=float, @@ -57,7 +65,7 @@ def get_args_parser(): # Model parameters parser.add_argument('--modelname', '-m', type=str, required=True, choices=['dn_dab_detr', 'dn_dab_deformable_detr', - 'dn_dab_deformable_detr_deformable_encoder_only']) + 'dn_dab_deformable_detr_deformable_encoder_only', 'dn_dab_dino_deformable_detr']) parser.add_argument('--frozen_weights', type=str, default=None, help="Path to the pretrained model. If set, only the mask head will be trained") @@ -94,6 +102,8 @@ def get_args_parser(): help="Number of attention heads inside the transformer's attentions") parser.add_argument('--num_queries', default=300, type=int, help="Number of query slots") + parser.add_argument('--num_results', default=300, type=int, + help="Number of detection results") parser.add_argument('--pre_norm', action='store_true', help="Using pre-norm in the Transformer blocks.") parser.add_argument('--num_select', default=300, type=int, @@ -170,7 +180,7 @@ def get_args_parser(): parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--debug', action='store_true', help="For debug only. It will perform only a few steps during trainig and val.") - parser.add_argument('--find_unused_params', action='store_true') + parser.add_argument('--find_unused_params', default=False, action='store_true') parser.add_argument('--save_results', action='store_true', help="For eval only. Save the outputs for all images.") @@ -196,6 +206,8 @@ def build_model_main(args): model, criterion, postprocessors = build_dab_deformable_detr(args) elif args.modelname.lower() == 'dn_dab_deformable_detr_deformable_encoder_only': model, criterion, postprocessors = build_dab_deformable_detr_deformable_encoder_only(args) + elif args.modelname.lower() == 'dn_dab_dino_deformable_detr': + model, criterion, postprocessors = build_dab_dino_deformable_detr(args) else: raise NotImplementedError @@ -222,8 +234,8 @@ def main(args): logger.info('local_rank: {}'.format(args.local_rank)) logger.info("args: " + str(args) + '\n') - if args.frozen_weights is not None: - assert args.masks, "Frozen training is meant for segmentation only" + #if args.frozen_weights is not None: + # assert args.masks, "Frozen training is meant for segmentation only" print(args) device = torch.device(args.device) @@ -293,7 +305,7 @@ def main(args): model_without_ddp.detr.load_state_dict(checkpoint['model']) output_dir = Path(args.output_dir) - if args.resume: + if args.resume and (args.resume.startswith('https') or path.exists(args.resume)): if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) @@ -303,6 +315,11 @@ def main(args): if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + if args.override_resumed_lr_drop: + print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.') + lr_scheduler.step_size = args.lr_drop + lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + lr_scheduler.step(lr_scheduler.last_epoch) args.start_epoch = checkpoint['epoch'] + 1 if args.drop_lr_now: diff --git a/models/__init__.py b/models/__init__.py index e65990b..3d1b227 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -7,3 +7,4 @@ from .DN_DAB_DETR import build_DABDETR from .dn_dab_deformable_detr import build_dab_deformable_detr from .dn_dab_deformable_detr_deformable_encoder_only import build_dab_deformable_detr_deformable_encoder_only +from .dn_dab_dino_deformable_detr import build_dab_dino_deformable_detr diff --git a/models/dn_dab_dino_deformable_detr/__init__.py b/models/dn_dab_dino_deformable_detr/__init__.py new file mode 100644 index 0000000..6ea58f0 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/__init__.py @@ -0,0 +1,14 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +from .dab_deformable_detr import build_dab_dino_deformable_detr diff --git a/models/dn_dab_dino_deformable_detr/backbone.py b/models/dn_dab_dino_deformable_detr/backbone.py new file mode 100644 index 0000000..ef9dec7 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/backbone.py @@ -0,0 +1,142 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n, eps=1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + self.eps = eps + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = self.eps + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} + self.strides = [8, 16, 32] + self.num_channels = [512, 1024, 2048] + else: + return_layers = {'layer4': "0"} + self.strides = [32] + self.num_channels = [2048] + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + norm_layer = FrozenBatchNorm2d + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=norm_layer) + assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" + super().__init__(backbone, train_backbone, return_interm_layers) + if dilation: + self.strides[-1] = self.strides[-1] // 2 + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + self.strides = backbone.strides + self.num_channels = backbone.num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in sorted(xs.items()): + out.append(x) + + # position encoding + for x in out: + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks or (args.num_feature_levels > 1) + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + return model diff --git a/models/dn_dab_dino_deformable_detr/dab_deformable_detr.py b/models/dn_dab_dino_deformable_detr/dab_deformable_detr.py new file mode 100644 index 0000000..a7eca9a --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/dab_deformable_detr.py @@ -0,0 +1,585 @@ +# ------------------------------------------------------------------------ +# DN-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +import os +import torch +import torch.nn.functional as F +from torch import nn +import math + +from util import box_ops +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized, inverse_sigmoid) + +from .backbone import build_backbone +from .matcher import build_matcher +from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, + dice_loss, sigmoid_focal_loss) +from .deformable_transformer import build_deforamble_transformer +import copy + +from .dn_components import prepare_for_dn, dn_post_process, compute_dn_loss + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DABDeformableDETR(nn.Module): + """ This is the DAB-Deformable-DETR for object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels, + aux_loss=True, with_box_refine=True, two_stage=False, + use_dab=True, + num_patterns=0, + random_refpoints_xy=False, + ): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + with_box_refine: iterative bounding box refinement + two_stage: two-stage Deformable DETR + use_dab: using dynamic anchor boxes formulation + num_patterns: number of pattern embeddings + random_refpoints_xy: random init the x,y of anchor boxes and freeze them. (It sometimes helps to improve the performance) + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + self.hidden_dim = hidden_dim = transformer.d_model + self.num_classes = num_classes + self.class_embed = nn.Linear(hidden_dim, num_classes) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.num_feature_levels = num_feature_levels + self.use_dab = use_dab + self.num_patterns = num_patterns + self.random_refpoints_xy = random_refpoints_xy + self.two_stage = two_stage + # dn label enc + self.label_enc = nn.Embedding(num_classes + 1, hidden_dim - 1) # # for indicator + if not use_dab: + self.query_embed = nn.Embedding(num_queries, hidden_dim*2) + else: + if not self.two_stage: + self.tgt_embed = nn.Embedding(num_queries, hidden_dim-1) # for indicator + self.refpoint_embed = nn.Embedding(num_queries, 4) + + if random_refpoints_xy: + # import ipdb; ipdb.set_trace() + self.refpoint_embed.weight.data[:, :2].uniform_(0,1) + self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2]) + self.refpoint_embed.weight.data[:, :2].requires_grad = False + + if self.num_patterns > 0: + self.patterns_embed = nn.Embedding(self.num_patterns, hidden_dim) + + if num_feature_levels > 1: + num_backbone_outs = len(backbone.strides) + input_proj_list = [] + for _ in range(num_backbone_outs): + in_channels = backbone.num_channels[_] + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )) + for _ in range(num_feature_levels - num_backbone_outs): + input_proj_list.append(nn.Sequential( + nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), + nn.GroupNorm(32, hidden_dim), + )) + in_channels = hidden_dim + self.input_proj = nn.ModuleList(input_proj_list) + else: + self.input_proj = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), + nn.GroupNorm(32, hidden_dim), + )]) + self.backbone = backbone + self.aux_loss = aux_loss + self.with_box_refine = with_box_refine + + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for proj in self.input_proj: + nn.init.xavier_uniform_(proj[0].weight, gain=1) + nn.init.constant_(proj[0].bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers + + if with_box_refine: + self.class_embed = _get_clones(self.class_embed, num_pred) + self.bbox_embed = _get_clones(self.bbox_embed, num_pred) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + # hack implementation for iterative bounding box refinement + self.transformer.decoder.bbox_embed = self.bbox_embed + else: + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + self.transformer.decoder.bbox_embed = None + if two_stage: + # hack implementation for two-stage + self.transformer.decoder.class_embed = self.class_embed + for box_embed in self.bbox_embed: + nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) + + def forward(self, samples: NestedTensor, dn_args=None): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + srcs = [] + masks = [] + for l, feat in enumerate(features): + src, mask = feat.decompose() + srcs.append(self.input_proj[l](src)) + masks.append(mask) + assert mask is not None + # import ipdb; ipdb.set_trace() + + if self.num_feature_levels > len(srcs): + _len_srcs = len(srcs) + for l in range(_len_srcs, self.num_feature_levels): + if l == _len_srcs: + src = self.input_proj[l](features[-1].tensors) + else: + src = self.input_proj[l](srcs[-1]) + m = samples.mask + mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] + pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) + srcs.append(src) + masks.append(mask) + pos.append(pos_l) + + #if self.two_stage: + # assert NotImplementedError + #elif self.use_dab: + if self.use_dab: + if not self.two_stage: + if self.num_patterns == 0: + tgt_all_embed = tgt_embed = self.tgt_embed.weight # nq, 256 + refanchor = self.refpoint_embed.weight # nq, 4 + # query_embeds = torch.cat((tgt_embed, refanchor), dim=1) + else: + # multi patterns is not used in this version + assert NotImplementedError + else: + tgt_all_embed = None + refanchor = None + else: + assert NotImplementedError + + # prepare for dn + input_query_label, input_query_bbox, attn_mask, mask_dict = \ + prepare_for_dn(dn_args, tgt_all_embed, refanchor, src.size(0), self.training, self.num_queries, self.num_classes, + self.hidden_dim, self.label_enc) + if input_query_label is not None and input_query_bbox is not None: + # sometimes the target is empty, add a zero part of label_enc to avoid unused parameters + input_query_label += self.label_enc.weight[0][0]*torch.tensor(0).cuda() + query_embeds = torch.cat((input_query_label, input_query_bbox), dim=2) + else: + query_embeds = None + + hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = \ + self.transformer(srcs, masks, pos, query_embeds, attn_mask) + levels = hs.shape[0] + + outputs_classes = [] + outputs_coords = [] + #for lvl in range(hs.shape[0]): + for lvl in range(levels): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](hs[lvl]) + tmp = self.bbox_embed[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + outputs_coord = torch.stack(outputs_coords) + # dn post process + outputs_class, outputs_coord = dn_post_process(outputs_class, outputs_coord, mask_dict) + + out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + + if self.two_stage: + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord} + if os.environ.get('IPDB_SHILONG_DEBUG') == 'INFO': + import ipdb; ipdb.set_trace() + return out, mask_dict + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_alpha = focal_alpha + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:,:,:-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] + losses = {'loss_ce': loss_ce} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + + losses = {} + losses['loss_bbox'] = loss_bbox.sum() / num_boxes + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def loss_masks(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the masks: the focal loss and the dice loss. + targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] + """ + assert "pred_masks" in outputs + + src_idx = self._get_src_permutation_idx(indices) + tgt_idx = self._get_tgt_permutation_idx(indices) + + src_masks = outputs["pred_masks"] + + # TODO use valid to mask invalid areas due to padding in loss + target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets]).decompose() + target_masks = target_masks.to(src_masks) + + src_masks = src_masks[src_idx] + # upsample predictions to the target size + src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], + mode="bilinear", align_corners=False) + src_masks = src_masks[:, 0].flatten(1) + + target_masks = target_masks[tgt_idx].flatten(1) + + losses = { + "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), + "loss_dice": dice_loss(src_masks, target_masks, num_boxes), + } + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + 'masks': self.loss_masks + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets, mask_dict=None): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs' and k != 'enc_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + losses = {} + for loss in self.losses: + kwargs = {} + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + if 'enc_outputs' in outputs: + enc_outputs = outputs['enc_outputs'] + bin_targets = copy.deepcopy(targets) + for bt in bin_targets: + bt['labels'] = torch.zeros_like(bt['labels']) + if os.environ.get('IPDB_SHILONG_DEBUG') == 'INFO': + import ipdb; ipdb.set_trace() + indices = self.matcher(enc_outputs, bin_targets) + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs['log'] = False + l_dict = self.get_loss(loss, enc_outputs, bin_targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_enc': v for k, v in l_dict.items()} + losses.update(l_dict) + + # dn loss computation + aux_num = 0 + if 'aux_outputs' in outputs: + aux_num = len(outputs['aux_outputs']) + dn_losses = compute_dn_loss(mask_dict, self.training, aux_num, self.focal_alpha) + losses.update(dn_losses) + + return losses + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + def __init__(self, num_select=300, nms_iou_threshold=-1) -> None: + super().__init__() + self.num_select = num_select + self.nms_iou_threshold = nms_iou_threshold + + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + num_select = self.num_select + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = out_logits.sigmoid() + topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1) + scores = topk_values + topk_boxes = topk_indexes // out_logits.shape[2] + labels = topk_indexes % out_logits.shape[2] + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build_dab_dino_deformable_detr(args): + num_classes = 20 if args.dataset_file != 'coco' else 91 + if args.dataset_file == "coco_panoptic": + num_classes = 250 + device = torch.device(args.device) + + backbone = build_backbone(args) + + transformer = build_deforamble_transformer(args) + model = DABDeformableDETR( + backbone, + transformer, + num_classes=num_classes, + num_queries=args.num_queries, + num_feature_levels=args.num_feature_levels, + aux_loss=args.aux_loss, + two_stage=args.two_stage, + use_dab=True, + num_patterns=args.num_patterns, + random_refpoints_xy=args.random_refpoints_xy + ) + if args.masks: + model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) + matcher = build_matcher(args) + weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef} + weight_dict['loss_giou'] = args.giou_loss_coef + # dn loss + if args.use_dn: + weight_dict['tgt_loss_ce'] = args.cls_loss_coef + weight_dict['tgt_loss_bbox'] = args.bbox_loss_coef + weight_dict['tgt_loss_giou'] = args.giou_loss_coef + if args.masks: + weight_dict["loss_mask"] = args.mask_loss_coef + weight_dict["loss_dice"] = args.dice_loss_coef + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + num_layers = args.dec_layers + for i in range(num_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + if args.masks: + losses += ["masks"] + # num_classes, matcher, weight_dict, losses, focal_alpha=0.25 + criterion = SetCriterion(num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha) + criterion.to(device) + postprocessors = {'bbox': PostProcess(num_select=args.num_results)} + if args.masks: + postprocessors['segm'] = PostProcessSegm() + if args.dataset_file == "coco_panoptic": + is_thing_map = {i: i <= 90 for i in range(201)} + postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) + + return model, criterion, postprocessors diff --git a/models/dn_dab_dino_deformable_detr/deformable_transformer.py b/models/dn_dab_dino_deformable_detr/deformable_transformer.py new file mode 100644 index 0000000..04fa5d4 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/deformable_transformer.py @@ -0,0 +1,560 @@ +# ------------------------------------------------------------------------ +# DN-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +import copy +from typing import Optional, List +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ + +from util.misc import inverse_sigmoid +from .ops.modules import MSDeformAttn + + +class DeformableTransformer(nn.Module): + def __init__(self, d_model=256, nhead=8, + num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, + activation="relu", return_intermediate_dec=False, + num_feature_levels=4, dec_n_points=4, enc_n_points=4, + two_stage=False, two_stage_num_proposals=300, + use_dab=False, use_mqs=False, use_look_forward_twice=False, high_dim_query_update=False, no_sine_embed=False): + super().__init__() + + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + self.use_dab = use_dab + self.use_mqs = use_mqs + self.return_intermediate = return_intermediate_dec + + encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, enc_n_points) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, dec_n_points) + self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, + use_dab=use_dab, use_look_forward_twice=use_look_forward_twice, + d_model=d_model, high_dim_query_update=high_dim_query_update, no_sine_embed=no_sine_embed) + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + self.num_queries = two_stage_num_proposals + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + if use_mqs: + self.tgt_embed = nn.Embedding(self.num_queries, d_model) + else: + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + nn.init.normal_(self.tgt_embed.weight.data) + else: + self.tgt_embed = None + if not self.use_dab: + self.reference_points = nn.Linear(d_model, 2) + + self.high_dim_query_update = high_dim_query_update + if high_dim_query_update: + assert not self.use_dab, "use_dab must be True" + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + if not self.two_stage and not self.use_dab: + xavier_uniform_(self.reference_points.weight.data, gain=1.0) + constant_(self.reference_points.bias.data, 0.) + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals, d_model): + num_pos_feats = d_model // 4 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, query_embed=None, attn_mask=None): + """ + Input: + - srcs: List([bs, c, h, w]) + - masks: List([bs, h, w]) + """ + assert self.two_stage or query_embed is not None + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + + src = src.flatten(2).transpose(1, 2) # bs, hw, c + mask = mask.flatten(1) # bs, hw + pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) # bs, \sum{hxw}, c + mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder + memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) + + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + + # hack implementation for two-stage Deformable DETR + + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + + # MQS is dab + mqs + if self.use_mqs: + assert self.use_dab + reference_points_mqs = reference_points + + # sometimes the target is empty, add a zero part of query_embed to avoid unused parameters + reference_points_mqs += self.tgt_embed.weight[0][0]*torch.tensor(0).cuda() + tgt_mqs = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model + + # query_embed is non None when training. + if query_embed is not None: + reference_points_dab = query_embed[..., self.d_model:].sigmoid() + tgt_dab = query_embed[..., :self.d_model] + + reference_points = torch.cat([reference_points_dab, reference_points_mqs], dim=1) + tgt = torch.cat([tgt_dab, tgt_mqs], dim=1) + else: + reference_points = reference_points_mqs + tgt = tgt_mqs + else: + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact, self.d_model))) + query_embed, tgt = torch.split(pos_trans_out, c, dim=2) + else: + if self.use_dab: + reference_points = query_embed[..., self.d_model:].sigmoid() + tgt = query_embed[..., :self.d_model] + # tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + else: + query_embed, tgt = torch.split(query_embed, c, dim=1) + query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) + tgt = tgt.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_embed).sigmoid() + # bs, num_quires, 2 + + init_reference_out = reference_points + + # decoder + hs, inter_references = self.decoder(tgt, reference_points, memory, + spatial_shapes, level_start_index, valid_ratios, + query_pos=query_embed if not self.use_dab else None, + src_padding_mask=mask_flatten, attn_mask=attn_mask) + + if self.return_intermediate: + reference_points = inter_references[-1] + else: + reference_points = inter_references + + if not self.two_stage: + enc_outputs_class, enc_outputs_coord_unact = None, None + + + inter_references_out = inter_references + #return hs, init_reference_out, inter_references_out, None, None + return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # self attention + self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout2 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): + # self attention + src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + + return src + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): + """ + Input: + - src: [bs, sum(hi*wi), 256] + - spatial_shapes: h,w of each level [num_level, 2] + - level_start_index: [num_level] start point of level in sum(hi*wi). + - valid_ratios: [bs, num_level, 2] + - pos: pos embed for src. [bs, sum(hi*wi), 256] + - padding_mask: [bs, sum(hi*wi)] + Intermedia: + - reference_points: [bs, sum(hi*wi), num_lebel, 2] + """ + output = src + # bs, sum(hi*wi), 256 + # import ipdb; ipdb.set_trace() + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) + for _, layer in enumerate(self.layers): + output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) + + return output + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__(self, d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, + src_padding_mask=None, self_attn_mask=None): + # self attention + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), attn_mask=self_attn_mask)[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # cross attention + tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + reference_points, + src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + +class DeformableTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, return_intermediate=False, use_dab=False, use_look_forward_twice=False, d_model=256, high_dim_query_update=False, no_sine_embed=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.use_dab = use_dab + self.d_model = d_model + self.no_sine_embed = no_sine_embed + self.use_look_forward_twice = use_look_forward_twice + + if use_dab: + self.query_scale = MLP(d_model, d_model, d_model, 2) + if self.no_sine_embed: + self.ref_point_head = MLP(4, d_model, d_model, 3) + else: + self.ref_point_head = MLP(2 * d_model, d_model, d_model, 2) + self.high_dim_query_update = high_dim_query_update + if high_dim_query_update: + self.high_dim_query_proj = MLP(d_model, d_model, d_model, 2) + + + def forward(self, tgt, reference_points, src, src_spatial_shapes, + src_level_start_index, src_valid_ratios, + query_pos=None, src_padding_mask=None, attn_mask=None): + output = tgt + if self.use_dab: + assert query_pos is None + # bs = src.shape[0] + # reference_points = reference_points[None].repeat(bs, 1, 1) # bs, nq, 4(xywh) + + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + # import ipdb; ipdb.set_trace() + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] \ + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] # bs, nq, 4, 4 + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + if self.use_dab: + # import ipdb; ipdb.set_trace() + if self.no_sine_embed: + raw_query_pos = self.ref_point_head(reference_points_input) + else: + query_sine_embed = gen_sineembed_for_position(reference_points_input[:, :, 0, :]) # bs, nq, 256*2 + raw_query_pos = self.ref_point_head(query_sine_embed) # bs, nq, 256 + pos_scale = self.query_scale(output) if lid != 0 else 1 + query_pos = pos_scale * raw_query_pos + if self.high_dim_query_update and lid != 0: + query_pos = query_pos + self.high_dim_query_proj(output) + + + output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, + src_padding_mask, self_attn_mask=attn_mask) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + + if self.use_look_forward_twice: + intermediate_reference_points.append(new_reference_points) + else: + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +def build_deforamble_transformer(args): + return DeformableTransformer( + d_model=args.hidden_dim, + nhead=args.nheads, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + dim_feedforward=args.dim_feedforward, + dropout=args.dropout, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=args.num_feature_levels, + dec_n_points=args.dec_n_points, + enc_n_points=args.enc_n_points, + two_stage=args.two_stage, + two_stage_num_proposals=args.num_queries, + use_dab=True, + use_mqs=args.use_mqs, + use_look_forward_twice=args.use_lft) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +def gen_sineembed_for_position(pos_tensor): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / 128) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) + return pos diff --git a/models/dn_dab_dino_deformable_detr/dn_components.py b/models/dn_dab_dino_deformable_detr/dn_components.py new file mode 100644 index 0000000..281bf08 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/dn_components.py @@ -0,0 +1,369 @@ +# ------------------------------------------------------------------------ +# DN-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] + + +import torch +from util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized, inverse_sigmoid) +# from .DABDETR import sigmoid_focal_loss +from util import box_ops +import torch.nn.functional as F + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + + return loss.mean(1).sum() / num_boxes + +def prepare_for_dn(dn_args, tgt_weight, embedweight, batch_size, training, num_queries, num_classes, hidden_dim, label_enc): + """ + The major difference from DN-DAB-DETR is that the author process pattern embedding pattern embedding in its detector + forward function and use learnable tgt embedding, so we change this function a little bit. + :param dn_args: targets, scalar, label_noise_scale, box_noise_scale, num_patterns + :param tgt_weight: use learnbal tgt in dab deformable detr + :param embedweight: positional anchor queries + :param batch_size: bs + :param training: if it is training or inference + :param num_queries: number of queires + :param num_classes: number of classes + :param hidden_dim: transformer hidden dim + :param label_enc: encode labels in dn + :return: + """ + + if training: + targets, scalar, label_noise_scale, box_noise_scale, num_patterns, contrastive = dn_args + else: + num_patterns = dn_args + + if num_patterns == 0: + num_patterns = 1 + if tgt_weight is not None and embedweight is not None: + indicator0 = torch.zeros([num_queries * num_patterns, 1]).cuda() + # sometimes the target is empty, add a zero part of label_enc to avoid unused parameters + tgt = torch.cat([tgt_weight, indicator0], dim=1) + label_enc.weight[0][0]*torch.tensor(0).cuda() + refpoint_emb = embedweight + else: + tgt = None + refpoint_emb = None + + if training: + if contrastive: + new_targets = [] + for t in targets: + new_t = {} + new_t['labels'] = torch.cat([t['labels'], torch.tensor(len(t['labels']) * [num_classes], dtype=torch.int64).cuda()], dim=0) + new_t['boxes'] = torch.cat([t['boxes'], t['boxes']], dim=0) + new_targets.append(new_t) + targets = new_targets + known = [(torch.ones_like(t['labels'])).cuda() for t in targets] # [ [ 1, 1], [1, 1, 1], ... ] + know_idx = [torch.nonzero(t) for t in known] # [ [0, 1], [0, 1, 2], ... ] + known_num = [sum(k) for k in known] # [ 2, 3, ... ] + + # to use fix number of dn queries + if int(max(known_num)) == 0: + scalar = 1 + elif scalar >= 100 and int(max(known_num))>0: + scalar=scalar//int(max(known_num)) + + if scalar <= 0: + scalar = 1 + + # can be modified to selectively denosie some label or boxes; also known label prediction + unmask_bbox = unmask_label = torch.cat(known) + # torch.cat(known) = [1, 1, 1, 1, 1, ... ] + labels = torch.cat([t['labels'] for t in targets]) + boxes = torch.cat([t['boxes'] for t in targets]) + batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) + # batch_idx = [ 0, 0, 1, 1, 1, ... ] + + known_indice = torch.nonzero(unmask_label + unmask_bbox) + # known_indice = [ 0, 1, 2, 3, 4, ... ] "elementwise addition = logical_and" of labels and bbox + known_indice = known_indice.view(-1) + + # add noise + known_indice = known_indice.repeat(scalar, 1).view(-1) + known_bid = batch_idx.repeat(scalar, 1).view(-1) + known_labels = labels.repeat(scalar, 1).view(-1) + known_bboxs = boxes.repeat(scalar, 1) + known_labels_expaned = known_labels.clone() + known_bbox_expand = known_bboxs.clone() + #print("known_bbox_expand = " +str(known_bbox_expand.shape)) + + # noise on the label + if label_noise_scale > 0: + p = torch.rand_like(known_labels_expaned.float()) + chosen_indice = torch.nonzero(p < (label_noise_scale)).view(-1) # usually half of bbox noise + new_label = torch.randint_like(chosen_indice, 0, num_classes) # randomly put a new one here + known_labels_expaned.scatter_(0, chosen_indice, new_label) + + # noise on the box + if box_noise_scale > 0: + known_bbox_ = torch.zeros_like(known_bboxs) + known_bbox_[:, :2] = known_bboxs[:, :2] - known_bboxs[:, 2:] / 2 + known_bbox_[:, 2:] = known_bboxs[:, :2] + known_bboxs[:, 2:] / 2 + + diff = torch.zeros_like(known_bbox_expand) + diff[:, :2] = known_bbox_expand[:, 2:] / 2 + diff[:, 2:] = known_bbox_expand[:, 2:] / 2 + + if contrastive: + rand_sign = torch.randint_like(known_bbox_expand, low=0, high=2, dtype=torch.float32) * 2.0 - 1.0 + rand_part = torch.rand_like(known_bbox_expand) + positive_idx = torch.tensor(range(len(boxes)//2)).long().cuda().unsqueeze(0).repeat(scalar, 1) + positive_idx += (torch.tensor(range(scalar)) * len(boxes)).long().cuda().unsqueeze(1) + positive_idx = positive_idx.flatten() + negative_idx = positive_idx + len(boxes)//2 + rand_part[negative_idx] += 1.0 + rand_part *= rand_sign + + known_bbox_ += torch.mul(rand_part, diff).cuda() * box_noise_scale + + else: + known_bbox_ += torch.mul((torch.rand_like(known_bbox_expand) * 2 - 1.0), + diff).cuda() * box_noise_scale + + known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0) + known_bbox_expand[:, :2] = (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2 + known_bbox_expand[:, 2:] = known_bbox_[:, 2:] - known_bbox_[:, :2] + + # in the case of negatives, override the label with "num_classes" label + if contrastive: + known_labels_expaned.scatter_(0, negative_idx, num_classes) + + m = known_labels_expaned.long().to('cuda') + input_label_embed = label_enc(m) + # add dn part indicator + indicator1 = torch.ones([input_label_embed.shape[0], 1]).cuda() + input_label_embed = torch.cat([input_label_embed, indicator1], dim=1) + input_bbox_embed = inverse_sigmoid(known_bbox_expand) + single_pad = int(max(known_num)) + pad_size = int(single_pad * scalar) + padding_label = torch.zeros(pad_size, hidden_dim).cuda() + padding_bbox = torch.zeros(pad_size, 4).cuda() + + if tgt is not None and refpoint_emb is not None: + input_query_label = torch.cat([padding_label, tgt], dim=0).repeat(batch_size, 1, 1) + input_query_bbox = torch.cat([padding_bbox, refpoint_emb], dim=0).repeat(batch_size, 1, 1) + else: + input_query_label = padding_label.repeat(batch_size, 1, 1) + input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) + + # map in order + map_known_indice = torch.tensor([]).to('cuda') + if len(known_num): + map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [0, 1, 0, 1, 2] + map_known_indice = torch.cat([map_known_indice + single_pad * i for i in range(scalar)]).long() + # + if len(known_bid): + input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed # [ bs, query_idx, hidden_dim ] + input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed + + tgt_size = pad_size + num_queries * num_patterns + attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 + # match query cannot see the reconstruct + attn_mask[pad_size:, :pad_size] = True + # reconstruct cannot see each other + for i in range(scalar): + if i == 0: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + if i == scalar - 1: + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + else: + attn_mask[single_pad * i:single_pad * (i + 1), single_pad * (i + 1):pad_size] = True + attn_mask[single_pad * i:single_pad * (i + 1), :single_pad * i] = True + mask_dict = { + 'known_indice': torch.as_tensor(known_indice).long(), + 'batch_idx': torch.as_tensor(batch_idx).long(), + 'map_known_indice': torch.as_tensor(map_known_indice).long(), + 'known_lbs_bboxes': (known_labels, known_bboxs), + 'know_idx': know_idx, + 'pad_size': pad_size, + 'scalar': scalar, + 'contrastive' : contrastive, + } + else: # no dn for inference + if tgt is not None and refpoint_emb is not None: + input_query_label = tgt.repeat(batch_size, 1, 1) + input_query_bbox = refpoint_emb.repeat(batch_size, 1, 1) + else: + input_query_label = None + input_query_bbox = None + attn_mask = None + mask_dict = None + + # input_query_label = input_query_label.transpose(0, 1) + # input_query_bbox = input_query_bbox.transpose(0, 1) + + return input_query_label, input_query_bbox, attn_mask, mask_dict + + +def dn_post_process(outputs_class, outputs_coord, mask_dict): + """ + post process of dn after output from the transformer + put the dn part in the mask_dict + """ + if mask_dict and mask_dict['pad_size'] > 0: + output_known_class = outputs_class[:, :, :mask_dict['pad_size'], :] # [ levels, bs, query size, hidden dim] + output_known_coord = outputs_coord[:, :, :mask_dict['pad_size'], :] + outputs_class = outputs_class[:, :, mask_dict['pad_size']:, :] + outputs_coord = outputs_coord[:, :, mask_dict['pad_size']:, :] + mask_dict['output_known_lbs_bboxes']=(output_known_class,output_known_coord) + return outputs_class, outputs_coord + + +def prepare_for_loss(mask_dict): + """ + prepare dn components to calculate loss + Args: + mask_dict: a dict that contains dn information + Returns: + + """ + output_known_class, output_known_coord = mask_dict['output_known_lbs_bboxes'] + known_labels, known_bboxs = mask_dict['known_lbs_bboxes'] + map_known_indice = mask_dict['map_known_indice'] + # [0, 1, 2, 3, 4, ..., 0, 1, 2, 3, 4, ...] + + known_indice = mask_dict['known_indice'] + # [0, 1, 2, 3, 4, ...] + + batch_idx = mask_dict['batch_idx'] + bid = batch_idx[known_indice] + num_tgt = known_indice.numel() + + if len(output_known_class) > 0: + output_known_class = output_known_class.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) + # [ levels, bs, qs, hdim ] -> [ bs, qs, lvls, hdim] -> [ lvls, bs * qs, hdim ] + output_known_coord = output_known_coord.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) + + if mask_dict['contrastive'] : + scalar = mask_dict['scalar'] + num_tgt = num_tgt // 2 + num_box = num_tgt // scalar + positive_idx = torch.tensor(range(num_box)).long().cuda().unsqueeze(0).repeat(scalar, 1) + positive_idx += (torch.tensor(range(scalar)) * num_box * 2).long().cuda().unsqueeze(1) + positive_idx = positive_idx.flatten() + # bbox reconstruction only use positive cases + # but, class reconstruction use both positive and negative(with no-object) + output_known_coord = output_known_coord[:,positive_idx,:] + known_bboxs = known_bboxs[positive_idx,:] + + return known_labels, known_bboxs, output_known_class, output_known_coord, num_tgt + + +def tgt_loss_boxes(src_boxes, tgt_boxes, num_tgt,): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + if len(tgt_boxes) == 0: + return { + 'tgt_loss_bbox': torch.as_tensor(0.).to('cuda'), + 'tgt_loss_giou': torch.as_tensor(0.).to('cuda'), + } + + loss_bbox = F.l1_loss(src_boxes, tgt_boxes, reduction='none') + + losses = {} + losses['tgt_loss_bbox'] = loss_bbox.sum() / num_tgt + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(tgt_boxes))) + losses['tgt_loss_giou'] = loss_giou.sum() / num_tgt + return losses + + +def tgt_loss_labels(src_logits_, tgt_labels_, num_tgt, focal_alpha, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + if len(tgt_labels_) == 0: + return { + 'tgt_loss_ce': torch.as_tensor(0.).to('cuda'), + 'tgt_class_error': torch.as_tensor(0.).to('cuda'), + } + + src_logits, tgt_labels= src_logits_.unsqueeze(0), tgt_labels_.unsqueeze(0) + + target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], + dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) + target_classes_onehot.scatter_(2, tgt_labels.unsqueeze(-1), 1) + + target_classes_onehot = target_classes_onehot[:, :, :-1] + loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_tgt, alpha=focal_alpha, gamma=2) * src_logits.shape[1] + + losses = {'tgt_loss_ce': loss_ce} + + losses['tgt_class_error'] = 100 - accuracy(src_logits_, tgt_labels_)[0] + return losses + + +def compute_dn_loss(mask_dict, training, aux_num, focal_alpha): + """ + compute dn loss in criterion + Args: + mask_dict: a dict for dn information + training: training or inference flag + aux_num: aux loss number + focal_alpha: for focal loss + """ + losses = {} + if training and 'output_known_lbs_bboxes' in mask_dict: + known_labels, known_bboxs, output_known_class, output_known_coord, \ + num_tgt = prepare_for_loss(mask_dict) + # -1 is the final level [ levels, bs * qs, hidden_dim ] + losses.update(tgt_loss_labels(output_known_class[-1], known_labels, num_tgt, focal_alpha)) + losses.update(tgt_loss_boxes(output_known_coord[-1], known_bboxs, num_tgt)) + else: + losses['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda') + losses['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda') + losses['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda') + losses['tgt_class_error'] = torch.as_tensor(0.).to('cuda') + + if aux_num: + for i in range(aux_num): + # dn aux loss + if training and 'output_known_lbs_bboxes' in mask_dict: + l_dict = tgt_loss_labels(output_known_class[i], known_labels, num_tgt, focal_alpha) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + l_dict = tgt_loss_boxes(output_known_coord[i], known_bboxs, num_tgt) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + else: + l_dict = dict() + l_dict['tgt_loss_bbox'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_class_error'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_loss_giou'] = torch.as_tensor(0.).to('cuda') + l_dict['tgt_loss_ce'] = torch.as_tensor(0.).to('cuda') + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + return losses diff --git a/models/dn_dab_dino_deformable_detr/matcher.py b/models/dn_dab_dino_deformable_detr/matcher.py new file mode 100644 index 0000000..4b376f6 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/matcher.py @@ -0,0 +1,106 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + with torch.no_grad(): + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), + box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher(cost_class=args.set_cost_class, + cost_bbox=args.set_cost_bbox, + cost_giou=args.set_cost_giou) diff --git a/models/dn_dab_dino_deformable_detr/ops/functions/__init__.py b/models/dn_dab_dino_deformable_detr/ops/functions/__init__.py new file mode 100644 index 0000000..8a2197b --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/functions/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn_func import MSDeformAttnFunction + diff --git a/models/dn_dab_dino_deformable_detr/ops/functions/ms_deform_attn_func.py b/models/dn_dab_dino_deformable_detr/ops/functions/ms_deform_attn_func.py new file mode 100644 index 0000000..8c5df8c --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/functions/ms_deform_attn_func.py @@ -0,0 +1,61 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import torch.nn.functional as F +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +import MultiScaleDeformableAttention as MSDA + + +class MSDeformAttnFunction(Function): + @staticmethod + def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): + ctx.im2col_step = im2col_step + output = MSDA.ms_deform_attn_forward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors + grad_value, grad_sampling_loc, grad_attn_weight = \ + MSDA.ms_deform_attn_backward( + value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) + + return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None + + +def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): + # for debug and test only, + # need to use cuda version instead + N_, S_, M_, D_ = value.shape + _, Lq_, M_, L_, P_, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for lid_, (H_, W_) in enumerate(value_spatial_shapes): + # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ + value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) + # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 + sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) + # N_*M_, D_, Lq_, P_ + sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, + mode='bilinear', padding_mode='zeros', align_corners=False) + sampling_value_list.append(sampling_value_l_) + # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) + attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) + output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) + return output.transpose(1, 2).contiguous() diff --git a/models/dn_dab_dino_deformable_detr/ops/make.sh b/models/dn_dab_dino_deformable_detr/ops/make.sh new file mode 100755 index 0000000..eae08d0 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/make.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + + +# TORCH_CUDA_ARCH_LIST="8.0" CUDA_HOME='/path/to/your/cuda/dir' +python setup.py build install diff --git a/models/dn_dab_dino_deformable_detr/ops/modules/__init__.py b/models/dn_dab_dino_deformable_detr/ops/modules/__init__.py new file mode 100644 index 0000000..f82cb1a --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/modules/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from .ms_deform_attn import MSDeformAttn diff --git a/models/dn_dab_dino_deformable_detr/ops/modules/ms_deform_attn.py b/models/dn_dab_dino_deformable_detr/ops/modules/ms_deform_attn.py new file mode 100644 index 0000000..663d64a --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/modules/ms_deform_attn.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import warnings +import math + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn.init import xavier_uniform_, constant_ + +from ..functions import MSDeformAttnFunction + + +def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + return (n & (n-1) == 0) and n != 0 + + +class MSDeformAttn(nn.Module): + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """ + Multi-Scale Deformable Attention Module + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level + """ + super().__init__() + if d_model % n_heads != 0: + raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) + _d_per_head = d_model // n_heads + # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation + if not _is_power_of_2(_d_per_head): + warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " + "which is more efficient in our CUDA implementation.") + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): + """ + :param query (N, Length_{query}, C) + :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes + :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements + + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) + output = MSDeformAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + return output diff --git a/models/dn_dab_dino_deformable_detr/ops/setup.py b/models/dn_dab_dino_deformable_detr/ops/setup.py new file mode 100644 index 0000000..049f923 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/setup.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + # import ipdb; ipdb.set_trace() + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "MultiScaleDeformableAttention", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="MultiScaleDeformableAttention", + version="1.0", + author="Weijie Su", + url="https://github.com/fundamentalvision/Deformable-DETR", + description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", + packages=find_packages(exclude=("configs", "tests",)), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/models/dn_dab_dino_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp b/models/dn_dab_dino_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp new file mode 100644 index 0000000..e1bf854 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.cpp @@ -0,0 +1,41 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include + +#include +#include + + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + AT_ERROR("Not implement on cpu"); +} + diff --git a/models/dn_dab_dino_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.h b/models/dn_dab_dino_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.h new file mode 100644 index 0000000..81b7b58 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/cpu/ms_deform_attn_cpu.h @@ -0,0 +1,33 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor +ms_deform_attn_cpu_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector +ms_deform_attn_cpu_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + + diff --git a/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu b/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu new file mode 100644 index 0000000..d6d5836 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.cu @@ -0,0 +1,153 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include +#include "cuda/ms_deform_im2col_cuda.cuh" + +#include +#include +#include +#include + + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); + + const int batch_n = im2col_step_; + auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto columns = output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + columns.data()); + + })); + } + + output = output.view({batch, num_query, num_heads*channels}); + + return output; +} + + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + + AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); + AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); + AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); + AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); + AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); + AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); + + AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); + AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); + AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); + AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); + AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); + + const int batch = value.size(0); + const int spatial_size = value.size(1); + const int num_heads = value.size(2); + const int channels = value.size(3); + + const int num_levels = spatial_shapes.size(0); + + const int num_query = sampling_loc.size(1); + const int num_point = sampling_loc.size(4); + + const int im2col_step_ = std::min(batch, im2col_step); + + AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); + + auto grad_value = at::zeros_like(value); + auto grad_sampling_loc = at::zeros_like(sampling_loc); + auto grad_attn_weight = at::zeros_like(attn_weight); + + const int batch_n = im2col_step_; + auto per_value_size = spatial_size * num_heads * channels; + auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; + auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; + auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); + + for (int n = 0; n < batch/im2col_step_; ++n) + { + auto grad_output_g = grad_output_n.select(0, n); + AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), + grad_output_g.data(), + value.data() + n * im2col_step_ * per_value_size, + spatial_shapes.data(), + level_start_index.data(), + sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, + grad_value.data() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + + })); + } + + return { + grad_value, grad_sampling_loc, grad_attn_weight + }; +} \ No newline at end of file diff --git a/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.h b/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.h new file mode 100644 index 0000000..c7ae53f --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_attn_cuda.h @@ -0,0 +1,30 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once +#include + +at::Tensor ms_deform_attn_cuda_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step); + +std::vector ms_deform_attn_cuda_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step); + diff --git a/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_im2col_cuda.cuh b/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_im2col_cuda.cuh new file mode 100644 index 0000000..6bc2acb --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/cuda/ms_deform_im2col_cuda.cuh @@ -0,0 +1,1327 @@ +/*! +************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************** +* Modified from DCN (https://github.com/msracver/Deformable-ConvNets) +* Copyright (c) 2018 Microsoft +************************************************************************** +*/ + +#include +#include +#include + +#include +#include + +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N, const int num_threads) +{ + return (N + num_threads - 1) / num_threads; +} + + +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + *grad_attn_weight = top_grad * val; + *grad_sampling_loc = width * grad_w_weight * top_grad_value; + *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; +} + + +template +__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, + const int &height, const int &width, const int &nheads, const int &channels, + const scalar_t &h, const scalar_t &w, const int &m, const int &c, + const scalar_t &top_grad, + const scalar_t &attn_weight, + scalar_t* &grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int h_low = floor(h); + const int w_low = floor(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; + + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; + + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t top_grad_value = top_grad * attn_weight; + scalar_t grad_h_weight = 0, grad_w_weight = 0; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_value+ptr1, w1*top_grad_value); + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_value+ptr2, w2*top_grad_value); + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_value+ptr3, w3*top_grad_value); + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_value+ptr4, w4*top_grad_value); + } + + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_attn_weight, top_grad * val); + atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); + atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); +} + + +template +__global__ void ms_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + scalar_t *data_col_ptr = data_col + index; + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + scalar_t col = 0; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; + } + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + } + } + *data_col_ptr = col; + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockSize; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; + __shared__ scalar_t cache_grad_attn_weight[blockSize]; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockSize/2; s>0; s>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + if (tid == 0) + { + scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; + int sid=2; + for (unsigned int tid = 1; tid < blockDim.x; ++tid) + { + _grad_w += cache_grad_sampling_loc[sid]; + _grad_h += cache_grad_sampling_loc[sid + 1]; + _grad_a += cache_grad_attn_weight[tid]; + sid += 2; + } + + + *grad_sampling_loc = _grad_w; + *(grad_sampling_loc + 1) = _grad_h; + *grad_attn_weight = _grad_a; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + *grad_sampling_loc = cache_grad_sampling_loc[0]; + *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; + *grad_attn_weight = cache_grad_attn_weight[0]; + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + +template +__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + extern __shared__ int _s[]; + scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; + scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; + unsigned int tid = threadIdx.x; + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; + *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; + *(cache_grad_attn_weight+threadIdx.x)=0; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); + } + + __syncthreads(); + + for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) + { + if (tid < s) { + const unsigned int xid1 = tid << 1; + const unsigned int xid2 = (tid + s) << 1; + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; + if (tid + (s << 1) < spre) + { + cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; + cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; + cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; + } + } + __syncthreads(); + } + + if (tid == 0) + { + atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); + atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); + atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); + } + __syncthreads(); + + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, + const scalar_t *grad_col, + const scalar_t *data_value, + const int64_t *data_spatial_shapes, + const int64_t *data_level_start_index, + const scalar_t *data_sampling_loc, + const scalar_t *data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t *grad_value, + scalar_t *grad_sampling_loc, + scalar_t *grad_attn_weight) +{ + CUDA_KERNEL_LOOP(index, n) + { + int _temp = index; + const int c_col = _temp % channels; + _temp /= channels; + const int sampling_index = _temp; + const int m_col = _temp % num_heads; + _temp /= num_heads; + const int q_col = _temp % num_query; + _temp /= num_query; + const int b_col = _temp; + + const scalar_t top_grad = grad_col[index]; + + int data_weight_ptr = sampling_index * num_levels * num_point; + int data_loc_w_ptr = data_weight_ptr << 1; + const int grad_sampling_ptr = data_weight_ptr; + grad_sampling_loc += grad_sampling_ptr << 1; + grad_attn_weight += grad_sampling_ptr; + const int grad_weight_stride = 1; + const int grad_loc_stride = 2; + const int qid_stride = num_heads * channels; + const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; + + for (int l_col=0; l_col < num_levels; ++l_col) + { + const int level_start_id = data_level_start_index[l_col]; + const int spatial_h_ptr = l_col << 1; + const int spatial_h = data_spatial_shapes[spatial_h_ptr]; + const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; + const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; + const scalar_t *data_value_ptr = data_value + value_ptr_offset; + scalar_t *grad_value_ptr = grad_value + value_ptr_offset; + + for (int p_col=0; p_col < num_point; ++p_col) + { + const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; + const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; + const scalar_t weight = data_attn_weight[data_weight_ptr]; + + const scalar_t h_im = loc_h * spatial_h - 0.5; + const scalar_t w_im = loc_w * spatial_w - 0.5; + if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) + { + ms_deform_attn_col2im_bilinear_gm( + data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, + top_grad, weight, grad_value_ptr, + grad_sampling_loc, grad_attn_weight); + } + data_weight_ptr += 1; + data_loc_w_ptr += 2; + grad_attn_weight += grad_weight_stride; + grad_sampling_loc += grad_loc_stride; + } + } + } +} + + +template +void ms_deformable_im2col_cuda(cudaStream_t stream, + const scalar_t* data_value, + const int64_t* data_spatial_shapes, + const int64_t* data_level_start_index, + const scalar_t* data_sampling_loc, + const scalar_t* data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* data_col) +{ + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + const int num_threads = CUDA_NUM_THREADS; + ms_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, + batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +template +void ms_deformable_col2im_cuda(cudaStream_t stream, + const scalar_t* grad_col, + const scalar_t* data_value, + const int64_t * data_spatial_shapes, + const int64_t * data_level_start_index, + const scalar_t * data_sampling_loc, + const scalar_t * data_attn_weight, + const int batch_size, + const int spatial_size, + const int num_heads, + const int channels, + const int num_levels, + const int num_query, + const int num_point, + scalar_t* grad_value, + scalar_t* grad_sampling_loc, + scalar_t* grad_attn_weight) +{ + const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; + const int num_kernels = batch_size * num_query * num_heads * channels; + const int num_actual_kernels = batch_size * num_query * num_heads * channels; + if (channels > 1024) + { + if ((channels & 1023) == 0) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_gm + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + else{ + switch(channels) + { + case 1: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 2: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 4: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 8: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 16: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 32: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 64: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 128: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 256: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 512: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + case 1024: + ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + break; + default: + if (channels < 64) + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v1 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + else + { + ms_deformable_col2im_gpu_kernel_shm_reduce_v2 + <<>>( + num_kernels, + grad_col, + data_value, + data_spatial_shapes, + data_level_start_index, + data_sampling_loc, + data_attn_weight, + batch_size, + spatial_size, + num_heads, + channels, + num_levels, + num_query, + num_point, + grad_value, + grad_sampling_loc, + grad_attn_weight); + } + } + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} \ No newline at end of file diff --git a/models/dn_dab_dino_deformable_detr/ops/src/ms_deform_attn.h b/models/dn_dab_dino_deformable_detr/ops/src/ms_deform_attn.h new file mode 100644 index 0000000..ac0ef2e --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/ms_deform_attn.h @@ -0,0 +1,62 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#pragma once + +#include "cpu/ms_deform_attn_cpu.h" + +#ifdef WITH_CUDA +#include "cuda/ms_deform_attn_cuda.h" +#endif + + +at::Tensor +ms_deform_attn_forward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_forward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +ms_deform_attn_backward( + const at::Tensor &value, + const at::Tensor &spatial_shapes, + const at::Tensor &level_start_index, + const at::Tensor &sampling_loc, + const at::Tensor &attn_weight, + const at::Tensor &grad_output, + const int im2col_step) +{ + if (value.type().is_cuda()) + { +#ifdef WITH_CUDA + return ms_deform_attn_cuda_backward( + value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + diff --git a/models/dn_dab_dino_deformable_detr/ops/src/vision.cpp b/models/dn_dab_dino_deformable_detr/ops/src/vision.cpp new file mode 100644 index 0000000..2201f63 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/src/vision.cpp @@ -0,0 +1,16 @@ +/*! +************************************************************************************************** +* Deformable DETR +* Copyright (c) 2020 SenseTime. All Rights Reserved. +* Licensed under the Apache License, Version 2.0 [see LICENSE for details] +************************************************************************************************** +* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +************************************************************************************************** +*/ + +#include "ms_deform_attn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); + m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); +} diff --git a/models/dn_dab_dino_deformable_detr/ops/test.py b/models/dn_dab_dino_deformable_detr/ops/test.py new file mode 100644 index 0000000..8dbf6d5 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/ops/test.py @@ -0,0 +1,89 @@ +# ------------------------------------------------------------------------------------------------ +# Deformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------------------------------ +# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 +# ------------------------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch + + +N, M, D = 1, 2, 2 +Lq, L, P = 2, 2, 2 +shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() +level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) +S = sum([(H*W).item() for H, W in shapes]) + + +torch.manual_seed(3) + + +@torch.no_grad() +def check_forward_equal_with_pytorch_double(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +@torch.no_grad() +def check_forward_equal_with_pytorch_float(): + value = torch.rand(N, S, M, D).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() + output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() + fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) + max_abs_err = (output_cuda - output_pytorch).abs().max() + max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() + + print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') + + +def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): + + value = torch.rand(N, S, M, channels).cuda() * 0.01 + sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() + attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 + attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) + im2col_step = 2 + func = MSDeformAttnFunction.apply + + value.requires_grad = grad_value + sampling_locations.requires_grad = grad_sampling_loc + attention_weights.requires_grad = grad_attn_weight + + gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) + + print(f'* {gradok} check_gradient_numerical(D={channels})') + + +if __name__ == '__main__': + check_forward_equal_with_pytorch_double() + check_forward_equal_with_pytorch_float() + + for channels in [30, 32, 64, 71, 1025, 2048, 3096]: + check_gradient_numerical(channels, True, True, True) + + + diff --git a/models/dn_dab_dino_deformable_detr/position_encoding.py b/models/dn_dab_dino_deformable_detr/position_encoding.py new file mode 100644 index 0000000..868ff86 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/position_encoding.py @@ -0,0 +1,101 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/models/dn_dab_dino_deformable_detr/segmentation.py b/models/dn_dab_dino_deformable_detr/segmentation.py new file mode 100644 index 0000000..01e1982 --- /dev/null +++ b/models/dn_dab_dino_deformable_detr/segmentation.py @@ -0,0 +1,373 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# DModified from eformable DETR +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +This file provides the definition of the convolutional heads used to predict masks, as well as the losses +""" +import io +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +import util.box_ops as box_ops +from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list + +try: + from panopticapi.utils import id2rgb, rgb2id +except ImportError: + pass + + +class DETRsegm(nn.Module): + def __init__(self, detr, freeze_detr=False): + super().__init__() + self.detr = detr + + if freeze_detr: + for p in self.parameters(): + p.requires_grad_(False) + + hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead + self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0) + self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) + + def forward(self, samples: NestedTensor): + if not isinstance(samples, NestedTensor): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.detr.backbone(samples) + + bs = features[-1].tensors.shape[0] + + src, mask = features[-1].decompose() + src_proj = self.detr.input_proj(src) + hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) + + outputs_class = self.detr.class_embed(hs) + outputs_coord = self.detr.bbox_embed(hs).sigmoid() + out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.detr.aux_loss: + out["aux_outputs"] = [ + {"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] + + # FIXME h_boxes takes the last one computed, keep this in mind + bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) + + seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) + outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) + + out["pred_masks"] = outputs_seg_masks + return out + + +class MaskHeadSmallConv(nn.Module): + """ + Simple convolutional head, using group norm. + Upsampling is done using a FPN approach + """ + + def __init__(self, dim, fpn_dims, context_dim): + super().__init__() + + inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] + self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) + self.gn1 = torch.nn.GroupNorm(8, dim) + self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) + self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) + self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) + self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) + self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) + self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) + self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) + self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) + self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) + + self.dim = dim + + self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) + self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) + self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_uniform_(m.weight, a=1) + nn.init.constant_(m.bias, 0) + + def forward(self, x, bbox_mask, fpns): + def expand(tensor, length): + return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) + + x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) + + x = self.lay1(x) + x = self.gn1(x) + x = F.relu(x) + x = self.lay2(x) + x = self.gn2(x) + x = F.relu(x) + + cur_fpn = self.adapter1(fpns[0]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay3(x) + x = self.gn3(x) + x = F.relu(x) + + cur_fpn = self.adapter2(fpns[1]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay4(x) + x = self.gn4(x) + x = F.relu(x) + + cur_fpn = self.adapter3(fpns[2]) + if cur_fpn.size(0) != x.size(0): + cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) + x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") + x = self.lay5(x) + x = self.gn5(x) + x = F.relu(x) + + x = self.out_lay(x) + return x + + +class MHAttentionMap(nn.Module): + """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" + + def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True): + super().__init__() + self.num_heads = num_heads + self.hidden_dim = hidden_dim + self.dropout = nn.Dropout(dropout) + + self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + + nn.init.zeros_(self.k_linear.bias) + nn.init.zeros_(self.q_linear.bias) + nn.init.xavier_uniform_(self.k_linear.weight) + nn.init.xavier_uniform_(self.q_linear.weight) + self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + + def forward(self, q, k, mask=None): + q = self.q_linear(q) + k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) + qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) + kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) + weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) + + if mask is not None: + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) + weights = self.dropout(weights) + return weights + + +def dice_loss(inputs, targets, num_boxes): + """ + Compute the DICE loss, similar to generalized IOU for masks + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + """ + inputs = inputs.sigmoid() + inputs = inputs.flatten(1) + numerator = 2 * (inputs * targets).sum(1) + denominator = inputs.sum(-1) + targets.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + return loss.sum() / num_boxes + + +def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): + """ + Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + Args: + inputs: A float tensor of arbitrary shape. + The predictions for each example. + targets: A float tensor with the same shape as inputs. Stores the binary + classification label for each element in inputs + (0 for the negative class and 1 for the positive class). + alpha: (optional) Weighting factor in range (0,1) to balance + positive vs negative examples. Default = -1 (no weighting). + gamma: Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. + Returns: + Loss tensor + """ + prob = inputs.sigmoid() + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + p_t = prob * targets + (1 - prob) * (1 - targets) + loss = ce_loss * ((1 - p_t) ** gamma) + + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + loss = alpha_t * loss + + return loss.mean(1).sum() / num_boxes + + +class PostProcessSegm(nn.Module): + def __init__(self, threshold=0.5): + super().__init__() + self.threshold = threshold + + @torch.no_grad() + def forward(self, results, outputs, orig_target_sizes, max_target_sizes): + assert len(orig_target_sizes) == len(max_target_sizes) + max_h, max_w = max_target_sizes.max(0)[0].tolist() + outputs_masks = outputs["pred_masks"].squeeze(2) + outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) + outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() + + for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): + img_h, img_w = t[0], t[1] + results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) + results[i]["masks"] = F.interpolate( + results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" + ).byte() + + return results + + +class PostProcessPanoptic(nn.Module): + """This class converts the output of the model to the final panoptic result, in the format expected by the + coco panoptic API """ + + def __init__(self, is_thing_map, threshold=0.85): + """ + Parameters: + is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether + the class is a thing (True) or a stuff (False) class + threshold: confidence threshold: segments with confidence lower than this will be deleted + """ + super().__init__() + self.threshold = threshold + self.is_thing_map = is_thing_map + + def forward(self, outputs, processed_sizes, target_sizes=None): + """ This function computes the panoptic prediction from the model's predictions. + Parameters: + outputs: This is a dict coming directly from the model. See the model doc for the content. + processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the + model, ie the size after data augmentation but before batching. + target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size + of each prediction. If left to None, it will default to the processed_sizes + """ + if target_sizes is None: + target_sizes = processed_sizes + assert len(processed_sizes) == len(target_sizes) + out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] + assert len(out_logits) == len(raw_masks) == len(target_sizes) + preds = [] + + def to_tuple(tup): + if isinstance(tup, tuple): + return tup + return tuple(tup.cpu().tolist()) + + for cur_logits, cur_masks, cur_boxes, size, target_size in zip( + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + ): + # we filter empty queries and detection below threshold + scores, labels = cur_logits.softmax(-1).max(-1) + keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) + cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) + cur_scores = cur_scores[keep] + cur_classes = cur_classes[keep] + cur_masks = cur_masks[keep] + cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) + cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) + + h, w = cur_masks.shape[-2:] + assert len(cur_boxes) == len(cur_classes) + + # It may be that we have several predicted masks for the same stuff class. + # In the following, we track the list of masks ids for each stuff class (they are merged later on) + cur_masks = cur_masks.flatten(1) + stuff_equiv_classes = defaultdict(lambda: []) + for k, label in enumerate(cur_classes): + if not self.is_thing_map[label.item()]: + stuff_equiv_classes[label.item()].append(k) + + def get_ids_area(masks, scores, dedup=False): + # This helper function creates the final panoptic segmentation image + # It also returns the area of the masks that appears on the image + + m_id = masks.transpose(0, 1).softmax(-1) + + if m_id.shape[-1] == 0: + # We didn't detect any mask :( + m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) + else: + m_id = m_id.argmax(-1).view(h, w) + + if dedup: + # Merge the masks corresponding to the same stuff class + for equiv in stuff_equiv_classes.values(): + if len(equiv) > 1: + for eq_id in equiv: + m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) + + final_h, final_w = to_tuple(target_size) + + seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) + seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) + + np_seg_img = ( + torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() + ) + m_id = torch.from_numpy(rgb2id(np_seg_img)) + + area = [] + for i in range(len(scores)): + area.append(m_id.eq(i).sum().item()) + return area, seg_img + + area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) + if cur_classes.numel() > 0: + # We know filter empty masks as long as we find some + while True: + filtered_small = torch.as_tensor( + [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device + ) + if filtered_small.any().item(): + cur_scores = cur_scores[~filtered_small] + cur_classes = cur_classes[~filtered_small] + cur_masks = cur_masks[~filtered_small] + area, seg_img = get_ids_area(cur_masks, cur_scores) + else: + break + + else: + cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) + + segments_info = [] + for i, a in enumerate(area): + cat = cur_classes[i].item() + segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) + del cur_classes + + with io.BytesIO() as out: + seg_img.save(out, format="PNG") + predictions = {"png_string": out.getvalue(), "segments_info": segments_info} + preds.append(predictions) + return preds