Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the consistency of pre-processing with yolov5 #293

Merged
merged 49 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
45dc664
Modify copyrights
zhiqwang Jan 26, 2022
bf1dd4e
Fix docstring
zhiqwang Jan 26, 2022
fe76c96
add letterbox function
Jan 26, 2022
ed406a7
Apply pre-commit format
Jan 26, 2022
339f07e
Fix importing torch.nn.functional
zhiqwang Jan 27, 2022
92ab438
Add unittest for letterbox
zhiqwang Jan 27, 2022
1c58ac7
Move letterbox into yolort.models.transform
zhiqwang Jan 27, 2022
8595755
Cleanup letterbox
zhiqwang Jan 27, 2022
3e2a0d8
Move 255 into the letterbox method
zhiqwang Jan 28, 2022
8e7b4a4
Add more test for letterbox
zhiqwang Jan 28, 2022
25aae82
Add fill_color to be filled in NestedTensor
zhiqwang Jan 28, 2022
69e2e49
Use numpy testing in test_letterbox
zhiqwang Jan 28, 2022
3a1f4cd
Add attributes size_divisible and fill_color in YOLOTransform
zhiqwang Jan 29, 2022
0926c15
Fixing docstrings
zhiqwang Jan 29, 2022
c83e23f
Adopt torchvison's structure
zhiqwang Jan 29, 2022
855c61d
Fix testing for batch_images
zhiqwang Jan 29, 2022
a58e187
Fixing docstrings
zhiqwang Jan 29, 2022
ee947ec
Remove fixed_size from YOLOTransform
zhiqwang Jan 29, 2022
170f108
Fixing parameters and docstrings in YOLOv5
zhiqwang Jan 30, 2022
028073a
Updating with torchvision
zhiqwang Jan 30, 2022
6503980
Padding into 2 sides in YOLOTransform
zhiqwang Jan 30, 2022
21f3fe6
Fixing type annotation
zhiqwang Jan 31, 2022
62e50b1
Move up parameter num_classes
zhiqwang Jan 31, 2022
3496b7b
Fix fill_color scale in YOLOTransform
zhiqwang Jan 31, 2022
210ac5c
Rename resize_boxes to scale_coords
zhiqwang Feb 1, 2022
1a1a4c6
Fix classmethod load_from_yolov5 in YOLOv5
zhiqwang Feb 1, 2022
98778aa
Fix the image_sizes in NestedTensor
zhiqwang Feb 1, 2022
ce6a951
Fixing docstrings
zhiqwang Feb 1, 2022
791fc7a
Fix size_divisible in test_load_from_yolov5
zhiqwang Feb 1, 2022
9652ae8
Minor fixes for type annotation
zhiqwang Feb 1, 2022
cfb250e
Fixing exporting ONNX
zhiqwang Feb 1, 2022
e0a770a
Fixing TestONNXExporter
zhiqwang Feb 1, 2022
4a09f34
Cleanup padding rule when batching
zhiqwang Feb 2, 2022
ecc96a5
Minor fixes
zhiqwang Feb 2, 2022
9af5443
Set fail-fast to false in GH Actions
zhiqwang Feb 2, 2022
1b7383e
Minor fixes
zhiqwang Feb 2, 2022
36ba3aa
Fixing type casting and annotations
zhiqwang Feb 2, 2022
0338fb8
Using consistent torch.int32 when casting in YOLOTransform
zhiqwang Feb 2, 2022
f285ac4
Fixing batching inference
zhiqwang Feb 3, 2022
a39710c
Cleanup ONNX Tester
zhiqwang Feb 3, 2022
11fbab3
Fixing casting Tensor.item() when exporting ONNX
zhiqwang Feb 3, 2022
f51fd33
Just test ONNX export without postprocess
zhiqwang Feb 4, 2022
2be9407
Apply pre-commit
zhiqwang Feb 4, 2022
614a40f
Fixing shape inference when tracing
zhiqwang Feb 4, 2022
15c60ed
Add types to letterbox
zhiqwang Feb 4, 2022
af38239
Cleanup YOLOTransform
zhiqwang Feb 4, 2022
b95e357
Fixing YOLOTransform letterbox batching
zhiqwang Feb 4, 2022
8bc3e78
Updating Intuition for yolort Notebook
zhiqwang Feb 4, 2022
3394321
Minor fixes
zhiqwang Feb 4, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
Unittest:
runs-on: ${{ matrix.image }}
strategy:
fail-fast: false
matrix:
image: [ 'ubuntu-latest' ]
torch: [ 'PyTorch 1.9.1+cpu', 'PyTorch 1.10.2+cpu' ]
Expand Down
137 changes: 79 additions & 58 deletions notebooks/inference-pytorch-export-libtorch.ipynb

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.io import read_image
from yolort.data import COCOEvaluator, DetectionDataModule, _helper as data_helper
from yolort.models import yolov5s
from yolort.models.transform import nested_tensor_from_tensor_list
from yolort.models.transform import YOLOTransform
from yolort.models.yolo import yolov5_darknet_pan_s_r31


Expand All @@ -28,9 +28,10 @@ def test_train_with_vanilla_model():
img_tensor = default_loader(img_name)
assert img_tensor.ndim == 3
# Add a dummy image to train
img_dummy = torch.rand((3, 416, 360), dtype=torch.float32)
img_dummy = torch.rand((3, 1080, 810), dtype=torch.float32)

images = nested_tensor_from_tensor_list([img_tensor, img_dummy])
yolo_transform = YOLOTransform(640, 640)
images = yolo_transform.batch_images([img_tensor, img_dummy])
targets = torch.tensor(
[
[0, 7, 0.3790, 0.5487, 0.3220, 0.2047],
Expand Down
37 changes: 14 additions & 23 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,15 @@ def test_torchscript(arch):


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
"arch, size_divisible, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
("yolov5s", 32, "r4.0", "v4.0", "9ca9a642"),
("yolov5n", 32, "r6.0", "v6.0", "649e089f"),
("yolov5s", 32, "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", 64, "r6.0", "v6.0", "beecbbae"),
],
)
def test_load_from_yolov5(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
def test_load_from_yolov5(arch, size_divisible, version, upstream_version, hash_prefix):
img_path = "test/assets/bus.jpg"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
Expand All @@ -366,6 +361,7 @@ def test_load_from_yolov5(
checkpoint_path,
score_thresh=score_thresh,
version=version,
size_divisible=size_divisible,
)
model_yolov5.eval()
out_from_yolov5 = model_yolov5.predict(img_path)
Expand All @@ -388,28 +384,23 @@ def test_load_from_yolov5(


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
"arch, size_divisible, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
("yolov5s", 32, "r4.0", "v4.0", "9ca9a642"),
("yolov5n", 32, "r6.0", "v6.0", "649e089f"),
("yolov5s", 32, "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", 64, "r6.0", "v6.0", "beecbbae"),
],
)
def test_load_from_yolov5_torchscript(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
def test_load_from_yolov5_torchscript(arch, size_divisible, version, upstream_version, hash_prefix):
import cv2
from yolort.utils import read_image_to_tensor
from yolort.v5 import letterbox

# Loading and pre-processing the image
img_path = "test/assets/zidane.jpg"
img_raw = cv2.imread(img_path)
img = letterbox(img_raw, new_shape=(640, 640))[0]
img = letterbox(img_raw, new_shape=(640, 640), stride=size_divisible)[0]
img = read_image_to_tensor(img)

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
Expand Down
32 changes: 32 additions & 0 deletions test/test_models_transform.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) 2022, yolort team. All rights reserved.
import copy

import numpy as np
import pytest
import torch
from yolort.models.transform import YOLOTransform, NestedTensor

Expand All @@ -19,3 +22,32 @@ def test_yolo_transform():
# Test annotations after transformation
torch.testing.assert_close(annotations[0]["boxes"], annotations_copy[0]["boxes"], rtol=0, atol=0)
torch.testing.assert_close(annotations[1]["boxes"], annotations_copy[1]["boxes"], rtol=0, atol=0)


@pytest.mark.parametrize("img_h", [300, 500, 720, 800, 1080, 1280])
@pytest.mark.parametrize("img_w", [300, 500, 720, 800, 1080, 1280])
@pytest.mark.parametrize("auto", [True])
@pytest.mark.parametrize("stride", [32, 64])
def test_letterbox(img_h, img_w, auto, stride):

from yolort.models.transform import _resize_image_and_masks
from yolort.v5 import letterbox

new_shape = (640, 640) # height, width

img_tensor = torch.randint(0, 255, (3, img_h, img_w))
img_numpy = img_tensor.permute(1, 2, 0).numpy().astype("uint8")

yolo_transform = YOLOTransform(new_shape[0], new_shape[1], size_divisible=stride, auto_rectangle=auto)

im3 = img_tensor / 255
im3, _ = _resize_image_and_masks(im3.float(), new_shape)
out1 = yolo_transform.batch_images([im3])

out2 = letterbox(img_numpy, new_shape=new_shape, auto=auto, stride=stride)

aug1 = out1[0].numpy()
aug2 = out2[0].astype(np.float32) # uint8 to float32
aug2 = np.transpose(aug2 / 255.0, [2, 0, 1])
assert aug1.shape == aug2.shape
np.testing.assert_allclose(aug1, aug2, rtol=1e-4, atol=1e-2)
Copy link
Owner

@zhiqwang zhiqwang Feb 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch's interpolate operator now only aligns with OpenCV for the float type, and the letterbox implemented by yolov5 operates on uint8, so the precision we set here is relatively low.

Check pytorch/pytorch#5580 (comment) for more details.

93 changes: 41 additions & 52 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,37 @@
"""
Test for exporting model to ONNX and inference with ONNXRuntime
Test for exporting model to ONNX and inference with ONNX Runtime
"""
import io
from pathlib import Path

import pytest
import torch
from PIL import Image
from torch import Tensor
from torchvision import transforms
from torchvision.io import read_image
from torchvision.ops._register_onnx_ops import _onnx_opset_version
from yolort import models
from yolort.utils.image_utils import to_numpy

# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
onnxruntime = pytest.importorskip("onnxruntime")


class TestONNXExporter:
@classmethod
def setUpClass(cls):
torch.manual_seed(123)

def run_model(
self,
model,
inputs_list,
tolerate_small_mismatch=False,
do_constant_folding=True,
dynamic_axes=None,
output_names=None,
input_names=None,
output_names=None,
dynamic_axes=None,
):
"""
The core part of exporting model to ONNX and inference with ONNXRuntime
The core part of exporting model to ONNX and inference with ONNX Runtime
Copy-paste from <https://github.com/pytorch/vision/blob/07fb8ba/test/test_onnx.py#L34>
"""
model.eval()
model = model.eval()

onnx_io = io.BytesIO()
if isinstance(inputs_list[0][-1], dict):
Expand All @@ -50,9 +45,9 @@ def run_model(
onnx_io,
do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
Expand All @@ -62,85 +57,79 @@ def run_model(
test_outputs = model(*test_inputs)
if isinstance(test_outputs, Tensor):
test_outputs = (test_outputs,)
self.ort_validate(onnx_io, test_inputs, test_outputs, tolerate_small_mismatch)
self.ort_validate(onnx_io, test_inputs, test_outputs)

def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):
def ort_validate(self, onnx_io, inputs, outputs):

inputs, _ = torch.jit._flatten(inputs)
outputs, _ = torch.jit._flatten(outputs)

def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()

inputs = list(map(to_numpy, inputs))
outputs = list(map(to_numpy, outputs))

ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
# compute onnxruntime output prediction
# Inference on ONNX Runtime
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs)

