Skip to content

Commit

Permalink
Fix bug on detection backbones when trainable_layers == 0 (#3906)
Browse files Browse the repository at this point in the history
* Fix a bug when trainable_layers == 0

* Fix same issue on ssd.
  • Loading branch information
datumbox authored May 25, 2021
1 parent 21824ce commit c58d5d1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def mobilenet_backbone(

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

for b in backbone[:freeze_before]:
for parameter in b.parameters():
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

for b in backbone[:freeze_before]:
for parameter in b.parameters():
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/ssdlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

for b in backbone[:freeze_before]:
for parameter in b.parameters():
Expand Down

0 comments on commit c58d5d1

Please sign in to comment.