Skip to content

Commit

Permalink
Fixing shape inference when tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Feb 4, 2022
1 parent 2be9407 commit 614a40f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
13 changes: 5 additions & 8 deletions yolort/models/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def forward(
if targets is not None and target_index is not None:
targets[i] = target_index

images = self.batch_images(images)
image_sizes = [img.shape[-2:] for img in images]
images = self.batch_images(images)
image_sizes_list: List[Tuple[int, int]] = []
for image_size in image_sizes:
assert len(image_size) == 2
Expand Down Expand Up @@ -257,13 +257,13 @@ def batch_images(self, images: List[Tensor]) -> Tensor:
def postprocess(
self,
result: List[Dict[str, Tensor]],
image_shapes: List[Tuple[int, int]],
image_shapes: Tensor,
original_image_sizes: List[Tuple[int, int]],
) -> List[Dict[str, Tensor]]:

for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
for i, (pred, o_im_s) in enumerate(zip(result, original_image_sizes)):
boxes = pred["boxes"]
boxes = scale_coords(boxes, im_s, o_im_s)
boxes = scale_coords(boxes, image_shapes, o_im_s)
result[i]["boxes"] = boxes

return result
Expand Down Expand Up @@ -308,14 +308,11 @@ def _resize_image_and_masks(
return image, target


def scale_coords(boxes: Tensor, new_size: Tuple[int, int], original_size: Tuple[int, int]) -> Tensor:
def scale_coords(boxes: Tensor, new_size: Tensor, original_size: Tuple[int, int]) -> Tensor:
"""
Rescale boxes (xyxy) from new_size to original_size
"""
new_size = torch.tensor(new_size, dtype=torch.float32, device=boxes.device)
original_size = torch.tensor(original_size, dtype=torch.float32, device=boxes.device)
gain = torch.min(new_size[0] / original_size[0], new_size[1] / original_size[1])
# wh padding
pad = (new_size[1] - original_size[1] * gain) / 2, (new_size[0] - original_size[0] * gain) / 2
xmin, ymin, xmax, ymax = boxes.unbind(1)

Expand Down
11 changes: 8 additions & 3 deletions yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from typing import Any, List, Dict, Tuple, Optional, Union, Callable

import torch
import torchvision
from pytorch_lightning import LightningModule
from torch import nn, Tensor
from torchvision.io import read_image
from yolort.data import COCOEvaluator, contains_any_tensor

from . import yolo
from ._utils import _evaluate_iou
from .transform import YOLOTransform
from .transform import YOLOTransform, _get_shape_onnx
from .yolo import YOLO

__all__ = ["YOLOv5"]
Expand Down Expand Up @@ -135,8 +136,12 @@ def _forward_impl(
else:
result = outputs

# detections = self.transform.postprocess(result, samples.image_sizes, original_image_sizes)
detections = result
if torchvision._is_tracing():
im_shape = _get_shape_onnx(samples.tensors)
else:
im_shape = torch.tensor(samples.tensors.shape[-2:])

detections = self.transform.postprocess(result, im_shape, original_image_sizes)

if torch.jit.is_scripting():
if not self._has_warned:
Expand Down

0 comments on commit 614a40f

Please sign in to comment.