From cb21b7539324234dfd8bfdbf8c5ffbe1336f1850 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 10 Mar 2022 12:43:47 +0100 Subject: [PATCH] PyTorch 1.11.0 compatibility updates Resolves `AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'` first raised in https://github.com/ultralytics/yolov5/issues/5499 and observed in all CI runs on just-released PyTorch 1.11.0. --- models/experimental.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/models/experimental.py b/models/experimental.py index 81fc9bb222..365bcf4cbf 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -94,21 +94,22 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: ckpt = torch.load(attempt_download(w), map_location=map_location) # load - if fuse: - model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model - else: - model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse + ckpt = (ckpt['ema'] or ckpt['model']).float() # FP32 model + model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode # Compatibility updates for m in model.modules(): - if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: - m.inplace = inplace # pytorch 1.7.0 compatibility - if type(m) is Detect: + t = type(m) + if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model): + m.inplace = inplace # torch 1.7.0 compatibility + if t is Detect: if not isinstance(m.anchor_grid, list): # new Detect Layer compatibility delattr(m, 'anchor_grid') setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl) - elif type(m) is Conv: - m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + elif t is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + elif t is Conv: + m._non_persistent_buffers_set = set() # torch 1.6.0 compatibility if len(model) == 1: return model[-1] # return model @@ -118,3 +119,4 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True): setattr(model, k, getattr(model[-1], k)) model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride return model # return ensemble +