Skip to content

Commit

Permalink
torchvision QAT tutorial: update for QAT with DDP
Browse files Browse the repository at this point in the history
Summary:

We've made two recent changes to QAT in PyTorch core:
1. add support for SyncBatchNorm
2. make eager mode QAT prepare scripts respect device affinity

This PR updates the torchvision QAT reference script to take
advantage of both of these.  This should be landed after
pytorch/pytorch#39337 (the last PT
fix) to avoid compatibility issues.

Test Plan:

```
python -m torch.distributed.launch
  --nproc_per_node 8
  --use_env
  references/classification/train_quantization.py
  --data-path {imagenet1k_subset}
  --output-dir {tmp}
  --sync-bn
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Jun 2, 2020
1 parent 34810c0 commit 3a07ab7
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def main(args):
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
model.fuse_model()
Expand All @@ -65,7 +66,8 @@ def main(args):
step_size=args.lr_step_size,
gamma=args.lr_gamma)

model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss()
model_without_ddp = model
Expand Down Expand Up @@ -224,6 +226,12 @@ def parse_args():
It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down

0 comments on commit 3a07ab7

Please sign in to comment.