From 55d451dc2614ae72f2806a6f46d52778bc774bc6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 18 May 2020 09:37:38 -0700 Subject: [PATCH] vision classification QAT tutorial: fix for DDP (redo) Summary: Redo of https://github.com/pytorch/vision/pull/2191 Makes the classification QAT tutorial not crash when used with DDP. There were two issues: 1. the model was moved to GPU before the observers were added, and they are created on CPU. In the context of this repo, the fix is to finalize the model before moving to GPU. We can potentially follow up with a better error message in the future, in a separate PR. 2. the QAT conversion was running on the DDP'ed model, which had various problems. The fix is to unwrap the model from DDP before cloning it for evaluation. There is still work to do on verifying that BN is working correctly in QAT + DDP, but saving that for a separate PR. Test Plan: ``` python -m torch.distributed.launch --use_env references/classification/train_quantization.py --data-path {path_to_imagenet_1k} --output_dir {output_dir} ``` Reviewers: Subscribers: Tasks: Tags: --- references/classification/train_quantization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index e59b8d4a64e..511d2ba1adb 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -51,7 +51,6 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) - model.to(device) if not (args.test_only or args.post_training_quantize): model.fuse_model() @@ -66,6 +65,8 @@ def main(args): step_size=args.lr_step_size, gamma=args.lr_gamma) + model.to(device) + criterion = nn.CrossEntropyLoss() model_without_ddp = model if args.distributed: @@ -129,7 +130,7 @@ def main(args): print('Evaluate QAT model') evaluate(model, criterion, data_loader_test, device=device) - quantized_eval_model = copy.deepcopy(model) + quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device('cpu')) torch.quantization.convert(quantized_eval_model, inplace=True)