Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FID metric #1302

Closed
wants to merge 37 commits into from
Closed

Add FID metric #1302

wants to merge 37 commits into from

Conversation

kamahori
Copy link
Contributor

@kamahori kamahori commented Sep 17, 2020

Fixes #998

Description:
Added FID metric which is used for GANs.
Most codes are incorporated from here, so further optimization might be needed.

Check list:

  • New tests are added (if a new feature is added)
  • New doc strings: description and/or example code are in RST format
  • Documentation is updated (if required)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 17, 2020

@kamahori thanks for working on this metric ! I agree that it is not a simplest metric the intergrate.
I have a few comments about the API I was thinking about. Idea is that user can attach the metric to evaluator engine:

fid_metric = FID()
evaluator = ...
fid_metric.attach(evaluator, "fid")

evaluator.run(validation_dataloader)

See here for more details: https://pytorch.org/ignite/concepts.html

Such that we can avoid dataloader setup from path etc. ImageNet pretrained model like Inception can an optional argument:

class FID(Metric):
    def __init__(self, ..., test_model=None):
        ...
        if test_model is None:
            try:
                 from torchvision ... import inception
                 test_model = ..
            except ImportError:
                 ...

Maybe something else... I'll comment later.

Other remarks, probable, we should not use numpy or tqdm in the code. We can replace numpy by torch and tqdm can be attached by user on the evaluator.

@kamahori
Copy link
Contributor Author

Thanks for reviewing @vfdev-5
I fixed several points.

Other remarks, probable, we should not use numpy or tqdm in the code. We can replace numpy by torch and tqdm can be attached by user on the evaluator.

How about scipy? Replacing scipy.linalg.sqrtm with torch seems a bit complicated.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 17, 2020

How about scipy? Replacing scipy.linalg.sqrtm with torch seems a bit complicated.

Good point, let me check if it can be possible or not at all...

@kamahori
Copy link
Contributor Author

Maybe we can use this ?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 17, 2020

Thanks for the link. Seems like it is a wrap over scipy and backprop enabled. Backprop is not a requirement for us. Can you check if there is a FR on pytorch for that ? I wonder if it is difficult to implement ...

@kamahori
Copy link
Contributor Author

I also found this and this, but it seems they are not exactly what we are looking for...
And FYI here is the original scipy implementation.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kamahori thanks for working on this metric !
I left some comments on how we need to reimplement the metric...
Honestly, it is not a simple metric to implement and hope we can find some nice solutions :)

ignite/metrics/fid.py Outdated Show resolved Hide resolved
ignite/metrics/fid.py Outdated Show resolved Hide resolved
ignite/metrics/fid.py Outdated Show resolved Hide resolved
ignite/metrics/fid.py Outdated Show resolved Hide resolved
ignite/metrics/fid.py Outdated Show resolved Hide resolved
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 18, 2020

@kamahori if there is no "simple" option to make it without scipy, let's put the metric into ignite.contrib.metrics and use scipy. In this module, there are other metrics using other deps like sklearn etc

@kamahori
Copy link
Contributor Author

@vfdev-5
Thank you for your comments. I fixed codes again.
I moved FID metric to ignite.contrib.metrics because I could not find any simple solution to scipy's problem.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Sep 21, 2020

@kamahori thanks for the update ! It looks better now 👍 However, it remains a tough part about "online" fashion metric computation, make it work with DDP and write tests

Let's handle them one by one.

  1. Online fashion metric computation

Let me give you an example how it would work for a simple metric like accuracy. At first, how can we compute accuracy in general (without online) on predictions and targets:

y_pred = ...   # tensor (N, K)  - predictions, N samples of K features
y = ...   # tensor (N, )  - targets, N samples, e.g. [0, 2, 1, 4, 5, 2, 3, 4] where K=6

pred_indices = torch.argmax(y_pred, dim=1)
num_correct = torch.eq(pred_indices, y).view(-1)  # compute the number of correctly predicted samples
num_samples = y.shape[0]

accuracy = num_correct / num_samples

In online manner, we have a bunch of pairs like (y_pred, y) and we would like to compute final accuracy. Naive way is to contatenate all y_pred into a single total_y_pred and same for total_y and then compute num_correct, num_samples. But this approach is not memory optimal as we have to store entire history of (y_pred, y).
However, we can see that accuracy as metric can be computed as

num_correct = num_correct_pair1 + num_correct_pair2 + ...
num_samples = num_samples_pair1 + num_samples_pair2 + ...
accuracy = num_correct / num_samples

Thus, at update step we can compute for a pair its num_correct_pairX, num_samples_pairX and accumulate them:

def reset(self):
    self._num_correct = 0
    self._num_samples = 0

def update(self, output):
    y_pred, y = output
    pred_indices = torch.argmax(y_pred, dim=1)
    self._num_correct += torch.eq(pred_indices, y).view(-1)  # compute the number of correctly predicted samples
    self._num_samples += y.shape[0]
    
def compute(self):
    return self._num_correct / self._num_samples

A similar reasoning should be applied for FID. We have to identify what can be precomputed on update call and accumulated and what should be then computed at compute call.

@kamahori
Copy link
Contributor Author

Thanks. How about this implementation?
Accumulate the features in update and calculate mu, sigma, FID in compute.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update ! I have some other comments.
I think it would be also good to write some basic tests.

ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
@kamahori
Copy link
Contributor Author

I implemented update to calculate covariance in an online manner (I think this would be fine) and added a FID test.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kamahori thanks ! This looks much better now ! Now we have to write some meaningful tests...

ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
@kamahori
Copy link
Contributor Author

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kamahori thanks for the updates. I left several comments.

ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
ignite/contrib/metrics/fid.py Outdated Show resolved Hide resolved
@kamahori
Copy link
Contributor Author

kamahori commented Oct 2, 2020

I found that calculating covariance in an online manner is not efficient. So I reverted _cov function and implemented to compute covariance only once.

@kamahori
Copy link
Contributor Author

kamahori commented Oct 2, 2020

I think most parts are completed.
Do you know why these errors are happening?:

And is there any better way to extract real values to avoid this?
@vfdev-5

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 2, 2020

@kamahori thanks for working on this PR which is not simple !

About XLA failing tests, let's see this a bit later as it looks like a bit unrelated to the PR.

First, I think we are incorrect about inception features. This is something to rework. Maybe, the simplest way is to replace final classification layer by identity, such that we do not need to do a lot of network surgery.

Second, I'm not that convinced about computing covariance like you changed. Now, we accumulate the sum and whole history of features... Let me also check it by myself (however, I'm a bit of lacking of time this weeks).

@vfdev-5 vfdev-5 added the hacktoberfest-accepted For accepted PRs label Oct 25, 2020
@import-antigravity
Copy link

Since no one appears to have linked it, I believe this is considered to be the "reference implementation" for pytorch: https://github.com/mseitzer/pytorch-fid

@sdesrozis
Copy link
Contributor

@import-antigravity thank you for the link !

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jun 16, 2021

closed in favor of #2049

@vfdev-5 vfdev-5 closed this Jun 16, 2021
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hacktoberfest-accepted For accepted PRs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Metrics for GANs
4 participants