-
-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FID metric #1302
Add FID metric #1302
Conversation
@kamahori thanks for working on this metric ! I agree that it is not a simplest metric the intergrate. fid_metric = FID()
evaluator = ...
fid_metric.attach(evaluator, "fid")
evaluator.run(validation_dataloader) See here for more details: https://pytorch.org/ignite/concepts.html Such that we can avoid dataloader setup from path etc. ImageNet pretrained model like Inception can an optional argument: class FID(Metric):
def __init__(self, ..., test_model=None):
...
if test_model is None:
try:
from torchvision ... import inception
test_model = ..
except ImportError:
... Maybe something else... I'll comment later. Other remarks, probable, we should not use numpy or tqdm in the code. We can replace numpy by torch and tqdm can be attached by user on the evaluator. |
Thanks for reviewing @vfdev-5
How about scipy? Replacing |
Good point, let me check if it can be possible or not at all... |
Maybe we can use this ? |
Thanks for the link. Seems like it is a wrap over scipy and backprop enabled. Backprop is not a requirement for us. Can you check if there is a FR on pytorch for that ? I wonder if it is difficult to implement ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kamahori thanks for working on this metric !
I left some comments on how we need to reimplement the metric...
Honestly, it is not a simple metric to implement and hope we can find some nice solutions :)
@kamahori if there is no "simple" option to make it without scipy, let's put the metric into |
@vfdev-5 |
@kamahori thanks for the update ! It looks better now 👍 However, it remains a tough part about "online" fashion metric computation, make it work with DDP and write tests Let's handle them one by one.
Let me give you an example how it would work for a simple metric like accuracy. At first, how can we compute accuracy in general (without online) on predictions and targets: y_pred = ... # tensor (N, K) - predictions, N samples of K features
y = ... # tensor (N, ) - targets, N samples, e.g. [0, 2, 1, 4, 5, 2, 3, 4] where K=6
pred_indices = torch.argmax(y_pred, dim=1)
num_correct = torch.eq(pred_indices, y).view(-1) # compute the number of correctly predicted samples
num_samples = y.shape[0]
accuracy = num_correct / num_samples In online manner, we have a bunch of pairs like
Thus, at def reset(self):
self._num_correct = 0
self._num_samples = 0
def update(self, output):
y_pred, y = output
pred_indices = torch.argmax(y_pred, dim=1)
self._num_correct += torch.eq(pred_indices, y).view(-1) # compute the number of correctly predicted samples
self._num_samples += y.shape[0]
def compute(self):
return self._num_correct / self._num_samples A similar reasoning should be applied for FID. We have to identify what can be precomputed on |
Thanks. How about this implementation? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update ! I have some other comments.
I think it would be also good to write some basic tests.
I implemented |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kamahori thanks ! This looks much better now ! Now we have to write some meaningful tests...
I fixed the implementation, but still have some errors in CI. Do you know what is the problem? @vfdev-5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kamahori thanks for the updates. I left several comments.
I found that calculating covariance in an online manner is not efficient. So I reverted |
I think most parts are completed.
And is there any better way to extract real values to avoid this? |
@kamahori thanks for working on this PR which is not simple ! About XLA failing tests, let's see this a bit later as it looks like a bit unrelated to the PR. First, I think we are incorrect about inception features. This is something to rework. Maybe, the simplest way is to replace final classification layer by identity, such that we do not need to do a lot of network surgery. Second, I'm not that convinced about computing covariance like you changed. Now, we accumulate the sum and whole history of features... Let me also check it by myself (however, I'm a bit of lacking of time this weeks). |
Since no one appears to have linked it, I believe this is considered to be the "reference implementation" for pytorch: https://github.com/mseitzer/pytorch-fid |
@import-antigravity thank you for the link ! |
closed in favor of #2049 |
Fixes #998
Description:
Added FID metric which is used for GANs.
Most codes are incorporated from here, so further optimization might be needed.
Check list: