From 3a07ab78ba6cf7f30316535bd8aa764c17b3c0e8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 1 Jun 2020 10:57:04 -0700 Subject: [PATCH] torchvision QAT tutorial: update for QAT with DDP Summary: We've made two recent changes to QAT in PyTorch core: 1. add support for SyncBatchNorm 2. make eager mode QAT prepare scripts respect device affinity This PR updates the torchvision QAT reference script to take advantage of both of these. This should be landed after https://github.com/pytorch/pytorch/pull/39337 (the last PT fix) to avoid compatibility issues. Test Plan: ``` python -m torch.distributed.launch --nproc_per_node 8 --use_env references/classification/train_quantization.py --data-path {imagenet1k_subset} --output-dir {tmp} --sync-bn ``` Reviewers: Subscribers: Tasks: Tags: --- references/classification/train_quantization.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 511d2ba1adb..b50e7455d19 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -51,6 +51,7 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) + model.to(device) if not (args.test_only or args.post_training_quantize): model.fuse_model() @@ -65,7 +66,8 @@ def main(args): step_size=args.lr_step_size, gamma=args.lr_gamma) - model.to(device) + if args.distributed and args.sync_bn: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss() model_without_ddp = model @@ -224,6 +226,12 @@ def parse_args(): It also serializes the transforms", action="store_true", ) + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) parser.add_argument( "--test-only", dest="test_only",