From 06daa3f289c4613a961acc270ec5ac98e57feb8d Mon Sep 17 00:00:00 2001 From: matteobeltrami Date: Tue, 21 Nov 2023 13:27:55 +0100 Subject: [PATCH] little train optim --- micromind/utils/yolo_helpers.py | 60 ++++++++++++++-------------- recipes/objection_detection/train.py | 3 +- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/micromind/utils/yolo_helpers.py b/micromind/utils/yolo_helpers.py index 436d647..b58dfc3 100644 --- a/micromind/utils/yolo_helpers.py +++ b/micromind/utils/yolo_helpers.py @@ -8,7 +8,6 @@ import types from pathlib import Path import yaml -import numpy as np import cv2 from collections import defaultdict import time @@ -422,9 +421,7 @@ def non_max_suppression( prediction = prediction[0] # select only inference output device = prediction.device - mps = "mps" in device.type # Apple MPS - if mps: # MPS not fully supported yet, convert tensors to CPU before NMS - prediction = prediction.cpu() + bs = prediction.shape[0] # batch size nc = nc or (prediction.shape[1] - 4) # number of classes nm = prediction.shape[1] - nc - 4 @@ -488,8 +485,6 @@ def non_max_suppression( i = i[:max_det] # limit detections output[xi] = x[i] - if mps: - output[xi] = output[xi].to(device) if (time.time() - t) > time_limit: # LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') break # time limit exceeded @@ -497,6 +492,7 @@ def non_max_suppression( return output +@torch.no_grad() def postprocess(preds, img, orig_imgs): """Perform post-processing on the predictions. @@ -518,7 +514,7 @@ def postprocess(preds, img, orig_imgs): A list of post-processed prediction arrays, each containing bounding boxes and associated information. """ - preds = preds + tt1 = time.time() preds = non_max_suppression( prediction=preds, conf_thres=0.25, @@ -527,6 +523,7 @@ def postprocess(preds, img, orig_imgs): max_det=300, multi_label=True, ) + all_preds = [] for i, pred in enumerate(preds): orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs @@ -660,17 +657,17 @@ def clip_boxes(boxes, shape): Arguments --------- - boxes : numpy.ndarray - An array containing bounding boxes in the format [x1, y1, x2, y2]. + boxes : torch.Tensor + A tensor containing bounding boxes in the format [x1, y1, x2, y2]. shape : tuple A tuple representing the shape of the image in the format (height, width). Returns ------- - An array containing the clipped bounding boxes : numpy.ndarray + A tensor containing the clipped bounding boxes : torch.Tensor """ - boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2 - boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2 + boxes[..., [0, 2]] = torch.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = torch.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2 return boxes @@ -685,8 +682,8 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): img1_shape : tuple A tuple representing the shape of the target image in the format (height, width). - boxes : numpy.ndarray or torch.Tensor - An array or tensor containing bounding boxes in the + boxes : torch.Tensor + A tensor containing bounding boxes in the format [x1, y1, x2, y2]. img0_shape : tuple A tuple representing the shape of the source image in the @@ -709,12 +706,13 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None): (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2, ) - boxes_np = boxes.numpy() if isinstance(boxes, torch.Tensor) else boxes - boxes_np[..., [0, 2]] -= pad[0] - boxes_np[..., [1, 3]] -= pad[1] - boxes_np[..., :4] /= gain - boxes_np = clip_boxes(boxes_np, img0_shape) - return torch.tensor(boxes_np) + + boxes[..., [0, 2]] -= pad[0] + boxes[..., [1, 3]] -= pad[1] + boxes /= gain + boxes = clip_boxes(boxes, img0_shape) + + return boxes def xywh2xyxy(x): @@ -728,8 +726,8 @@ def xywh2xyxy(x): Arguments --------- - x : numpy.ndarray or torch.Tensor - An array or tensor containing bounding box coordinates in the + x : torch.Tensor + A tensor containing bounding box coordinates in the format (center_x, center_y, width, height). Returns @@ -742,8 +740,8 @@ def xywh2xyxy(x): wh = x[..., 2:4] # width, height xy1 = xy - wh / 2 # top left x, y xy2 = xy + wh / 2 # bottom right x, y - result = np.concatenate((xy1, xy2), axis=-1) - return torch.Tensor(result) + result = torch.cat((xy1, xy2), dim=-1) + return result def bbox_format(box): @@ -792,8 +790,6 @@ def calculate_iou(box1, box2): float The intersection over union of the two bounding boxes. """ - box1 = bbox_format(box1) - box2 = bbox_format(box2) x1 = torch.max(box1[0], box2[0]) y1 = torch.max(box1[1], box2[1]) @@ -829,10 +825,11 @@ def average_precision(predictions, ground_truth, class_id, iou_threshold=0.5): float The average precision for the specified class. """ - predictions = [p for p in predictions if p[5] == class_id] - ground_truth = [g for g in ground_truth if g[5] == class_id] + predictions = predictions[predictions[:, 5] == class_id] + ground_truth = ground_truth[ground_truth[:, 5] == class_id] - predictions.sort(key=lambda x: x[4], reverse=True) + _, indices = torch.sort(predictions[:, 4], descending=True) + predictions = predictions[indices] tp = torch.zeros(len(predictions)) fp = torch.zeros(len(predictions)) gt_count = len(ground_truth) @@ -846,7 +843,10 @@ def average_precision(predictions, ground_truth, class_id, iou_threshold=0.5): best_gt_idx = j if best_iou > 0: tp[i] = 1 - ground_truth.pop(best_gt_idx) + tmp = torch.ones(ground_truth.shape[0]) + tmp[best_gt_idx] = 0 + ground_truth = ground_truth[tmp.bool()] + # ground_truth.pop(best_gt_idx) else: fp[i] = 1 diff --git a/recipes/objection_detection/train.py b/recipes/objection_detection/train.py index 8204442..be1dc22 100644 --- a/recipes/objection_detection/train.py +++ b/recipes/objection_detection/train.py @@ -223,12 +223,13 @@ def configure_optimizers(self): ) return opt, sched + @torch.no_grad() def mAP(self, pred, batch): batch_size = len(batch["im_file"]) preprocessed_batch = self.preprocess_batch(batch) post_predictions = postprocess( - preds=pred[0].detach().cpu(), img=preprocessed_batch, orig_imgs=batch + preds=pred[0], img=preprocessed_batch, orig_imgs=batch ) batch_bboxes_xyxy = xywh2xyxy(batch["bboxes"])