-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not materialize entire randperm in RandomSampler #103339
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103339
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8f5e445 with merge base bc2caa7 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/utils/data/sampler.py
Outdated
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] | ||
indices = torch.randperm(n, generator=generator) | ||
for i in indices: | ||
yield i.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one minor issue is that repeated .item()
are slow #29973
so maybe tolist() is actually not that bad? (espeically if indices itself is materialized as tensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I've updated the code to call .numpy
on the tensor before iterating on it, which should avoid the slow .item()
calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, if you can afford an extra allocation, why not just yield from indices.tolist()
? Because the indices Python list would take too much memory?
I don't know what's the current state of affaires on obligation of numpy dependency
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth filing a feature request to iterate Python items? Or supporting memoryview on tensors (so that it can be iterated)...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some comments in the PR description, but the main issue is that calling .tolist()
on a torch tensor of size 1billion+ adds a massive garbage collection overhead because we just allocated billions of individual python int objects that need to be managed separately. By using a numpy array instead, the garbage collector only needs to keep track of 1 object regardless of the dataset size, making garbage collection much faster during training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i created also #103352
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this just be yield from torch.randperm(n, generator=generator).numpy()
(+ some indexing) keeping the existing terse syntax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 7a7a1e9, note that there's an additional map
call compared to the original comment to ensure that the type matches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aviverma01 Should we also change the other branch, yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in b450690
@kit1980 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
I've triggered more tests and also imported this internally to make sure nothing breaks. |
Thanks @kit1980, would you able to help fix the "Meta Internal-Only Changes Check"? Also I think some of the tests may be flakey. Would it be possible to re-kick the failing tests? |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
@kit1980 Looks like there may be some remaining flakey tests, and I believe the "Meta Internal-Only" Changes Check is still failing. Any chance you could help/show me how to fix it? |
@kit1980 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@mergebot merge -i |
@aviverma01 sorry, I misspelled the bot name. |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hi @aviverma01 @kit1980, the change this PR introduces seems to block the usage of manually specifying DataLoader's generator with a non-CPU device due to the gen = torch.Generator(device=torch.device("mps:0"))
data_loader = data.DataLoader(dataset, batch_size=8, shuffle=True, generator=gen) The error:
Was it intended to be? Or do you have any idea on this? Thanks :) |
@pytorchbot revert -m "Cause issues on MPS, and also fails without numpy" -c nosignal |
I'm reverting this. I've realized there is another issue with the PR, in fails without numpy, which is an optional dependency actually. |
@pytorchbot successfully started a revert job. Check the current status here. |
@aviverma01 your PR has been successfully reverted. |
This reverts commit d80174e. Reverted #103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](#103339 (comment)))
…103339)" This reverts commit d80174e. Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
#112187) This reverts commit d80174e. Reverted #103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](#103339 (comment))) Co-authored-by: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
…103339)" This reverts commit d80174e. Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
…103339)" This reverts commit d80174e. Reverted pytorch#103339 on behalf of https://github.com/kit1980 due to Cause issues on MPS, and also fails without numpy ([comment](pytorch#103339 (comment)))
In our DDP training workloads, each rank was initializing a
RandomSampler
for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope,gc.collect
calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we calltorch.randperm(n).tolist()
, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection.This PR swaps out the
.tolist()
call with a.numpy()
call and manually calls.item()
on each element as it is being requested. This has two benefits:RandomSampler::__next__
should be about twice as fast, since.numpy
does not copy the contents of the original tensorgc.collect()
calls no longer scales linearly with the size of the dataset passed toRandomSampler
I've attached some
timeit
samples to illustrate the speedups with this Pr:Code