Skip to content

Commit

Permalink
Minor additions
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa committed Mar 25, 2019
1 parent 5d0ecca commit fea2b19
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def main(args):

model = models.get_model(args.model, args.backbone, num_classes=dataset.num_classes, aux=args.aux_loss)
model.to(device)
model = torch.nn.utils.convert_sync_batchnorm(model)
if args.distributed:
model = torch.nn.utils.convert_sync_batchnorm(model)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])

model_without_ddp = model
if args.distributed:
Expand Down Expand Up @@ -178,6 +183,7 @@ def main(args):
dest='weight_decay')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--local_rank', default=0, type=int, help='print frequency')

args = parser.parse_args()
Expand Down

0 comments on commit fea2b19

Please sign in to comment.