From d80174e2db679365f8b58ff8583bdc4af5a8b74c Mon Sep 17 00:00:00 2001 From: Avi Verma Date: Fri, 16 Jun 2023 19:25:55 +0000 Subject: [PATCH] Do not materialize entire randperm in RandomSampler (#103339) In our DDP training workloads, each rank was initializing a `RandomSampler` for a dataset with a length of 3.5 billion items. We noticed that when this sampler was in scope, `gc.collect` calls were taking on the order of seconds to run, which would slow down the entire training iteration. This is because when we call `torch.randperm(n).tolist()`, we create a python list of 3.5 billion items, which massively slows down the periodic mark & sweep garbage collection. This PR swaps out the `.tolist()` call with a `.numpy()` call and manually calls `.item()` on each element as it is being requested. This has two benefits: 1. The first call to `RandomSampler::__next__` should be about twice as fast, since `.numpy` does not copy the contents of the original tensor 2. The runtime of `gc.collect()` calls no longer scales linearly with the size of the dataset passed to `RandomSampler` I've attached some `timeit` samples to illustrate the speedups with this Pr: ``` Main (no GC): 51.72115747816861 Main (10 GC calls) 83.61965207383037 PR (no GC) 33.06403830461204 PR (10 GC calls) 33.959467427805066 ``` Code ```python from timeit import timeit baseline_no_gc = """ import torch n = int(1e9) steps = n // 100 x = torch.randperm(n).tolist() x_iter = iter(x) for i in range(steps): next(x_iter) """ baseline_gc = """ import torch import gc n = int(1e9) steps = n // 100 gc_every = steps // 10 x = torch.randperm(n).tolist() x_iter = iter(x) for i in range(steps): next(x_iter) if i % gc_every == 0: gc.collect() """ numpy_no_gc = """ import torch n = int(1e9) steps = n // 100 x = torch.randperm(n).numpy() x_iter = (i.item() for i in x) for i in range(steps): next(x_iter) """ numpy_gc = """ import torch import gc n = int(1e9) steps = n // 100 gc_every = steps // 10 x = torch.randperm(n).numpy() x_iter = (i.item() for i in x) for i in range(steps): next(x_iter) if i % gc_every == 0: gc.collect() """ if __name__ == "__main__": print("Main (no GC): ", timeit(baseline_no_gc, number=1)) print("Main (10 GC calls)", timeit(baseline_gc, number=1)) print("PR (no GC)", timeit(numpy_no_gc, number=1)) print("PR (10 GC calls)", timeit(numpy_gc, number=1)) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/103339 Approved by: https://github.com/kit1980 --- torch/utils/data/sampler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 5dbcbc088fba6..606f9ec5b6314 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -159,12 +159,13 @@ def __iter__(self) -> Iterator[int]: if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + yield from map(int, torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).numpy()) + final_samples = torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator) + yield from map(int, final_samples.numpy()) else: for _ in range(self.num_samples // n): - yield from torch.randperm(n, generator=generator).tolist() - yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + yield from map(int, torch.randperm(n, generator=generator).numpy()) + yield from map(int, torch.randperm(n, generator=generator)[:self.num_samples % n].numpy()) def __len__(self) -> int: return self.num_samples