diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 5dbcbc088fba6..606f9ec5b6314 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -159,12 +159,13 @@ def __iter__(self) -> Iterator[int]: if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + yield from map(int, torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).numpy()) + final_samples = torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator) + yield from map(int, final_samples.numpy()) else: for _ in range(self.num_samples // n): - yield from torch.randperm(n, generator=generator).tolist() - yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + yield from map(int, torch.randperm(n, generator=generator).numpy()) + yield from map(int, torch.randperm(n, generator=generator)[:self.num_samples % n].numpy()) def __len__(self) -> int: return self.num_samples