diff --git a/yolort/models/anchor_utils.py b/yolort/models/anchor_utils.py index b444df59..6aff8aaa 100644 --- a/yolort/models/anchor_utils.py +++ b/yolort/models/anchor_utils.py @@ -5,6 +5,8 @@ import torch from torch import nn, Tensor +from yolort.utils import check_version + class AnchorGenerator(nn.Module): def __init__(self, strides: List[int], anchor_grids: List[List[float]]): @@ -29,7 +31,10 @@ def _generate_grids( widths = torch.arange(width, dtype=torch.int32, device=device).to(dtype=dtype) heights = torch.arange(height, dtype=torch.int32, device=device).to(dtype=dtype) - shift_y, shift_x = torch.meshgrid(heights, widths) + if check_version(torch.__version__, "1.10.0"): + shift_y, shift_x = torch.meshgrid(heights, widths, indexing="ij") + else: + shift_y, shift_x = torch.meshgrid(heights, widths) grid = torch.stack((shift_x, shift_y), 2).expand((1, self.num_anchors, height, width, 2)) grids.append(grid) diff --git a/yolort/v5/helper.py b/yolort/v5/helper.py index 2677e6d7..a2a92be6 100644 --- a/yolort/v5/helper.py +++ b/yolort/v5/helper.py @@ -71,11 +71,13 @@ def load_yolov5_model(checkpoint_path: str, fuse: bool = False): model = ckpt["ema" if ckpt.get("ema") else "model"].float().eval() # Compatibility updates - for m in model.modules(): - if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: - if isinstance(m, Detect): - if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility - delattr(m, "anchor_grid") - setattr(m, "anchor_grid", [torch.zeros(1)] * m.nl) + for sub_module in model.modules(): + if type(sub_module) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: + if isinstance(sub_module, Detect): + if not isinstance(sub_module.anchor_grid, list): # new Detect Layer compatibility + delattr(sub_module, "anchor_grid") + setattr(sub_module, "anchor_grid", [torch.zeros(1)] * sub_module.nl) + elif isinstance(sub_module, nn.Upsample) and not hasattr(sub_module, "recompute_scale_factor"): + sub_module.recompute_scale_factor = None # torch 1.11.0 compatibility return model