Skip to content

Commit

Permalink
little train optim
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobeltrami committed Nov 21, 2023
1 parent 5309864 commit 06daa3f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 31 deletions.
60 changes: 30 additions & 30 deletions micromind/utils/yolo_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import types
from pathlib import Path
import yaml
import numpy as np
import cv2
from collections import defaultdict
import time
Expand Down Expand Up @@ -422,9 +421,7 @@ def non_max_suppression(
prediction = prediction[0] # select only inference output

device = prediction.device
mps = "mps" in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()

bs = prediction.shape[0] # batch size
nc = nc or (prediction.shape[1] - 4) # number of classes
nm = prediction.shape[1] - nc - 4
Expand Down Expand Up @@ -488,15 +485,14 @@ def non_max_suppression(
i = i[:max_det] # limit detections

output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
# LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
break # time limit exceeded

return output


@torch.no_grad()
def postprocess(preds, img, orig_imgs):
"""Perform post-processing on the predictions.
Expand All @@ -518,7 +514,7 @@ def postprocess(preds, img, orig_imgs):
A list of post-processed prediction arrays, each containing bounding
boxes and associated information.
"""
preds = preds
tt1 = time.time()
preds = non_max_suppression(
prediction=preds,
conf_thres=0.25,
Expand All @@ -527,6 +523,7 @@ def postprocess(preds, img, orig_imgs):
max_det=300,
multi_label=True,
)

all_preds = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
Expand Down Expand Up @@ -660,17 +657,17 @@ def clip_boxes(boxes, shape):
Arguments
---------
boxes : numpy.ndarray
An array containing bounding boxes in the format [x1, y1, x2, y2].
boxes : torch.Tensor
A tensor containing bounding boxes in the format [x1, y1, x2, y2].
shape : tuple
A tuple representing the shape of the image in the format (height, width).
Returns
-------
An array containing the clipped bounding boxes : numpy.ndarray
A tensor containing the clipped bounding boxes : torch.Tensor
"""
boxes[..., [0, 2]] = np.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
boxes[..., [1, 3]] = np.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
boxes[..., [0, 2]] = torch.clip(boxes[..., [0, 2]], 0, shape[1]) # x1, x2
boxes[..., [1, 3]] = torch.clip(boxes[..., [1, 3]], 0, shape[0]) # y1, y2
return boxes


Expand All @@ -685,8 +682,8 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
img1_shape : tuple
A tuple representing the shape of the target image in the
format (height, width).
boxes : numpy.ndarray or torch.Tensor
An array or tensor containing bounding boxes in the
boxes : torch.Tensor
A tensor containing bounding boxes in the
format [x1, y1, x2, y2].
img0_shape : tuple
A tuple representing the shape of the source image in the
Expand All @@ -709,12 +706,13 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
(img1_shape[1] - img0_shape[1] * gain) / 2,
(img1_shape[0] - img0_shape[0] * gain) / 2,
)
boxes_np = boxes.numpy() if isinstance(boxes, torch.Tensor) else boxes
boxes_np[..., [0, 2]] -= pad[0]
boxes_np[..., [1, 3]] -= pad[1]
boxes_np[..., :4] /= gain
boxes_np = clip_boxes(boxes_np, img0_shape)
return torch.tensor(boxes_np)

boxes[..., [0, 2]] -= pad[0]
boxes[..., [1, 3]] -= pad[1]
boxes /= gain
boxes = clip_boxes(boxes, img0_shape)

return boxes


def xywh2xyxy(x):
Expand All @@ -728,8 +726,8 @@ def xywh2xyxy(x):
Arguments
---------
x : numpy.ndarray or torch.Tensor
An array or tensor containing bounding box coordinates in the
x : torch.Tensor
A tensor containing bounding box coordinates in the
format (center_x, center_y, width, height).
Returns
Expand All @@ -742,8 +740,8 @@ def xywh2xyxy(x):
wh = x[..., 2:4] # width, height
xy1 = xy - wh / 2 # top left x, y
xy2 = xy + wh / 2 # bottom right x, y
result = np.concatenate((xy1, xy2), axis=-1)
return torch.Tensor(result)
result = torch.cat((xy1, xy2), dim=-1)
return result


def bbox_format(box):
Expand Down Expand Up @@ -792,8 +790,6 @@ def calculate_iou(box1, box2):
float
The intersection over union of the two bounding boxes.
"""
box1 = bbox_format(box1)
box2 = bbox_format(box2)

x1 = torch.max(box1[0], box2[0])
y1 = torch.max(box1[1], box2[1])
Expand Down Expand Up @@ -829,10 +825,11 @@ def average_precision(predictions, ground_truth, class_id, iou_threshold=0.5):
float
The average precision for the specified class.
"""
predictions = [p for p in predictions if p[5] == class_id]
ground_truth = [g for g in ground_truth if g[5] == class_id]
predictions = predictions[predictions[:, 5] == class_id]
ground_truth = ground_truth[ground_truth[:, 5] == class_id]

predictions.sort(key=lambda x: x[4], reverse=True)
_, indices = torch.sort(predictions[:, 4], descending=True)
predictions = predictions[indices]
tp = torch.zeros(len(predictions))
fp = torch.zeros(len(predictions))
gt_count = len(ground_truth)
Expand All @@ -846,7 +843,10 @@ def average_precision(predictions, ground_truth, class_id, iou_threshold=0.5):
best_gt_idx = j
if best_iou > 0:
tp[i] = 1
ground_truth.pop(best_gt_idx)
tmp = torch.ones(ground_truth.shape[0])
tmp[best_gt_idx] = 0
ground_truth = ground_truth[tmp.bool()]
# ground_truth.pop(best_gt_idx)
else:
fp[i] = 1

Expand Down
3 changes: 2 additions & 1 deletion recipes/objection_detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,13 @@ def configure_optimizers(self):
)
return opt, sched

@torch.no_grad()
def mAP(self, pred, batch):
batch_size = len(batch["im_file"])

preprocessed_batch = self.preprocess_batch(batch)
post_predictions = postprocess(
preds=pred[0].detach().cpu(), img=preprocessed_batch, orig_imgs=batch
preds=pred[0], img=preprocessed_batch, orig_imgs=batch
)

batch_bboxes_xyxy = xywh2xyxy(batch["bboxes"])
Expand Down

0 comments on commit 06daa3f

Please sign in to comment.