Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ViT to torchvision/models #4594

Merged
merged 39 commits into from
Nov 27, 2021
Merged

Adding ViT to torchvision/models #4594

merged 39 commits into from
Nov 27, 2021

Conversation

yiwen-song
Copy link
Contributor

@yiwen-song yiwen-song commented Oct 12, 2021

The first part of #4593 :)

cc @datumbox @bjuncek

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @sallysyw, I didn't check the architecture for now. The code looks great, I just took a brief look at the docstrings and public/private interface

torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @sallysyw, great addition.

I've added a few comments related to the conventions used at TorchVision. Let me know your thoughts. I'm happy to review the ML bit if you want, I just need to freshen up the paper.

torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
Copy link

@mannatsingh mannatsingh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a bunch of thoughts / suggestions. Feel free to ignore whatever doesn't make sense or to keep things simple!

torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell, the only comments that might affect the API of the class are Francisco's remark on the classifier and Mannat's on embedding interpolation. The latter can be addressed on a follow up PR but it's worth considering the implications on the API. The good. thing is, we are just after a release, so even if we merge and want to review the API, we got plainly of time to do so. Other than that, I think the PR is mergeable. I left a couple of minor comments but none of them are blocking and mostly FYIs and nits.

@sallysyw I haven't checked the validity of the ViT implementation itself comparing to the paper. I know you are porting this from trusted sources, but if you want me to check it let me know.

torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
@mannatsingh
Copy link

@sallysyw I haven't checked the validity of the ViT implementation itself comparing to the paper. I know you are porting this from trusted sources, but if you want me to check it let me know.

This is a great point! If you train from scratch and the axes are messed up, you still get reasonable results sometimes (speaking from experience lol). We should try and maybe repro a result from here - https://github.com/facebookresearch/ClassyVision/tree/main/examples/vit

Or if it's easier, we can use a pretrained model from any source (like Classy), and evaluate it with this implementation and verify that the accuracy matches!

@yiwen-song
Copy link
Contributor Author

yiwen-song commented Oct 27, 2021

@sallysyw I haven't checked the validity of the ViT implementation itself comparing to the paper. I know you are porting this from trusted sources, but if you want me to check it let me know.

This is a great point! If you train from scratch and the axes are messed up, you still get reasonable results sometimes (speaking from experience lol). We should try and maybe repro a result from here - https://github.com/facebookresearch/ClassyVision/tree/main/examples/vit

Or if it's easier, we can use a pretrained model from any source (like Classy), and evaluate it with this implementation and verify that the accuracy matches!

Thanks @mannatsingh, I've been training vit_b_32 from scratch on AWS cluster, once the training is finished and the results look fine, I'll update it here.
At the same time, I will port the pre-trained weights from original repo or classy and validate it on our model.

@yiwen-song
Copy link
Contributor Author

yiwen-song commented Nov 1, 2021

I've finished the first iteration of training the vit_b_32 and vit_b_16 models from scratch using the ImageNet dataset, here are the numbers I got:

Job ID Model Epochs Nodes Batch Size Per GPU Global Batch Size Original Paper Acc@1 ClassyVision Acc@1 Acc@1 Acc@5
2420 vit_b_32 300 2 256 4096 73.38 73.3 71.876 89.784
2512 vit_b_16 300 8 64 4096 77.91 78.98 76.854 92.959

I plan to further tune the parameters and I'll upload the pre-trained weights in a following PR once I got the accuracy numbers matching the previous results.
I have also checked the number of params in these models and they are exactly the same as the ClassyVision models.

Let me know if there's other concerns before I can merge this PR.
@datumbox @mannatsingh

@yiwen-song
Copy link
Contributor Author

yiwen-song commented Nov 1, 2021

Or if it's easier, we can use a pretrained model from any source (like Classy), and evaluate it with this implementation and verify that the accuracy matches!

Do you know where can I find some pre-trained classy-vision checkpoints?

@datumbox
Copy link
Contributor

datumbox commented Nov 2, 2021

@sallysyw Thanks for the update.

As far as I can see the unit-tests are failing and it seems related. From what I can see the FX feature extractor seg faults on vit_b_16. This needs to be fixed before we merge:

test/test_backbone_utils.py::TestFxFeatureExtraction::test_jit_forward_backward[vit_b_16] Windows fatal exception: access violation

Do you know where can I find some pre-trained classy-vision checkpoints?

