Skip to content

Commit

Permalink
Variable number of detection heads (#75)
Browse files Browse the repository at this point in the history
* start 1-head

* neck forward optim

* variable number of heads working

* add check on heads

* fix num_classes in map

* fix load voc dataset

* add VOC file

* improved docs

* fix voc download

* fix download coco8

* minor cosmetic

* minor fix to the docs

* fix VOC download

* fix COCO download

* fix download to specific location

* fix COCO download

* update README changelog

* cosmetic

* cosmetic

* more info

---------

Co-authored-by: fpaissan <me@francescopaissan.it>
  • Loading branch information
matteobeltrami and fpaissan authored Dec 17, 2023
1 parent d69bd0e commit 2ba9318
Show file tree
Hide file tree
Showing 8 changed files with 337 additions and 63 deletions.
127 changes: 89 additions & 38 deletions micromind/networks/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,22 @@ class Yolov8Neck(nn.Module):
Arguments
---------
w : float
Width multiple of the Darknet.
r : float
Ratio multiple of the Darknet.
d : float
Depth multiple of the Darknet.
filters : list, optional
List of filter sizes for different layers. Default: [256, 512, 768].
up : list, optional
List of upsampling factors. Default: [2, 2].
heads : list, optional
List indicating whether each detection head is active.
Default: [True, True, True].
d : float, optional
Depth multiple of the Darknet. Default: 1.
"""

def __init__(self, filters=[256, 512, 768], up=[2, 2], d=1):
def __init__(
self, filters=[256, 512, 768], up=[2, 2], heads=[True, True, True], d=1
):
super().__init__()
self.heads = heads
self.up1 = Upsample(up[0], mode="nearest")
self.up2 = Upsample(up[1], mode="nearest")
self.n1 = C2f(
Expand All @@ -369,50 +375,84 @@ def __init__(self, filters=[256, 512, 768], up=[2, 2], d=1):
n=round(3 * d),
shortcut=False,
)
self.n3 = Conv(
c1=int(filters[0]), c2=int(filters[0]), kernel_size=3, stride=2, padding=1
)
self.n4 = C2f(
c1=int(filters[0] + filters[1]),
c2=int(filters[1]),
n=round(3 * d),
shortcut=False,
)
self.n5 = Conv(
c1=int(filters[1]), c2=int(filters[1]), kernel_size=3, stride=2, padding=1
)
self.n6 = C2f(
c1=int(filters[1] + filters[2]),
c2=int(filters[2]),
n=round(3 * d),
shortcut=False,
)
"""
Only if we decide to use the 2nd and 3rd detection head we define
the needed blocks. Otherwise the not needed blocks would be initialized
(and thus would occupy space) but will never be used.
"""
if self.heads[1] or self.heads[2]:
self.n3 = Conv(
c1=int(filters[0]),
c2=int(filters[0]),
kernel_size=3,
stride=2,
padding=1,
)
self.n4 = C2f(
c1=int(filters[0] + filters[1]),
c2=int(filters[1]),
n=round(3 * d),
shortcut=False,
)
if self.heads[2]:
self.n5 = Conv(
c1=int(filters[1]),
c2=int(filters[1]),
kernel_size=3,
stride=2,
padding=1,
)
self.n6 = C2f(
c1=int(filters[1] + filters[2]),
c2=int(filters[2]),
n=round(3 * d),
shortcut=False,
)

def forward(self, p3, p4, p5):
"""Executes YOLOv8 neck.
Arguments
---------
x : tuple
Input to the neck.
p3 : torch.Tensor
First feature map coming from the backbone.
p4 : torch.Tensor
Second feature map coming from the backbone.
p5 : torch.Tensor
Third feature map coming from the backbone.
Returns
-------
Three intermediate representations with different resolutions : list
Three intermediate representations with different resolutions. : List
"""
x = self.up1(p5)
x = torch.cat((x, p4), dim=1)
x = self.n1(x)
h1 = self.up2(x)
h1 = torch.cat((h1, p3), dim=1)
head_1 = self.n2(h1)
h2 = self.n3(head_1)
h2 = torch.cat((h2, x), dim=1)
head_2 = self.n4(h2)
h3 = self.n5(head_2)
h3 = torch.cat((h3, p5), dim=1)
head_3 = self.n6(h3)
return [head_1, head_2, head_3]
return_heads = []

# here we check if the 1st head should be returned
if self.heads[0]:
return_heads.append(head_1)

# here we check if the 2nd head should be executed
if self.heads[1] or self.heads[2]:
h2 = self.n3(head_1)
h2 = torch.cat((h2, x), dim=1)
head_2 = self.n4(h2)
# here we check if the 2nd head should be returned
if self.heads[1]:
return_heads.append(head_2)

# here we check if the 3rd head should beexecuted and returned
if self.heads[2]:
h3 = self.n5(head_2)
h3 = torch.cat((h3, p5), dim=1)
head_3 = self.n6(h3)
return_heads.append(head_3)
return return_heads


class DetectionHead(nn.Module):
Expand All @@ -424,15 +464,23 @@ class DetectionHead(nn.Module):
Number of classes to predict.
filters : tuple
Number of channels of the three inputs of the detection head.
heads : list, optional
List indicating whether each detection head is active.
Default: [True, True, True].
"""

def __init__(self, nc=80, filters=()):
def __init__(self, nc=80, filters=(), heads=[True, True, True]):
super().__init__()
self.reg_max = 16
self.nc = nc
self.nl = len(filters)
self.no = nc + self.reg_max * 4
self.stride = torch.tensor([8.0, 16.0, 32.0], dtype=torch.float16)
assertion_error = """Expected at least one head to be active. \
Please change the `heads` parameter to a valid configuration. \
Every configuration other than [False, False, False] is a valid option."""
assert heads != [False, False, False], " ".join(assertion_error.split())
self.stride = self.stride[torch.tensor(heads)]
c2, c3 = max((16, filters[0] // 4, self.reg_max * 4)), max(
filters[0], min(self.nc, 104)
) # channels
Expand All @@ -454,8 +502,11 @@ def forward(self, x):
Arguments
---------
x : list
x : list[torch.Tensor]
Input to the detection head.
In the YOLOv8 standard implementation it contains the three outputs of
the neck. In a more general case it contains as many tensors as the number
of active heads in the initialization.
Returns
-------
Expand All @@ -473,7 +524,7 @@ def forward(self, x):
)

y = [(i.reshape(x[0].shape[0], self.no, -1)) for i in x]
x_cat = torch.cat((y[0], y[1], y[2]), dim=2)
x_cat = torch.cat(y, dim=2)
box, cls = x_cat[:, : self.reg_max * 4], x_cat[:, self.reg_max * 4 :]
dbox = (
dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1)
Expand Down
12 changes: 8 additions & 4 deletions micromind/utils/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ def load_config(file_path):
if not isinstance(config["train"], list):
train = Path(path / config["train"])
else:
train = [Path(path / p) for p in config["train"]]
train = [Path(path).joinpath(p) for p in config["train"]]
else:
train = None

if "val" in config:
if not isinstance(config["val"], list):
val = Path(path / config["val"])
else:
val = [Path(path / p) for p in config["val"]]
val = [Path(path).joinpath(p) for p in config["val"]]
else:
val = None
# val = Path(path / config["val"]) if "val" in config else None
Expand Down Expand Up @@ -882,7 +882,9 @@ def average_precision(predictions, ground_truth, class_id, iou_threshold=0.5):
return ap.item()


def mean_average_precision(post_predictions, batch, batch_bboxes, iou_threshold=0.5):
def mean_average_precision(
post_predictions, batch, batch_bboxes, num_classes=80, iou_threshold=0.5
):
"""
Calculate the mean average precision (mAP) for all classes in YOLO predictions.
Expand All @@ -896,6 +898,8 @@ def mean_average_precision(post_predictions, batch, batch_bboxes, iou_threshold=
Tensor containing batch bounding boxes.
iou_threshold : float
The IoU threshold for considering a prediction as correct.
num_classes : int
The number of classes of the dataset. Default is 80.
Returns
-------
Expand All @@ -915,7 +919,7 @@ def mean_average_precision(post_predictions, batch, batch_bboxes, iou_threshold=
(bboxes, torch.ones((num_obj, 1)).to(batch["img"].device), classes), dim=1
)

for class_id in range(80):
for class_id in range(num_classes):
ap = average_precision(
post_predictions[batch_el].to(batch["img"].device),
gt,
Expand Down
6 changes: 4 additions & 2 deletions recipes/object_detection/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Object Detection using YOLO

**[1 Dec 2023]** Fix DDP handling and computational graph
**[17 Dec 2023]** Add VOC dataset, selective head option, and instructions for dataset download.<br />
**[1 Dec 2023]** Fix DDP handling and computational graph.

**Disclaimer**: we will shortly release HuggingFace checkpoints for COCO and VOC for both PhiNet and XiNet.

Expand All @@ -13,7 +14,8 @@ To reproduce our results, you can follow these steps:

1. install `micromind` with `pip install git+https://github.com/fpaissan/micromind`
2. install the additional dependencies for this recipe with `pip install -r extra_requirements.txt`
3. start a training!

**Note**: before training, do not start the process using DDP, if you need to download the dataset.

### Training

Expand Down
Loading

0 comments on commit 2ba9318

Please sign in to comment.