Skip to content

Commit

Permalink
RTMDet-tiny enablement for detection task (#3542)
Browse files Browse the repository at this point in the history
- RTMDet-tiny enablement on detection task
- Minor refactoring
  • Loading branch information
sungchul2 authored May 31, 2024
1 parent 78662c1 commit 52897e2
Show file tree
Hide file tree
Showing 18 changed files with 1,096 additions and 88 deletions.
2 changes: 1 addition & 1 deletion src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _customize_outputs(
elif isinstance(v, torch.Tensor):
losses[k] = v
else:
msg = "Loss output should be list or torch.tensor but got {type(v)}"
msg = f"Loss output should be list or torch.tensor but got {type(v)}"
raise TypeError(msg)
return losses

Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/detection/heads/atss_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from otx.algo.detection.heads.class_incremental_mixin import (
ClassIncrementalMixin,
)
from otx.algo.detection.losses.cross_entropy_loss import CrossEntropyLoss
from otx.algo.detection.losses.cross_focal_loss import (
CrossSigmoidFocalLoss,
)
Expand Down Expand Up @@ -57,12 +58,12 @@ def __init__(
self,
num_classes: int,
in_channels: int,
loss_centerness: nn.Module,
pred_kernel_size: int = 3,
stacked_convs: int = 4,
conv_cfg: dict | None = None,
norm_cfg: dict | None = None,
reg_decoded_bbox: bool = True,
loss_centerness: nn.Module | None = None,
init_cfg: dict | None = None,
bg_loss_weight: float = -1.0,
use_qfl: bool = False,
Expand All @@ -89,7 +90,7 @@ def __init__(
)

self.sampling = False
self.loss_centerness = loss_centerness
self.loss_centerness = loss_centerness or CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0)

if use_qfl:
kwargs["loss_cls"] = (
Expand Down
52 changes: 3 additions & 49 deletions src/otx/algo/detection/heads/delta_xywh_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from torch import Tensor

from otx.algo.detection.utils.utils import clip_bboxes_export


# This class and its supporting functions below lightly adapted from the mmdet DeltaXYWHBBoxCoder available at:
# https://github.com/open-mmlab/mmdetection/blob/cfd5d3a985b0249de009b67d04f37263e11cdf3d/mmdet/models/task_modules/coders/delta_xywh_bbox_coder.py
Expand Down Expand Up @@ -360,54 +362,6 @@ def delta2bbox_export(
y2 = xy2[..., 1]

if clip_border and max_shape is not None:
x1, y1, x2, y2 = clip_bboxes(x1, y1, x2, y2, max_shape)
x1, y1, x2, y2 = clip_bboxes_export(x1, y1, x2, y2, max_shape)

return torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())


def clip_bboxes(
x1: Tensor,
y1: Tensor,
x2: Tensor,
y2: Tensor,
max_shape: Tensor | tuple[int, ...],
) -> tuple[Tensor, ...]:
"""Clip bboxes for onnx.
Since torch.clamp cannot have dynamic `min` and `max`, we scale the
boxes by 1/max_shape and clamp in the range [0, 1] if necessary.
Args:
x1 (Tensor): The x1 for bounding boxes.
y1 (Tensor): The y1 for bounding boxes.
x2 (Tensor): The x2 for bounding boxes.
y2 (Tensor): The y2 for bounding boxes.
max_shape (Tensor | Sequence[int]): The (H,W) of original image.
Returns:
tuple(Tensor): The clipped x1, y1, x2, y2.
"""
if isinstance(max_shape, torch.Tensor):
# scale by 1/max_shape
x1 = x1 / max_shape[1]
y1 = y1 / max_shape[0]
x2 = x2 / max_shape[1]
y2 = y2 / max_shape[0]

# clamp [0, 1]
x1 = torch.clamp(x1, 0, 1)
y1 = torch.clamp(y1, 0, 1)
x2 = torch.clamp(x2, 0, 1)
y2 = torch.clamp(y2, 0, 1)

# scale back
x1 = x1 * max_shape[1]
y1 = y1 * max_shape[0]
x2 = x2 * max_shape[1]
y2 = y2 * max_shape[0]
else:
x1 = torch.clamp(x1, 0, max_shape[1])
y1 = torch.clamp(y1, 0, max_shape[0])
x2 = torch.clamp(x2, 0, max_shape[1])
y2 = torch.clamp(y2, 0, max_shape[0])
return x1, y1, x2, y2
31 changes: 29 additions & 2 deletions src/otx/algo/detection/heads/distance_point_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OpenMMLab. All rights reserved.
""""Distance Point BBox coder."""
"""Distance Point BBox coder."""

from __future__ import annotations

from typing import TYPE_CHECKING

from otx.algo.detection.utils.utils import bbox2distance, distance2bbox
from otx.algo.detection.utils.utils import bbox2distance, distance2bbox, distance2bbox_export

if TYPE_CHECKING:
from torch import Tensor
Expand Down Expand Up @@ -84,3 +84,30 @@ def decode(
if self.clip_border is False:
max_shape = None
return distance2bbox(points, pred_bboxes, max_shape)

def decode_export(
self,
points: Tensor,
pred_bboxes: Tensor,
max_shape: tuple[int, ...] | Tensor | tuple[tuple[int, ...], ...] | None = None,
) -> Tensor:
"""Decode distance prediction to bounding box for export."""
if points.size(0) != pred_bboxes.size(0):
msg = (
f"The batch of points (={points.size(0)}) and the batch of pred_bboxes "
f"(={pred_bboxes.size(0)}) should be same."
)
raise ValueError(msg)

if points.size(-1) != 2:
msg = f"points should have the format with size of 2, given {points.size(-1)}."
raise ValueError(msg)

if pred_bboxes.size(-1) != 4:
msg = f"pred_bboxes should have the format with size of 4, given {pred_bboxes.size(-1)}."
raise ValueError(msg)

if self.clip_border is False:
max_shape = None

return distance2bbox_export(points, pred_bboxes, max_shape)
Loading

0 comments on commit 52897e2

Please sign in to comment.