Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A question about metrics on ddp #2398

Closed
haichao592 opened this issue Jun 28, 2020 · 14 comments
Closed

A question about metrics on ddp #2398

haichao592 opened this issue Jun 28, 2020 · 14 comments
Labels
discussion In a discussion stage question Further information is requested

Comments

@haichao592
Copy link

Thanks for the nice work! 🥳🥳🥳

Backgroud

The new metrics package is released with build-in ddp support.
DistributedSampler is used by dataloaders to work with ddp.
It adds extra samples to make the samples evenly divisible.
This means during evaluation and testing, the added extra samples affect the eval result.

Question

Does the metrics package remove the extra redundant samples to make sure the eval result is not affected?

I have looked into the code and find the answer is NO.
Did I miss something or it's just not been resolved for now?

@haichao592 haichao592 added the question Further information is requested label Jun 28, 2020
@justusschock
Copy link
Member

Hi @haichao592 , AFAIK the DistributedSampler simply does not behave like this. It only divides the dataset into disjunct subsets and does not add any extra samples. This is also the reason, we haven't implemented this.

@zerogerc
Copy link
Contributor

zerogerc commented Jul 2, 2020

@justusschock, the DistributedSampler actually adds samples from the start of the dataset to make the number of examples evenly divisible:

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))] 

Reference: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py#L79

@justusschock
Copy link
Member

justusschock commented Jul 2, 2020

Okay, thanks @zerogerc learned something new :)

But to answer the question here: No we don't over that and I also think we can't do that so easily. Because we don't know which samples that are (there may also be shuffling) and we also don't know which batch you currently feed in your metric. So far this is only a module that calculates and syncs with some sanity checks. The more advanced features are yet to come, but I'm not sure, if this is possible :)

@SkafteNicki do you think, we can add this feature? Probably not without knowing the index and removing duplicates, which we won't do, right?

@zerogerc
Copy link
Contributor

zerogerc commented Jul 2, 2020

Hi @justusschock, just my thoughts on this matter.

I believe that even divisibility is only critical for backward propagation and gradient calculation. I think that writing a distributed sampler which does not add additional samples would solve the problem for the metrics calculation.

@SkafteNicki
Copy link
Member

I was neither aware of this, and cannot really understand why this is default behavior of pytorch (at least without any kind of warning).
I guess that the only way to prevent this would be to write a custom DistributedSampler that corrects this.

@justusschock
Copy link
Member

@SkafteNicki @zerogerc when you think about it, this makes sense (and can not be changed that easily without dropping the last batches), since you cannot sync/reduce something that is not available on some of the processes (which would be the case if you have a not evenly divisible number of batches/samples.

@SkafteNicki
Copy link
Member

@justusschock for something like that to work, a torch.distributed.barrier would be needed before any reduction?

@zerogerc
Copy link
Contributor

zerogerc commented Jul 2, 2020

@justusschock, @SkafteNicki
One more concern.

I think that custom sampler would break the MEAN reduce operation for metrics. For example, accuracy cannot be correctly reduced if each process handled the different number of examples. For correct reduction, each process should calculate the number of handled examples. Another approach is to return tps and sups and always use SUM operation to reduce ddp.

@SkafteNicki
Copy link
Member

@zerogerc we are aware of that problem, and I have an upcoming PR that will fix that issue by using SUM to sync metric states (for accuracy tps and sups) which afterwards gets divided

@SkafteNicki
Copy link
Member

This problem was also discussed in pytorch (pytorch/pytorch#22584) with the solution that the only way to do this would be to keep track of the sample ids and then filter the computations based on that.

Interestingly enough this is the very reason the official imagenet example only adds distributed sampler to their train dataloader and not their validation dataloader (https://github.com/pytorch/examples/blob/49ec0bd72b85be55579ae8ceb278c66145f593e1/imagenet/main.py#L216-L233)

@zerogerc
Copy link
Contributor

zerogerc commented Jul 2, 2020

I think that tracking of sample ids and filtering would add significant overhead to validation step.

@zerogerc
Copy link
Contributor

zerogerc commented Jul 2, 2020

Judging from the implementation of DistributedDataParallel the computation only hangs if the gradient computation is required, so the uneven distribution should work. Needs testing, though.

https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/distributed.py#L501

@stale stale bot added the won't fix This will not be worked on label Aug 31, 2020
@stale stale bot removed the won't fix This will not be worked on label Sep 1, 2020
@Borda Borda added the discussion In a discussion stage label Sep 15, 2020
@carmocca
Copy link
Contributor

Judging from the implementation of DistributedDataParallel the computation only hangs if the gradient computation is required, so the uneven distribution should work. Needs testing, though.

@zerogerc did you test it?

@stale stale bot added the won't fix This will not be worked on label Nov 19, 2020
@SkafteNicki SkafteNicki removed the won't fix This will not be worked on label Nov 19, 2020
@SkafteNicki
Copy link
Member

This issue should be solved by this PR #5141 which enables support for even input in LightningDDP

@SkafteNicki SkafteNicki linked a pull request Dec 16, 2020 that will close this issue
11 tasks
@Lightning-AI Lightning-AI deleted a comment from stale bot Feb 4, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Feb 4, 2021
@Borda Borda closed this as completed Feb 4, 2021
@Lightning-AI Lightning-AI locked and limited conversation to collaborators Feb 4, 2021

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
discussion In a discussion stage question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants