Skip to content

Commit

Permalink
Fix train_bn arg
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Feb 17, 2021
1 parent 23a594e commit efcff12
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions pl_examples/domain_templates/computer_vision_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, milestones: tuple = (5, 10), train_bn: bool = False):
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule):
self.freeze(modules=pl_module.feature_extractor, train_bn=False)
self.freeze(modules=pl_module.feature_extractor, train_bn=self.train_bn)

def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
if epoch == self.milestones[0]:
Expand Down Expand Up @@ -164,21 +164,19 @@ class TransferLearningModel(pl.LightningModule):
def __init__(
self,
backbone: str = "resnet50",
train_bn: bool = True,
milestones: tuple = (5, 10),
batch_size: int = 32,
lr: float = 1e-2,
lr_scheduler_gamma: float = 1e-1,
num_workers: int = 6,
**kwargs,
**_,
) -> None:
"""
Args:
dl_path: Path where the data will be downloaded
"""
super().__init__()
self.backbone = backbone
self.train_bn = train_bn
self.milestones = milestones
self.batch_size = batch_size
self.lr = lr
Expand Down Expand Up @@ -334,7 +332,7 @@ def main(args: argparse.Namespace) -> None:
dl_path=os.path.join(args.root_data_path, 'data'), batch_size=args.batch_size, num_workers=args.num_workers
)
model = TransferLearningModel(**vars(args))
finetuning_callback = MilestonesFinetuning(milestones=args.milestones)
finetuning_callback = MilestonesFinetuning(milestones=args.milestones, train_bn=args.train_bn)

trainer = pl.Trainer(
weights_summary=None,
Expand Down

0 comments on commit efcff12

Please sign in to comment.