This is a good idea to do. I don't know where you could find one (worth checking with internal teams to see if they have checkpoints you can use) but it's definitely a necessary step prior merging. I would also recommend, unless you have other important reasons, to fully reproduce the achieved accuracies before merging. In the past, there were cases were the cause of the drop in accuracy was identified in the architecture itself. Though you can mitigate that risk by loading pre-trained checkpoints and reproducing their accuracies to the digit, this doesn't cover for bugs for when the model is on training mode.

Let me know if there's other concerns before I can merge this PR.

Some of the above comments, for example Francisco's #4594 (comment) and #4594 (comment), affect the architecture so if you merge now you might have to make radical changes later. Given that the next release is months from now, we got time to fix them. Still given that there are outstanding comments, it would be good to confirm with @mannatsingh and @fmassa that an early merge is OK in this case. Another option is to merge this on prototype where the criteria are more relaxed. Thoughts?

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All look great to me. I left one last comment for something missed in earlier reviews and 2 optional nits. Hopefully this is the last round before merging,

test/test_backbone_utils.py Outdated Show resolved Hide resolved
torchvision/prototype/models/vision_transformer.py Outdated Show resolved Hide resolved
torchvision/prototype/models/vision_transformer.py Outdated Show resolved Hide resolved
@yiwen-song
Copy link
Contributor Author

whoops - trying to import prototype in torchvision.models causes this error on circleCI:

ModuleNotFoundError: `torchvision.prototype.datasets` depends on PyTorch's `torchdata` (https://github.com/pytorch/data). You can install it with `pip install git+https://github.com/pytorch/data.git`. Note that you cannot install it with `pip install torchdata`, since this is another package.
Tests failed for torchvision-0.12.0.dev20211124-py36_cpu.tar.bz2 - moving package to /opt/conda/conda-bld/broken
WARNING:conda_build.build:Tests failed for torchvision-0.12.0.dev20211124-py36_cpu.tar.bz2 - moving package to /opt/conda/conda-bld/broken
WARNING conda_build.build:tests_failed(2970): Tests failed for torchvision-0.12.0.dev20211124-py36_cpu.tar.bz2 - moving package to /opt/conda/conda-bld/broken
TESTS FAILED: torchvision-0.12.0.dev20211124-py36_cpu.tar.bz2

But I confirmed locally that after the changes the tests passed on linux cpu.

cc @datumbox

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sallysyw I tested them locally and the tests fail:

[W ir_emitter.cpp:4340] Warning: Dict values consist of heterogeneous types, which means that the dict has been typed as containing Dict[str, Union[Tensor, Tuple[Tensor, Optional[Tensor]]]]. To use any of the values in this Dict, it will be necessary to add an `assert isinstance` statement before first use to trigger type refinement.
  File "<eval_with_key>.76", line 196
    eq_14 = dim_12 == 3;  dim_12 = None
    _assert_14 = torch._assert(eq_14, 'Expected (seq_length, batch_size, hidden_dim) got Proxy(getattr_14)');  eq_14 = None
    return {'encoder.dropout': encoder_dropout, 'encoder.layers.encoder_layer_5.ln': encoder_layers_encoder_layer_5_ln_1, 'encoder.layers.encoder_layer_6.ln': encoder_layers_encoder_layer_6_ln_1, 'encoder.layers.encoder_layer_7.add': add_15, 'encoder.layers.encoder_layer_7.add_1': add_16, 'encoder.layers.encoder_layer_8.ln_1': encoder_layers_encoder_layer_8_ln_2, 'encoder.layers.encoder_layer_8.mlp.linear_1': encoder_layers_encoder_layer_8_mlp_linear_2, 'encoder.layers.encoder_layer_10.self_attention': encoder_layers_encoder_layer_10_self_attention, 'encoder.layers.encoder_layer_10.add': add_21, 'encoder.layers.encoder_layer_10.mlp.dropout_1': encoder_layers_encoder_layer_10_mlp_dropout_2}
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
 (function operator())

test/test_backbone_utils.py:187 (TestFxFeatureExtraction.test_jit_forward_backward[vit_b_16])
Traceback (most recent call last):
  File "./vision/test/test_backbone_utils.py", line 198, in test_jit_forward_backward
    sum(o.mean() for o in fgn_out.values()).backward()
  File "./vision/test/test_backbone_utils.py", line 198, in <genexpr>
    sum(o.mean() for o in fgn_out.values()).backward()
AttributeError: 'tuple' object has no attribute 'mean'

