Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot obtain the accuracy stated in the doc for inception_v3 pretrained on Imagenet #6066

Open
m-parchami opened this issue May 22, 2022 · 6 comments

Comments

@m-parchami
Copy link

m-parchami commented May 22, 2022

Hi.
I'm trying to evaluate inception_v3 pretrained model from the hub on Imagenet (ILSVRC 2012) test set. I use the following evaluation code:

def compute_accuracy(self):
        num_correct = 0
        num_images = 0

        _IMAGE_MEAN_VALUE = [0.485, 0.456, 0.406]
        _IMAGE_STD_VALUE = [0.229, 0.224, 0.225]
        imgnet_loader = torch.utils.data.DataLoader(
            ImageFolder('/home/amin/dataset/ILSVRC/val', 
            transforms.Compose([
                transforms.Resize(299),
                transforms.CenterCrop(299),
                transforms.ToTensor(),
                transforms.Normalize(mean=_IMAGE_MEAN_VALUE, std=_IMAGE_STD_VALUE),
            ])),
            batch_size=16, shuffle=True,
            num_workers=8, pin_memory=True)
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True).cuda()
        self.model.eval()

        for i, (images, targets) in \
            enumerate(tqdm(imgnet_loader,  desc="Compute Accuracy", total=len(imgnet_loader))):
            images = images.cuda() if torch.cuda.is_available() else images.cpu()
            targets = targets.cuda() if torch.cuda.is_available() else targets.cpu()
            output_dict = self.model(images)
            pred=output_dict.argmax(dim=1)
            
            num_correct += (pred == targets).sum().item()
            num_images += images.size(0)

        classification_acc = num_correct / float(num_images) * 100
        return classification_acc

However, the accuracy I get is 77.216 while it should be 77.45 according to this page. I figured that the model has a transform_input as preprocess in itself. So if we are doing normalization beforehand (as suggested in the example code), we should set transform_input to false. So if I add self.model.transform_input = False, I get 77.472, Which is closer to the expected value but not exactly the same.

Assuming that the issue isn't from my code, I also found the thread on std, mean values (#1439). So I tested some of the values suggested there as well, and got these results:

mean std model's transform_input Accuracy
[0.485, 0.456, 0.406] [0.229, 0.224, 0.225] disabled 77.472
[0.485, 0.456, 0.406] [0.229, 0.224, 0.225] enabled 77.216
[0.4803, 0.4569, 0.4083] [0.2806, 0.2736, 0.2877] disabled 77.448
[0.4803, 0.4569, 0.4083] [0.2806, 0.2736, 0.2877] enabled 76.986
[0.4845, 0.4541, 0.4025] [0.2724, 0.2637, 0.2761] disabled 77.456
[0.4845, 0.4541, 0.4025] [0.2724, 0.2637, 0.2761] enabled 77.03
[0.4701, 0.4340, 0.3832] [0.2845, 0.2733, 0.2805] disabled 77.44
[0.4701, 0.4340, 0.3832] [0.2845, 0.2733, 0.2805] enabled 77.01

I appreciate any input on this.
Thanks.

@m-parchami m-parchami changed the title Cannot obtain accuracy stated in the doc for inception_v3 pretrained on Imagenet Cannot obtain the accuracy stated in the doc for inception_v3 pretrained on Imagenet May 22, 2022
@NicolasHug
Copy link
Member

@m-parchami I'd recommend using the torchvision training references to check the accuracy of a given model - that's what they use to report accuracies - https://github.com/pytorch/vision/tree/main/references/classification#inception-v3
Also please note that this isn't a tochhub-related issue, more a torchvision issue (you're in luck, we maintain both).

Various things can affect the results. Obviously slight differences in the code, but also batch size, shuffling, some samples potentially being dropped or duplicated depending on the number of GPUs, the use of CUDA determinism, etc...

@NicolasHug NicolasHug transferred this issue from pytorch/hub May 23, 2022
@m-parchami
Copy link
Author

m-parchami commented May 23, 2022

Thanks @NicolasHug for moving it to the right place and not just closing it :) Sorry for the inconvenience.

Regarding your comment, I assumed that by running model.eval() the batch normalization and dropouts can no longer affect the output. Also, I run the code only on a single GPU to avoid any duplication/drops in the samples. Also, because the model is pretrained, I assumed that the CUDA determinism can't affect any weights, hence the output.

I appreciate it if you or anyone else can elaborate on how the model can give different results while both the model and the loader are obtained directly from PyTorch itself. Also, the transforms were directly copied from the documents.

I also noticed this inconsistency even within the Pytorch's documentation. Here, the accuracy of the model is stated as 77.45 while here, on the torchvision.models doc it is stated as 77.294

@NicolasHug
Copy link
Member

I also noticed this inconsistency even within the Pytorch's documentation. Here, the accuracy of the model is stated as 77.45 while here, on the torchvision.models doc it is stated as 77.294

Good catch, thanks. In general, the torchhub docs might be trailing a bit. You'll find the latest accuracies here: https://pytorch.org/vision/main/models.html (soon to be https://pytorch.org/vision/stable/models.html).

I assumed that by running model.eval() the batch normalization and dropouts can no longer affect the output

You're right, but still different batch sizes can lead to slightly different computations (regardless of batchnorm or dropout, or anything else).

Also, because the model is pretrained, I assumed that the CUDA determinism can't affect any weights, hence the output

CUDA and pytorch determinism affect the model during training but also during eval. You'll find more details here https://pytorch.org/docs/stable/notes/randomness.html. This is why we recommend using our scripts with the test-only option to make sure the reproducibility settings are consistent.

I appreciate it if you or anyone else can elaborate on how the model can give different results while both the model and the loader are obtained directly from PyTorch itself

Another slight discrepancy is that you used vision:v0.10.0 while we're at 0.12. I doubt inception_v3 has changed inbetween, but that might be a variable to eliminate as well :)

@m-parchami
Copy link
Author

m-parchami commented May 24, 2022

Thank you for being so thorough.
With the help of the links you sent, I figured out the issue. According to bottom of this page, the latest pre-trained inception_v3 model uses resize(342). With that, I got the exact same top-1 and top-5 accuracies, even with the script from my first post (which uses torch.hub).

So by just fixing the resize(299) to resize(342) on torch.hub's page and updating the top-1 and top-5 based on torchvision.models's page, everything would be consistent and accurate.

Can I send a PR (to hub repo) for this to save everyone's time? I'll fix it both on the hub's webpage and the corresponding colab notebook.

@datumbox
Copy link
Contributor

Actually, I think the documentation on TorchHub needs to change to use the new Multi-weight support API. Instead of manually preprocessing the weights and getting into problems, this should be done using the transforms attached to the weights. Here is an example on the docs.

Probably we want to merge such a PR after the upcoming release though because that's a new API.

@NicolasHug
Copy link
Member

Probably we want to merge such a PR after the upcoming release though because that's a new API.

yes this is tracked in pytorch/hub#287

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants