-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding new ResNet50 weights #4734
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping
It seems that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @datumbox!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding some highlights to assist review:
@@ -14,6 +14,12 @@ | |||
from torchvision.transforms.functional import InterpolationMode | |||
|
|||
|
|||
try: | |||
from torchvision.prototype import models as PM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to import the prototype models but without failing.
@@ -142,11 +148,18 @@ def load_data(traindir, valdir, args): | |||
print("Loading dataset_test from {}".format(cache_path)) | |||
dataset_test, _ = torch.load(cache_path) | |||
else: | |||
if not args.weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which preprocessing we will use depends on whether weights are defined.
else: | ||
fn = PM.__dict__[args.model] | ||
weights = PM._api.get_weight(fn, args.weights) | ||
preprocessing = weights.transforms() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a definition of the weights means we will be accessing the prototype models. Those have the preprocessing attached to the weights, so we fetch them and construct the preprocessing class.
@@ -74,3 +75,38 @@ def __getattr__(self, name): | |||
if f.name == name: | |||
return object.__getattribute__(self.value, name) | |||
return super().__getattr__(name) | |||
|
|||
|
|||
def get_weight(fn: Callable, weight_name: str) -> Weights: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now I consider it a private method. We will eventually need to make it public because getting the enum class from a string is useful but it's unclear whether we should do it by passing the model_builder and then weight_name or construct it via the fully qualified name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I only got a chance to looks at it now.
Relying on the model_builder's annotation seems like a pretty involved way of retrieving the weights.
Should we go simple here and just register all the weights in some sort of private _AVAILABLE_WEIGHTS
dict? get_weight()
would then just be a query into this private dict
(This is my only comment, the rest of the PR looks great!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug Thanks for looking at it. FYI I merged after Prabhat's review so that we pass this to the FBsync but I plan to make changes on follow up PRs.
I agree that this is involved and that's why I haven't exposed it as public. I've added an entry at #4652 to review the mechanism and more specifically sync with you on making it Torchhub friendly. One option as you said is to have a similar registration mechanism as proposed here to keep track of method/weight combos and flag also the "best/latest" weights. I have on purpose omitted all the versioning parts of the original RFC to allow for discussions across Audio and Text to continue and see if we can adopt a common solution. But I think they are currently looking into moving towards a different direction that has no model builders, so we might be able to bring this feature sooner.
Summary: * Update model checkpoint for resnet50. * Add get_weight method to retrieve weights from name. * Update the references to support prototype weights. * Fixing mypy typing. * Switching to a python 3.6 supported equivalent. * Add unit-test. * Add optional num_classes. Reviewed By: NicolasHug Differential Revision: D31916330 fbshipit-source-id: 2ac0f9202f62a78078b0917e6730d7fc0925acdf
* Update model checkpoint for resnet50. * Add get_weight method to retrieve weights from name. * Update the references to support prototype weights. * Fixing mypy typing. * Switching to a python 3.6 supported equivalent. * Add unit-test. * Add optional num_classes.
Related to #3995
This PR does 2 things:
Concerning the new model weights, it was trained using the Batteries Included primitives and achieves the following accuracy:
The linked issue provides high-level details on the recipe but I'll also follow up with a blogpost on how it was trained.
cc @datumbox @vfdev-5 @pmeier @bjuncek