I think we moved the code to prototype to early (that's on me) before we confirm that all issues are fixed. Since the tests on prototypes are not executed on CI, we now risk not detecting issues with the implementation.

I propose undoing the move to prototype, running the tests on the CI and ensure everything works prior moving everything to the prototype again. I did this job on a separate no-merge PR at #4984 and as you can see the tests fail. You are welcome to merge my changes into your current PR to investigate.

@yiwen-song
Copy link
Contributor Author

@sallysyw I tested them locally and the tests fail:

[W ir_emitter.cpp:4340] Warning: Dict values consist of heterogeneous types, which means that the dict has been typed as containing Dict[str, Union[Tensor, Tuple[Tensor, Optional[Tensor]]]]. To use any of the values in this Dict, it will be necessary to add an `assert isinstance` statement before first use to trigger type refinement.
  File "<eval_with_key>.76", line 196
    eq_14 = dim_12 == 3;  dim_12 = None
    _assert_14 = torch._assert(eq_14, 'Expected (seq_length, batch_size, hidden_dim) got Proxy(getattr_14)');  eq_14 = None
    return {'encoder.dropout': encoder_dropout, 'encoder.layers.encoder_layer_5.ln': encoder_layers_encoder_layer_5_ln_1, 'encoder.layers.encoder_layer_6.ln': encoder_layers_encoder_layer_6_ln_1, 'encoder.layers.encoder_layer_7.add': add_15, 'encoder.layers.encoder_layer_7.add_1': add_16, 'encoder.layers.encoder_layer_8.ln_1': encoder_layers_encoder_layer_8_ln_2, 'encoder.layers.encoder_layer_8.mlp.linear_1': encoder_layers_encoder_layer_8_mlp_linear_2, 'encoder.layers.encoder_layer_10.self_attention': encoder_layers_encoder_layer_10_self_attention, 'encoder.layers.encoder_layer_10.add': add_21, 'encoder.layers.encoder_layer_10.mlp.dropout_1': encoder_layers_encoder_layer_10_mlp_dropout_2}
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
 (function operator())

test/test_backbone_utils.py:187 (TestFxFeatureExtraction.test_jit_forward_backward[vit_b_16])
Traceback (most recent call last):
  File "./vision/test/test_backbone_utils.py", line 198, in test_jit_forward_backward
    sum(o.mean() for o in fgn_out.values()).backward()
  File "./vision/test/test_backbone_utils.py", line 198, in <genexpr>
    sum(o.mean() for o in fgn_out.values()).backward()
AttributeError: 'tuple' object has no attribute 'mean'

I think we moved the code to prototype to early (that's on me) before we confirm that all issues are fixed. Since the tests on prototypes are not executed on CI, we now risk not detecting issues with the implementation.

I propose undoing the move to prototype, running the tests on the CI and ensure everything works prior moving everything to the prototype again. I did this job on a separate no-merge PR at #4984 and as you can see the tests fail. You are welcome to merge my changes into your current PR to investigate.

hmm... I think the seed generated on your machine is different from mine and that's why I didn't catch this failure previously.
Can you please print out the nodes in your test run and I'll investigate (or filter) it?
@datumbox

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sallysyw All worked like a charm. I pushed the removal of the extra classed from your branch and merged main. We should be good to merge whenever you want. Thanks for the great contribution, looking forward to the weights.

@yiwen-song yiwen-song merged commit 47281bb into pytorch:main Nov 27, 2021
@github-actions
Copy link

Hey @sallysyw!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Nov 30, 2021
Summary:
* [vit] Adding ViT to torchvision/models

* adding pre-logits layer + resolving comments

* Fix the model attribute bug

* Change version to arch

* fix failing unittests

* remove useless prints

* reduce input size to fix unittests

* Increase windows-cpu executor to 2xlarge

* Use `batch_first=True` and remove classifier

* Change resource_class back to xlarge

* Remove vit_h_14

* Remove vit_h_14 from __all__

* Move vision_transformer.py into prototype

* Fix formatting issue

* remove arch in builder

* Fix type err in model builder

* address comments and trigger unittests

* remove the prototype import in torchvision.models

* Adding vit back to models to trigger CircleCI test

* fix test_jit_forward_backward

* Move all to prototype.

* Adopt new helper methods and fix prototype tests.

* Remove unused import.

Reviewed By: NicolasHug

Differential Revision: D32694316

fbshipit-source-id: fa2867555fb7ae65f8dab537517386f6694585a2

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <vvryniotis@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants