Skip to content

Commit

Permalink
fix for loading models with num_batches_tracked in frozen bn
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko committed Jan 10, 2020
1 parent be6dd47 commit 5bcc337
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def __init__(self, n):
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
Expand Down

0 comments on commit 5bcc337

Please sign in to comment.