From 5bcc337ab430c4fd09f91c13e40d186c2e023b6b Mon Sep 17 00:00:00 2001 From: Sergey Zagoruyko Date: Wed, 8 Jan 2020 05:22:32 -0800 Subject: [PATCH] fix for loading models with num_batches_tracked in frozen bn --- torchvision/ops/misc.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 8e6a5aea2e3..c21d630ee44 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -130,6 +130,16 @@ def __init__(self, n): self.register_buffer("running_mean", torch.zeros(n)) self.register_buffer("running_var", torch.ones(n)) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + def forward(self, x): # move reshapes to the beginning # to make it fuser-friendly