Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new ResNet50 weights #4734

Merged
merged 10 commits into from
Oct 25, 2021
Merged

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Oct 25, 2021

Related to #3995

This PR does 2 things:

  • Adds new ResNet50 weights to the prototype
  • Extends the Reference Scripts to be able to validate models from the prototype area

Concerning the new model weights, it was trained using the Batteries Included primitives and achieves the following accuracy:

torchrun --nproc_per_node=1 train.py --test-only --weights ImageNet1K_RefV2 --model resnet50 -b 1
Test:  Acc@1 80.674 Acc@5 95.166

The linked issue provides high-level details on the recipe but I'll also follow up with a blogpost on how it was trained.

cc @datumbox @vfdev-5 @pmeier @bjuncek

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping

@datumbox datumbox marked this pull request as draft October 25, 2021 11:37
@datumbox
Copy link
Contributor Author

It seems that the get_args is supported only from Python 3.8 and onwards. So I'll write a workaround for now.

@datumbox datumbox marked this pull request as ready for review October 25, 2021 12:21
Copy link
Contributor

@prabhat00155 prabhat00155 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox!

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding some highlights to assist review:

@@ -14,6 +14,12 @@
from torchvision.transforms.functional import InterpolationMode


try:
from torchvision.prototype import models as PM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to import the prototype models but without failing.

@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args):
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
else:
if not args.weights:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which preprocessing we will use depends on whether weights are defined.

else:
fn = PM.__dict__[args.model]
weights = PM._api.get_weight(fn, args.weights)
preprocessing = weights.transforms()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a definition of the weights means we will be accessing the prototype models. Those have the preprocessing attached to the weights, so we fetch them and construct the preprocessing class.

@@ -74,3 +75,38 @@ def __getattr__(self, name):
if f.name == name:
return object.__getattribute__(self.value, name)
return super().__getattr__(name)


def get_weight(fn: Callable, weight_name: str) -> Weights:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I consider it a private method. We will eventually need to make it public because getting the enum class from a string is useful but it's unclear whether we should do it by passing the model_builder and then weight_name or construct it via the fully qualified name.

Copy link
Member

@NicolasHug NicolasHug Oct 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I only got a chance to looks at it now.

Relying on the model_builder's annotation seems like a pretty involved way of retrieving the weights.

Should we go simple here and just register all the weights in some sort of private _AVAILABLE_WEIGHTS dict? get_weight() would then just be a query into this private dict

(This is my only comment, the rest of the PR looks great!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug Thanks for looking at it. FYI I merged after Prabhat's review so that we pass this to the FBsync but I plan to make changes on follow up PRs.

I agree that this is involved and that's why I haven't exposed it as public. I've added an entry at #4652 to review the mechanism and more specifically sync with you on making it Torchhub friendly. One option as you said is to have a similar registration mechanism as proposed here to keep track of method/weight combos and flag also the "best/latest" weights. I have on purpose omitted all the versioning parts of the original RFC to allow for discussions across Audio and Text to continue and see if we can adopt a common solution. But I think they are currently looking into moving towards a different direction that has no model builders, so we might be able to bring this feature sooner.

@datumbox datumbox merged commit dc11399 into pytorch:main Oct 25, 2021
@datumbox datumbox deleted the models/sota_resnet50 branch October 25, 2021 14:32
facebook-github-bot pushed a commit that referenced this pull request Oct 26, 2021
Summary:
* Update model checkpoint for resnet50.

* Add get_weight method to retrieve weights from name.

* Update the references to support prototype weights.

* Fixing mypy typing.

* Switching to a python 3.6 supported equivalent.

* Add unit-test.

* Add optional num_classes.

Reviewed By: NicolasHug

Differential Revision: D31916330

fbshipit-source-id: 2ac0f9202f62a78078b0917e6730d7fc0925acdf
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
* Update model checkpoint for resnet50.

* Add get_weight method to retrieve weights from name.

* Update the references to support prototype weights.

* Fixing mypy typing.

* Switching to a python 3.6 supported equivalent.

* Add unit-test.

* Add optional num_classes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve the accuracy of Classification models by using SOTA recipes and primitives
4 participants