for i in range(0, len(outputs)):
try:
torch.testing.assert_close(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
except AssertionError as error:
if tolerate_small_mismatch:
self.assertIn("(0.00%)", str(error), str(error))
else:
raise
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)

def get_image(self, img_name, size):
def get_image(self, img_name):

img_path = Path(__file__).parent.resolve() / "assets" / img_name
image = Image.open(img_path).convert("RGB").resize(size, Image.BILINEAR)
image = read_image(str(img_path)) / 255

return transforms.ToTensor()(image)
return image

def get_test_images(self):
return (
[self.get_image("bus.jpg", (416, 320))],
[self.get_image("zidane.jpg", (352, 480))],
)
return [self.get_image("bus.jpg")], [self.get_image("zidane.jpg")]

@pytest.mark.parametrize(
"arch, upstream_version",
"arch, auto_rectangle, upstream_version",
[
("yolov5s", "r3.1"),
("yolov5m", "r4.0"),
# ("yolov5ts", "r4.0"),
("yolov5s", True, "r3.1"),
("yolov5m", True, "r4.0"),
("yolov5n", True, "r6.0"),
("yolov5n6", True, "r6.0"),
],
)
def test_yolort_export_onnx(self, arch, upstream_version):
def test_yolort_export_onnx(self, arch, auto_rectangle, upstream_version):
images_one, images_two = self.get_test_images()
images_dummy = [torch.ones(3, 100, 100) * 0.3]
images_dummy = [torch.ones(3, 1080, 720) * 0.3]

model = models.__dict__[arch](
upstream_version=upstream_version,
export_friendly=True,
pretrained=True,
size=(640, 640),
auto_rectangle=auto_rectangle,
score_thresh=0.45,
)
model.eval()
model = model.eval()
model(images_one)
# Test exported model on images of different size, or dummy input
self.run_model(
model,
[(images_one,), (images_two,), (images_dummy,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True,
input_names=["images"],
output_names=["scores", "labels", "boxes"],
dynamic_axes={
"images": [1, 2],
"boxes": [0, 1],
"labels": [0],
"scores": [0],
},
)
# Test exported model for an image with no detections on other images
self.run_model(
model,
[(images_dummy,), (images_one,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]},
tolerate_small_mismatch=True,
input_names=["images"],
output_names=["scores", "labels", "boxes"],
dynamic_axes={
"images": [1, 2],
"boxes": [0, 1],
"labels": [0],
"scores": [0],
},
)
6 changes: 3 additions & 3 deletions yolort/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def yolov5n6(upstream_version: str = "r6.0", export_friendly: bool = False, **kw
Default: False.
"""
if upstream_version == "r6.0":
model = YOLOv5(arch="yolov5_darknet_pan_n6_r60", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_n6_r60", size_divisible=64, **kwargs)
else:
raise NotImplementedError("Currently only supports r6.0 version")

Expand All @@ -138,7 +138,7 @@ def yolov5s6(upstream_version: str = "r6.0", export_friendly: bool = False, **kw
Default: False.
"""
if upstream_version == "r6.0":
model = YOLOv5(arch="yolov5_darknet_pan_s6_r60", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_s6_r60", size_divisible=64, **kwargs)
else:
raise NotImplementedError("Currently only supports r5.0 and r6.0 versions")

Expand All @@ -157,7 +157,7 @@ def yolov5m6(upstream_version: str = "r6.0", export_friendly: bool = False, **kw
Default: False.
"""
if upstream_version == "r6.0":
model = YOLOv5(arch="yolov5_darknet_pan_m6_r60", **kwargs)
model = YOLOv5(arch="yolov5_darknet_pan_m6_r60", size_divisible=64, **kwargs)
else:
raise NotImplementedError("Currently only supports r5.0 and r6.0 versions")

Expand Down
Loading