Skip to content

Commit

Permalink
Do not materialize entire randperm in RandomSampler (pytorch#103339)
Browse files Browse the repository at this point in the history
In our DDP training workloads, each rank was initializing a `RandomSampler` for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, `gc.collect` calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call `torch.randperm(n).tolist()`, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection.

This PR swaps out the `.tolist()` call with a `.numpy()` call and manually calls `.item()` on each element as it is being requested. This has two benefits:

1. The first call to `RandomSampler::__next__` should be about twice as fast, since `.numpy` does not copy the contents of the original tensor
2. The runtime of `gc.collect()` calls no longer scales linearly with the size of the dataset passed to `RandomSampler`

I've attached some `timeit` samples to illustrate the speedups with this Pr:

```
Main (no GC):  51.72115747816861
Main (10 GC calls) 83.61965207383037
PR (no GC) 33.06403830461204
PR (10 GC calls) 33.959467427805066
```

Code
```python
from timeit import timeit

baseline_no_gc = """
import torch

n = int(1e9)
steps = n // 100

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
"""

baseline_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).tolist()
x_iter = iter(x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""

numpy_no_gc = """
import torch
n = int(1e9)
steps = n // 100

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
"""

numpy_gc = """
import torch
import gc
n = int(1e9)
steps = n // 100
gc_every = steps // 10

x = torch.randperm(n).numpy()
x_iter = (i.item() for i in x)

for i in range(steps):
    next(x_iter)
    if i % gc_every == 0:
        gc.collect()
"""

if __name__ == "__main__":
    print("Main (no GC): ", timeit(baseline_no_gc, number=1))
    print("Main (10 GC calls)", timeit(baseline_gc, number=1))
    print("PR (no GC)",  timeit(numpy_no_gc, number=1))
    print("PR (10 GC calls)", timeit(numpy_gc, number=1))

```
Pull Request resolved: pytorch#103339
Approved by: https://github.com/kit1980
  • Loading branch information
aviverma01 authored and pytorchmergebot committed Jun 16, 2023
1 parent 67babf7 commit d80174e
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torch/utils/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,13 @@ def __iter__(self) -> Iterator[int]:

if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
yield from map(int, torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).numpy())
final_samples = torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator)
yield from map(int, final_samples.numpy())
else:
for _ in range(self.num_samples // n):
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
yield from map(int, torch.randperm(n, generator=generator).numpy())
yield from map(int, torch.randperm(n, generator=generator)[:self.num_samples % n].numpy())

def __len__(self) -> int:
return self.num_samples
Expand Down

0 comments on commit d80174e

Please sign in to comment.