-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds bounding boxes conversion (#2710)
* adds boxes conversion * adds documentation * adds xywh tests * fixes small typo * adds tests * Remove sphinx theme * corrects assertions * cleans code as per suggestion Signed-off-by: Aditya Oke <okeaditya315@gmail.com> * reverts assertion * fixes to assertEqual * fixes inplace operations * Adds docstrings * added documentation * changes tests * moves code to box_convert * adds more tests * Apply suggestions from code review Let's leave those changes to a separate PR * fixes documentation Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
- Loading branch information
1 parent
786ec32
commit e70c91a
Showing
5 changed files
with
239 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
from torch.jit.annotations import Tuple | ||
from torch import Tensor | ||
import torchvision | ||
|
||
|
||
def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format. | ||
(cx, cy) refers to center of bounding box | ||
(w, h) are width and height of bounding box | ||
Arguments: | ||
boxes (Tensor[N, 4]): boxes in (cx, cy, w, h) format which will be converted. | ||
Returns: | ||
boxes (Tensor(N, 4)): boxes in (x1, y1, x2, y2) format. | ||
""" | ||
# We need to change all 4 of them so some temporary variable is needed. | ||
cx, cy, w, h = boxes.unbind(-1) | ||
x1 = cx - 0.5 * w | ||
y1 = cy - 0.5 * h | ||
x2 = cx + 0.5 * w | ||
y2 = cy + 0.5 * h | ||
|
||
boxes = torch.stack((x1, y1, x2, y2), dim=-1) | ||
|
||
return boxes | ||
|
||
|
||
def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (x1, y1, x2, y2) format to (cx, cy, w, h) format. | ||
(x1, y1) refer to top left of bounding box | ||
(x2, y2) refer to bottom right of bounding box | ||
Arguments: | ||
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format which will be converted. | ||
Returns: | ||
boxes (Tensor(N, 4)): boxes in (cx, cy, w, h) format. | ||
""" | ||
x1, y1, x2, y2 = boxes.unbind(-1) | ||
cx = (x1 + x2) / 2 | ||
cy = (y1 + y2) / 2 | ||
w = x2 - x1 | ||
h = y2 - y1 | ||
|
||
boxes = torch.stack((cx, cy, w, h), dim=-1) | ||
|
||
return boxes | ||
|
||
|
||
def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format. | ||
(x, y) refers to top left of bouding box. | ||
(w, h) refers to width and height of box. | ||
Arguments: | ||
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted. | ||
Returns: | ||
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format. | ||
""" | ||
x, y, w, h = boxes.unbind(-1) | ||
boxes = torch.stack([x, y, x + w, y + h], dim=-1) | ||
return boxes | ||
|
||
|
||
def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (x1, y1, x2, y2) format to (x, y, w, h) format. | ||
(x1, y1) refer to top left of bounding box | ||
(x2, y2) refer to bottom right of bounding box | ||
Arguments: | ||
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) which will be converted. | ||
Returns: | ||
boxes (Tensor[N, 4]): boxes in (x, y, w, h) format. | ||
""" | ||
x1, y1, x2, y2 = boxes.unbind(-1) | ||
x2 = x2 - x1 # x2 - x1 | ||
y2 = y2 - y1 # y2 - y1 | ||
boxes = torch.stack((x1, y1, x2, y2), dim=-1) | ||
return boxes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters