Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torchvision QAT tutorial: update for QAT with DDP #2280

Merged
merged 1 commit into from
Jun 3, 2020
Merged

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jun 1, 2020

Summary:

We've made two recent changes to QAT in PyTorch core:

  1. add support for SyncBatchNorm
  2. make eager mode QAT prepare scripts respect device affinity

This PR updates the torchvision QAT reference script to take
advantage of both of these. This should be landed after
pytorch/pytorch#39337 (the last PT
fix) to avoid compatibility issues.

Test Plan:

python -m torch.distributed.launch
  --nproc_per_node 8
  --use_env
  references/classification/train_quantization.py
  --data-path {imagenet1k_subset}
  --output-dir {tmp}
  --sync-bn

Reviewers:

Subscribers:

Tasks:

Tags:

@vkuzo vkuzo requested review from fmassa and raghuramank100 June 1, 2020 18:02
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Wow, I just realized that there might have been an issue with the location of the .to(device) in the previous implementation.

Also, can't you convert the model to SyncBatchNorm before creating the optimizer? It could maybe be safer, if it works?

@@ -65,7 +66,8 @@ def main(args):
step_size=args.lr_step_size,
gamma=args.lr_gamma)

model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point in time, converting to SyncBatchNorm would (erroneously) create some tensors on the CPU, despite the model being on the GPU. I suppose this has been fixed? Otherwise we would need to call model.to(device) again after SyncBatchNorm.

Now that I think a bit more about this, looks like finetuning was wrong because of the location of the to(device) before, because it was done after creating the optimizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, the SyncBatchNorm device affinity was fixed in pytorch/pytorch#38729 and this diff assumes the fix is present in torch.

#2230 (by me) originally moved the to(device) down to unbreak the tutorial, because at that time device affinity was broken with QAT. This PR moves it back to the original place, since the bugs are fixed.

@vkuzo vkuzo force-pushed the qat_ddp_20200601 branch from 0d6b665 to 3a07ab7 Compare June 2, 2020 18:37
Summary:

We've made two recent changes to QAT in PyTorch core:
1. add support for SyncBatchNorm
2. make eager mode QAT prepare scripts respect device affinity

This PR updates the torchvision QAT reference script to take
advantage of both of these.  This should be landed after
pytorch/pytorch#39337 (the last PT
fix) to avoid compatibility issues.

Test Plan:

```
python -m torch.distributed.launch
  --nproc_per_node 8
  --use_env
  references/classification/train_quantization.py
  --data-path {imagenet1k_subset}
  --output-dir {tmp}
  --sync-bn
```

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the qat_ddp_20200601 branch from 3a07ab7 to 7d85818 Compare June 2, 2020 18:39
@vkuzo
Copy link
Contributor Author

vkuzo commented Jun 2, 2020

Also, can't you convert the model to SyncBatchNorm before creating the optimizer? It could maybe be safer, if it works?

good point, fixed

@vkuzo vkuzo requested a review from fmassa June 3, 2020 15:59
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@fmassa fmassa merged commit 3902140 into master Jun 3, 2020
@fmassa fmassa deleted the qat_ddp_20200601 branch June 3, 2020 17:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants