-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a way to evaluate the descriptor yolo v7 descriptor. * Update descriptor yolo v7 weights.
- Loading branch information
Showing
7 changed files
with
319 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
tools/dnn_training/object_detection/datasets/object_detection_objects365.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
from collections import defaultdict | ||
|
||
import torch | ||
from PIL import Image | ||
|
||
CLASS_COUNT = 365 | ||
COCO_OBJECTS365_CLASS_INDEXES = {0, 46, 5, 58, 114, 55, 116, 65, 21, 40, 176, 127, 249, 24, 56, 139, 92, 78, 99, 96, | ||
144, 295, 178, 180, 38, 39, 13, 43, 194, 219, 119, 173, 154, 137, 113, 145, 146, 204, | ||
8, 35, 10, 88, 84, 93, 26, 112, 82, 265, 104, 141, 152, 234, 143, 150, 97, 2, 50, 25, | ||
75, 98, 153, 37, 73, 115, 132, 106, 64, 163, 149, 277, 81, 133, 18, 94, 30, 169, 328, | ||
226, 239, 156, 165, 177, 206} | ||
|
||
|
||
class ObjectDetectionObjects365(torch.utils.data.Dataset): | ||
def __init__(self, root, split='train', transforms=None, ignored_classes=None): | ||
if ignored_classes is None: | ||
ignored_classes = set() | ||
else: | ||
ignored_classes = set(ignored_classes) | ||
|
||
if split == 'training': | ||
self._image_root = os.path.join(root, 'images', 'train') | ||
self._label_root = os.path.join(root, 'labels', 'train') | ||
elif split == 'validation': | ||
self._image_root = os.path.join(root, 'images', 'val') | ||
self._label_root = os.path.join(root, 'labels', 'val') | ||
else: | ||
raise ValueError('Invalid split') | ||
|
||
self._image_files, self._bboxes = self._list_images(self._image_root, self._label_root, ignored_classes) | ||
self._transforms = transforms | ||
|
||
def _list_images(self, image_path, label_path, ignored_classes): | ||
image_files = os.listdir(image_path) | ||
bboxes = defaultdict(list) | ||
|
||
for image_file in image_files: | ||
with open(os.path.join(label_path, os.path.splitext(image_file)[0] + '.txt'), 'r') as f: | ||
for line in f: | ||
values = line.split(' ') | ||
class_index = int(values[0]) | ||
if class_index in ignored_classes: | ||
continue | ||
|
||
x_center = float(values[1]) | ||
y_center = float(values[2]) | ||
width = float(values[3]) | ||
height = float(values[4]) | ||
|
||
bboxes[image_file].append({ | ||
'class_index': class_index, | ||
'x_center': x_center, | ||
'y_center': y_center, | ||
'width': width, | ||
'height': height | ||
}) | ||
|
||
return image_files, bboxes | ||
|
||
def __len__(self): | ||
return len(self._image_files) | ||
|
||
def __getitem__(self, index): | ||
image_file = self._image_files[index] | ||
image = Image.open(os.path.join(self._image_root, image_file)).convert('RGB') | ||
|
||
initial_width, initial_height = image.size | ||
|
||
target = [] | ||
for i in range(len(self._bboxes[image_file])): | ||
target.append({ | ||
'class_index': self._bboxes[image_file][i]['class_index'], | ||
'x_center': self._bboxes[image_file][i]['x_center'] * initial_width, | ||
'y_center': self._bboxes[image_file][i]['y_center'] * initial_height, | ||
'width': self._bboxes[image_file][i]['width'] * initial_width, | ||
'height': self._bboxes[image_file][i]['height'] * initial_height | ||
}) | ||
|
||
image, target, transforms_metadata = self._transforms(image, target) | ||
metadata = { | ||
'initial_width': initial_width, | ||
'initial_height': initial_height, | ||
'scale': transforms_metadata['scale'], | ||
'offset_x': transforms_metadata['offset_x'], | ||
'offset_y': transforms_metadata['offset_y'] | ||
} | ||
|
||
return image, target, metadata |
46 changes: 46 additions & 0 deletions
46
tools/dnn_training/object_detection/datasets/objects365_detection_transforms.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
import torchvision.transforms.functional as F | ||
|
||
from object_detection.datasets.coco_detection_transforms import _resize_image | ||
from object_detection.datasets.object_detection_objects365 import CLASS_COUNT | ||
|
||
|
||
def _convert_bbox_to_yolo(target, scale, offset_x, offset_y, one_hot_class): | ||
if one_hot_class: | ||
converted_target = {'bbox': torch.zeros(len(target), 4, dtype=torch.float), | ||
'class': torch.zeros(len(target), CLASS_COUNT, dtype=torch.float)} | ||
else: | ||
converted_target = {'bbox': torch.zeros(len(target), 4, dtype=torch.float), | ||
'class': torch.zeros(len(target), dtype=torch.long)} | ||
|
||
for i in range(len(target)): | ||
converted_target['bbox'][i] = torch.tensor([target[i]['x_center'] * scale + offset_x, | ||
target[i]['y_center'] * scale + offset_y, | ||
target[i]['width'] * scale, | ||
target[i]['height'] * scale], dtype=torch.float) | ||
if one_hot_class: | ||
converted_target['class'][i, target[i]['class_index']] = 1.0 | ||
else: | ||
converted_target['class'][i] = target[i]['class_index'] | ||
|
||
return converted_target | ||
|
||
|
||
class Objects365DetectionValidationTransforms: | ||
def __init__(self, image_size, one_hot_class): | ||
self._image_size = image_size | ||
self._one_hot_class = one_hot_class | ||
|
||
def __call__(self, image, target): | ||
resized_image, scale, offset_x, offset_y = _resize_image(image, self._image_size) | ||
resized_image_tensor = F.to_tensor(resized_image) | ||
|
||
if target is not None: | ||
target = _convert_bbox_to_yolo(target, scale, offset_x, offset_y, self._one_hot_class) | ||
|
||
metadata = { | ||
'scale': scale, | ||
'offset_x': offset_x, | ||
'offset_y': offset_y | ||
} | ||
return resized_image_tensor, target, metadata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import argparse | ||
import os | ||
|
||
|
||
import numpy as np | ||
|
||
import torch | ||
|
||
from tqdm import tqdm | ||
|
||
from common.metrics import RocDistancesThresholdsEvaluation | ||
from common.modules import load_checkpoint | ||
|
||
from object_detection.criterions.yolo_v4_loss import calculate_iou | ||
from object_detection.datasets import CocoDetectionValidationTransforms, ObjectDetectionCoco | ||
from object_detection.datasets import Objects365DetectionValidationTransforms, ObjectDetectionObjects365, \ | ||
COCO_OBJECTS365_CLASS_INDEXES | ||
from object_detection.descriptor_yolo_v7 import DescriptorYoloV7 | ||
from object_detection.datasets.object_detection_coco import CLASS_COUNT | ||
from object_detection.filter_yolo_predictions import group_predictions, filter_yolo_predictions | ||
|
||
|
||
COMPARABLE_CONFIDENCE_THRESHOLD = 0.01 | ||
NOT_COMPARABLE_CONFIDENCE_THRESHOLD = 0.25 | ||
NMS_THRESHOLD = 0.45 | ||
NOT_COMPARABLE_IOU_THRESHOLD = 0.5 | ||
|
||
|
||
class CocoDescriptorEvaluation(RocDistancesThresholdsEvaluation): | ||
def __init__(self, embeddings_class_pairs, interval, output_path): | ||
super(CocoDescriptorEvaluation, self).__init__(output_path, thresholds=np.arange(0, 2, 0.0001)) | ||
self._embeddings = torch.stack([p[0] for p in embeddings_class_pairs], dim=0).half() | ||
self._classes = torch.stack([p[1] for p in embeddings_class_pairs], dim=0).to(torch.int16) | ||
self._interval = interval | ||
|
||
if self._embeddings.device.type == 'cuda': | ||
self._embeddings = self._embeddings.half() | ||
|
||
def _calculate_distances(self): | ||
N = self._embeddings.size(0) | ||
distances = torch.zeros(self._calculate_pair_count(N), | ||
dtype=self._embeddings.dtype, | ||
device=self._embeddings.device) | ||
|
||
k = 0 | ||
for i in range(N): | ||
others = self._embeddings[i + 1::self._interval] | ||
distances[k:k + others.size(0)] = (self._embeddings[i].repeat(others.size(0), 1) - others).pow(2).sum(dim=1).sqrt() | ||
k += others.size(0) | ||
|
||
torch.cuda.empty_cache() | ||
return distances[::self._interval] | ||
|
||
def _get_is_same_person_target(self): | ||
N = self._classes.size(0) | ||
is_same_person_target = torch.zeros(self._calculate_pair_count(N), | ||
dtype=torch.bool, | ||
device=self._classes.device) | ||
|
||
k = 0 | ||
for i in range(N): | ||
others = self._classes[i + 1::self._interval] | ||
is_same_person_target[k:k + others.size(0)] = self._classes[i] == others | ||
k += others.size(0) | ||
|
||
torch.cuda.empty_cache() | ||
return is_same_person_target[::self._interval] | ||
|
||
def _calculate_pair_count(self, N): | ||
c = 0 | ||
for i in range(N): | ||
c += self._embeddings[i + 1::self._interval].size(0) | ||
|
||
return c | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Test the specified descriptor yolo model') | ||
parser.add_argument('--use_gpu', action='store_true', help='Use the GPU') | ||
parser.add_argument('--embedding_size', type=int, help='Choose the embedding size', required=True) | ||
parser.add_argument('--checkpoint', type=str, help='Choose the checkpoint file path', required=True) | ||
parser.add_argument('--dataset_root', type=str, help='Choose the coco root path', required=True) | ||
parser.add_argument('--dataset_type', type=str, choices=['coco', 'objects365'], help='Choose the coco root path', | ||
required=True) | ||
parser.add_argument('--comparable', action='store_true', help='Enable comparable results') | ||
parser.add_argument('--output_path', type=str, help='Choose the output path', required=True) | ||
|
||
args = parser.parse_args() | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') | ||
|
||
model = DescriptorYoloV7(CLASS_COUNT, embedding_size=args.embedding_size, class_probs=False) | ||
load_checkpoint(model, args.checkpoint) | ||
|
||
if args.dataset_type == 'coco': | ||
transforms = CocoDetectionValidationTransforms(model.get_image_size(), one_hot_class=False) | ||
dataset = ObjectDetectionCoco(os.path.join(args.dataset_root, 'val2017'), | ||
os.path.join(args.dataset_root, 'instances_val2017.json'), | ||
transforms) | ||
interval = 2 if args.comparable else 1 | ||
elif args.dataset_type == 'objects365': | ||
transforms = Objects365DetectionValidationTransforms(model.get_image_size(), one_hot_class=False) | ||
dataset = ObjectDetectionObjects365(os.path.join(args.dataset_root), | ||
split='validation', | ||
transforms=transforms, | ||
ignored_classes=COCO_OBJECTS365_CLASS_INDEXES) | ||
interval = 1000 if args.comparable else 30 | ||
else: | ||
raise ValueError(f'Invalid dataset ({args.dataset_type})') | ||
|
||
os.makedirs(args.output_path, exist_ok=True) | ||
|
||
|
||
evaluate(model, args.embedding_size, dataset, device, args.comparable, interval, args.output_path) | ||
|
||
|
||
def evaluate(model, embedding_size, dataset, device, comparable, interval, output_path): | ||
model = model.to(device) | ||
model.eval() | ||
|
||
embeddings_class_pairs = [] | ||
|
||
bbox_count = 0 | ||
with torch.no_grad(): | ||
for image, target, metadata in tqdm(dataset): | ||
target['bbox'] = target['bbox'].to(device) | ||
target['class'] = target['class'].to(device) | ||
|
||
bbox_count += target['bbox'].size(0) | ||
embeddings_class_pairs.extend( | ||
compute_embedding(model, embedding_size, image.to(device), target, comparable)) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
print(f'{len(embeddings_class_pairs)} boxes out of {bbox_count} detected') | ||
coco_descriptor_evaluation = CocoDescriptorEvaluation(embeddings_class_pairs, interval, output_path) | ||
coco_descriptor_evaluation.evaluate() | ||
|
||
|
||
def compute_embedding(model, embedding_size, image_tensor, target, comparable): | ||
predictions = model(image_tensor.unsqueeze(0)) | ||
predictions = group_predictions(predictions)[0] | ||
C = predictions.size(1) | ||
predictions = filter_yolo_predictions(predictions, | ||
confidence_threshold=COMPARABLE_CONFIDENCE_THRESHOLD if comparable else NOT_COMPARABLE_CONFIDENCE_THRESHOLD, | ||
nms_threshold=NMS_THRESHOLD) | ||
|
||
if len(predictions) == 0: | ||
print('Warning: No predictions found') | ||
predicted_boxes = torch.zeros(1, C).to(image_tensor.device) | ||
else: | ||
predicted_boxes = torch.stack(predictions, dim=0) | ||
|
||
embeddings_class_pairs = [] | ||
|
||
for i in range(target['bbox'].size(0)): | ||
target_box = target['bbox'][i] | ||
target_class = target['class'][i] | ||
|
||
ious = calculate_iou(predicted_boxes[:, :4], target_box.repeat(len(predicted_boxes), 1)) | ||
best_index = ious.argmax() | ||
best_predicted_box = predicted_boxes[best_index] | ||
|
||
if comparable or ious[best_index] > NOT_COMPARABLE_IOU_THRESHOLD: | ||
embeddings_class_pairs.append((best_predicted_box[-embedding_size:], target_class)) | ||
|
||
return embeddings_class_pairs | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |