diff --git a/references/classification/train.py b/references/classification/train.py index c32416c1259..3e35e787c13 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -59,7 +59,18 @@ def evaluate(model, criterion, data_loader, device): return metric_logger.acc1.global_avg +def _get_cache_path(filepath): + import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() + cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") + cache_path = os.path.expanduser(cache_path) + return cache_path + + def main(args): + if args.output_dir: + utils.mkdir(args.output_dir) + utils.init_distributed_mode(args) print(args) @@ -76,28 +87,45 @@ def main(args): print("Loading training data") st = time.time() - scale = (0.08, 1.0) - if args.model == 'mobilenet_v2': - scale = (0.2, 1.0) - dataset = torchvision.datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224, scale=scale), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) + cache_path = _get_cache_path(traindir) + if args.cache_dataset and os.path.exists(cache_path): + # Attention, as the transforms are also cached! + print("Loading dataset_train from {}".format(cache_path)) + dataset, _ = torch.load(cache_path) + else: + dataset = torchvision.datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + if args.cache_dataset: + print("Saving dataset_train to {}".format(cache_path)) + utils.mkdir(os.path.dirname(cache_path)) + utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") - dataset_test = torchvision.datasets.ImageFolder( - valdir, - transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - normalize, - ])) + cache_path = _get_cache_path(valdir) + if args.cache_dataset and os.path.exists(cache_path): + # Attention, as the transforms are also cached! + print("Loading dataset_test from {}".format(cache_path)) + dataset_test, _ = torch.load(cache_path) + else: + dataset_test = torchvision.datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + if args.cache_dataset: + print("Saving dataset_test to {}".format(cache_path)) + utils.mkdir(os.path.dirname(cache_path)) + utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") if args.distributed: @@ -118,7 +146,7 @@ def main(args): print("Creating model") model = torchvision.models.__dict__[args.model]() model.to(device) - if args.distributed: + if args.distributed and args.sync_bn: model = torch.nn.utils.convert_sync_batchnorm(model) model_without_ddp = model @@ -131,7 +159,6 @@ def main(args): optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - # if using mobilenet, step_size=2 and gamma=0.94 lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) if args.resume: @@ -139,6 +166,7 @@ def main(args): model_without_ddp.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + args.start_epoch = checkpoint['epoch'] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -146,26 +174,32 @@ def main(args): print("Start training") start_time = time.time() - for epoch in range(args.epochs): + for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) lr_scheduler.step() train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: - utils.save_on_master({ + checkpoint = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), - 'args': args}, + 'epoch': epoch, + 'args': args} + utils.save_on_master( + checkpoint, os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) -if __name__ == "__main__": +def parse_args(): import argparse parser = argparse.ArgumentParser(description='PyTorch Classification Training') @@ -188,6 +222,20 @@ def main(args): parser.add_argument('--print-freq', default=10, type=int, help='print frequency') parser.add_argument('--output-dir', default='.', help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument( + "--cache-dataset", + dest="cache_dataset", + help="Cache the datasets for quicker initialization. It also serializes the transforms", + action="store_true", + ) + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) parser.add_argument( "--test-only", dest="test_only", @@ -202,7 +250,9 @@ def main(args): args = parser.parse_args() - if args.output_dir: - utils.mkdir(args.output_dir) + return args + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/references/classification/utils.py b/references/classification/utils.py index 03ea272fcf6..72d15067b30 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -214,13 +214,15 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) - args.gpu = args.rank % torch.cuda.device_count() - elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass else: print('Not using distributed mode') args.distributed = False