Skip to content

Commit

Permalink
Miscellaneous improvements to the classification reference scripts (#894
Browse files Browse the repository at this point in the history
)

* Miscellaneous improvements to the classification reference scritps

* Fix lint
  • Loading branch information
fmassa authored May 8, 2019
1 parent 43ab2fe commit ae81313
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 31 deletions.
104 changes: 77 additions & 27 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,18 @@ def evaluate(model, criterion, data_loader, device):
return metric_logger.acc1.global_avg


def _get_cache_path(filepath):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path


def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)
print(args)

Expand All @@ -76,28 +87,45 @@ def main(args):

print("Loading training data")
st = time.time()
scale = (0.08, 1.0)
if args.model == 'mobilenet_v2':
scale = (0.2, 1.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224, scale=scale),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)

print("Loading validation data")
dataset_test = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
if args.distributed:
Expand All @@ -118,7 +146,7 @@ def main(args):
print("Creating model")
model = torchvision.models.__dict__[args.model]()
model.to(device)
if args.distributed:
if args.distributed and args.sync_bn:
model = torch.nn.utils.convert_sync_batchnorm(model)

model_without_ddp = model
Expand All @@ -131,41 +159,47 @@ def main(args):
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

# if using mobilenet, step_size=2 and gamma=0.94
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return

print("Start training")
start_time = time.time()
for epoch in range(args.epochs):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
lr_scheduler.step()
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
utils.save_on_master({
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'args': args},
'epoch': epoch,
'args': args}
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'checkpoint.pth'))

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))


if __name__ == "__main__":
def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training')

Expand All @@ -188,6 +222,20 @@ def main(args):
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument(
"--cache-dataset",
dest="cache_dataset",
help="Cache the datasets for quicker initialization. It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
Expand All @@ -202,7 +250,9 @@ def main(args):

args = parser.parse_args()

if args.output_dir:
utils.mkdir(args.output_dir)
return args


if __name__ == "__main__":
args = parse_args()
main(args)
10 changes: 6 additions & 4 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,15 @@ def save_on_master(*args, **kwargs):


def init_distributed_mode(args):
if 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print('Not using distributed mode')
args.distributed = False
Expand Down

0 comments on commit ae81313

Please sign in to comment.