diff --git a/README.md b/README.md index 5f30e59..7665b29 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ This is the official implementation of the paper "[DETRs Beat YOLOs on Real-time ## Updates!!! --- +- \[2023.11.05\] upgrade the logic of `remap_mscoco_category` to facilitate training of custom datasets, see detils in [Train custom data](./rtdetr_pytorch/) part. - \[2023.10.23\] Add [*discussion for deployments*](https://github.com/lyuwenyu/RT-DETR/issues/95), supported onnxruntime, TensorRT, openVINO - \[2023.10.12\] Add tuning code for pytorch version, now you can tuning rtdetr based on pretrained weights - \[2023.09.19\] Upload [*pytorch weights*](https://github.com/lyuwenyu/RT-DETR/issues/42) convert from paddle version diff --git a/rtdetr_pytorch/README.md b/rtdetr_pytorch/README.md index 5998fb9..cae34f1 100644 --- a/rtdetr_pytorch/README.md +++ b/rtdetr_pytorch/README.md @@ -79,7 +79,7 @@ python tools/export_onnx.py -c configs/rtdetr/rtdetr_r18vd_6x_coco.yml -r path/t
Train custom data -1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset. +1. set `remap_mscoco_category: False`. This variable only works for ms-coco dataset. If you want to use `remap_mscoco_category` logic on your dataset, please modify variable [`mscoco_category2name`](https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/data/coco/coco_dataset.py) based on your dataset. 2. add `-t path/to/checkpoint` (optinal) to tuning rtdetr based on pretrained checkpoint. see [training script details](./tools/README.md).
diff --git a/rtdetr_pytorch/src/data/coco/__init__.py b/rtdetr_pytorch/src/data/coco/__init__.py index fb16dca..c83b002 100644 --- a/rtdetr_pytorch/src/data/coco/__init__.py +++ b/rtdetr_pytorch/src/data/coco/__init__.py @@ -1,5 +1,9 @@ - -from .coco_dataset import * +from .coco_dataset import ( + CocoDetection, + mscoco_category2label, + mscoco_label2category, + mscoco_category2name, +) from .coco_eval import * from .coco_utils import get_coco_api_from_dataset \ No newline at end of file diff --git a/rtdetr_pytorch/src/data/coco/coco_dataset.py b/rtdetr_pytorch/src/data/coco/coco_dataset.py index 1b27387..0ef7849 100644 --- a/rtdetr_pytorch/src/data/coco/coco_dataset.py +++ b/rtdetr_pytorch/src/data/coco/coco_dataset.py @@ -104,7 +104,7 @@ def __call__(self, image, target): boxes[:, 1::2].clamp_(min=0, max=h) if self.remap_mscoco_category: - classes = [category2label[obj["category_id"]] - 1 for obj in anno] + classes = [mscoco_category2label[obj["category_id"]] for obj in anno] else: classes = [obj["category_id"] for obj in anno] @@ -151,10 +151,7 @@ def __call__(self, image, target): return image, target - - -names = { - 0: 'background', +mscoco_category2name = { 1: 'person', 2: 'bicycle', 3: 'car', @@ -237,88 +234,5 @@ def __call__(self, image, target): 90: 'toothbrush' } - -label2category = { - 1: 1, - 2: 2, - 3: 3, - 4: 4, - 5: 5, - 6: 6, - 7: 7, - 8: 8, - 9: 9, - 10: 10, - 11: 11, - 12: 13, - 13: 14, - 14: 15, - 15: 16, - 16: 17, - 17: 18, - 18: 19, - 19: 20, - 20: 21, - 21: 22, - 22: 23, - 23: 24, - 24: 25, - 25: 27, - 26: 28, - 27: 31, - 28: 32, - 29: 33, - 30: 34, - 31: 35, - 32: 36, - 33: 37, - 34: 38, - 35: 39, - 36: 40, - 37: 41, - 38: 42, - 39: 43, - 40: 44, - 41: 46, - 42: 47, - 43: 48, - 44: 49, - 45: 50, - 46: 51, - 47: 52, - 48: 53, - 49: 54, - 50: 55, - 51: 56, - 52: 57, - 53: 58, - 54: 59, - 55: 60, - 56: 61, - 57: 62, - 58: 63, - 59: 64, - 60: 65, - 61: 67, - 62: 70, - 63: 72, - 64: 73, - 65: 74, - 66: 75, - 67: 76, - 68: 77, - 69: 78, - 70: 79, - 71: 80, - 72: 81, - 73: 82, - 74: 84, - 75: 85, - 76: 86, - 77: 87, - 78: 88, - 79: 89, - 80: 90 -} - -category2label = {v: k for k, v in label2category.items()} +mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} +mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} \ No newline at end of file diff --git a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py index 61e8a04..344d69a 100644 --- a/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py +++ b/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -58,7 +58,8 @@ def forward(self, outputs, orig_target_sizes): # TODO if self.remap_mscoco_category: - labels = torch.tensor([self.mscoco_label_category_map[int(x.item()) + 1] for x in labels.flatten()])\ + from ...data.coco import mscoco_label2category + labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\ .to(boxes.device).reshape(labels.shape) results = [] @@ -77,90 +78,3 @@ def deploy(self, ): @property def iou_types(self, ): return ('bbox', ) - - - @property - def mscoco_label_category_map(self, ): - return { - 1: 1, - 2: 2, - 3: 3, - 4: 4, - 5: 5, - 6: 6, - 7: 7, - 8: 8, - 9: 9, - 10: 10, - 11: 11, - 12: 13, - 13: 14, - 14: 15, - 15: 16, - 16: 17, - 17: 18, - 18: 19, - 19: 20, - 20: 21, - 21: 22, - 22: 23, - 23: 24, - 24: 25, - 25: 27, - 26: 28, - 27: 31, - 28: 32, - 29: 33, - 30: 34, - 31: 35, - 32: 36, - 33: 37, - 34: 38, - 35: 39, - 36: 40, - 37: 41, - 38: 42, - 39: 43, - 40: 44, - 41: 46, - 42: 47, - 43: 48, - 44: 49, - 45: 50, - 46: 51, - 47: 52, - 48: 53, - 49: 54, - 50: 55, - 51: 56, - 52: 57, - 53: 58, - 54: 59, - 55: 60, - 56: 61, - 57: 62, - 58: 63, - 59: 64, - 60: 65, - 61: 67, - 62: 70, - 63: 72, - 64: 73, - 65: 74, - 66: 75, - 67: 76, - 68: 77, - 69: 78, - 70: 79, - 71: 80, - 72: 81, - 73: 82, - 74: 84, - 75: 85, - 76: 86, - 77: 87, - 78: 88, - 79: 89, - 80: 90 - } -