Skip to content

Commit

Permalink
update remap_mscoco_category
Browse files Browse the repository at this point in the history
  • Loading branch information
lyuwenyu committed Nov 6, 2023
1 parent 5f0f43c commit 95fc522
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 181 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ This is the official implementation of the paper "[DETRs Beat YOLOs on Real-time

## Updates!!!
---
- \[2023.11.05\] upgrade the logic of `remap_mscoco_category` to facilitate training of custom datasets, see detils in [Train custom data](./rtdetr_pytorch/) part.
- \[2023.10.23\] Add [*discussion for deployments*](https://github.com/lyuwenyu/RT-DETR/issues/95), supported onnxruntime, TensorRT, openVINO
- \[2023.10.12\] Add tuning code for pytorch version, now you can tuning rtdetr based on pretrained weights
- \[2023.09.19\] Upload [*pytorch weights*](https://github.com/lyuwenyu/RT-DETR/issues/42) convert from paddle version
Expand Down
2 changes: 1 addition & 1 deletion rtdetr_pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ python tools/export_onnx.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -r path/t
<details open>
<summary>Train custom data</summary>

1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset.
1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset. If you want to use `remap_mscoco_category` logic on your dataset, please modify variable [`mscoco_category2name`](https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/data/coco/coco_dataset.py) based on your dataset.

2. add `-t path/to/checkpoint` (optinal) to tuning rtdetr based on pretrained checkpoint. see [training script details](./tools/README.md).
</details>
8 changes: 6 additions & 2 deletions rtdetr_pytorch/src/data/coco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

from .coco_dataset import *
from .coco_dataset import (
CocoDetection,
mscoco_category2label,
mscoco_label2category,
mscoco_category2name,
)
from .coco_eval import *

from .coco_utils import get_coco_api_from_dataset
94 changes: 4 additions & 90 deletions rtdetr_pytorch/src/data/coco/coco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __call__(self, image, target):
boxes[:, 1::2].clamp_(min=0, max=h)

if self.remap_mscoco_category:
classes = [category2label[obj["category_id"]] - 1 for obj in anno]
classes = [mscoco_category2label[obj["category_id"]] for obj in anno]
else:
classes = [obj["category_id"] for obj in anno]

Expand Down Expand Up @@ -151,10 +151,7 @@ def __call__(self, image, target):
return image, target




names = {
0: 'background',
mscoco_category2name = {
1: 'person',
2: 'bicycle',
3: 'car',
Expand Down Expand Up @@ -237,88 +234,5 @@ def __call__(self, image, target):
90: 'toothbrush'
}


label2category = {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 13,
13: 14,
14: 15,
15: 16,
16: 17,
17: 18,
18: 19,
19: 20,
20: 21,
21: 22,
22: 23,
23: 24,
24: 25,
25: 27,
26: 28,
27: 31,
28: 32,
29: 33,
30: 34,
31: 35,
32: 36,
33: 37,
34: 38,
35: 39,
36: 40,
37: 41,
38: 42,
39: 43,
40: 44,
41: 46,
42: 47,
43: 48,
44: 49,
45: 50,
46: 51,
47: 52,
48: 53,
49: 54,
50: 55,
51: 56,
52: 57,
53: 58,
54: 59,
55: 60,
56: 61,
57: 62,
58: 63,
59: 64,
60: 65,
61: 67,
62: 70,
63: 72,
64: 73,
65: 74,
66: 75,
67: 76,
68: 77,
69: 78,
70: 79,
71: 80,
72: 81,
73: 82,
74: 84,
75: 85,
76: 86,
77: 87,
78: 88,
79: 89,
80: 90
}

category2label = {v: k for k, v in label2category.items()}
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())}
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()}
90 changes: 2 additions & 88 deletions rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def forward(self, outputs, orig_target_sizes):

# TODO
if self.remap_mscoco_category:
labels = torch.tensor([self.mscoco_label_category_map[int(x.item()) + 1] for x in labels.flatten()])\
from ...data.coco import mscoco_label2category
labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\
.to(boxes.device).reshape(labels.shape)

results = []
Expand All @@ -77,90 +78,3 @@ def deploy(self, ):
@property
def iou_types(self, ):
return ('bbox', )


@property
def mscoco_label_category_map(self, ):
return {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 13,
13: 14,
14: 15,
15: 16,
16: 17,
17: 18,
18: 19,
19: 20,
20: 21,
21: 22,
22: 23,
23: 24,
24: 25,
25: 27,
26: 28,
27: 31,
28: 32,
29: 33,
30: 34,
31: 35,
32: 36,
33: 37,
34: 38,
35: 39,
36: 40,
37: 41,
38: 42,
39: 43,
40: 44,
41: 46,
42: 47,
43: 48,
44: 49,
45: 50,
46: 51,
47: 52,
48: 53,
49: 54,
50: 55,
51: 56,
52: 57,
53: 58,
54: 59,
55: 60,
56: 61,
57: 62,
58: 63,
59: 64,
60: 65,
61: 67,
62: 70,
63: 72,
64: 73,
65: 74,
66: 75,
67: 76,
68: 77,
69: 78,
70: 79,
71: 80,
72: 81,
73: 82,
74: 84,
75: 85,
76: 86,
77: 87,
78: 88,
79: 89,
80: 90
}

0 comments on commit 95fc522

Please sign in to comment.