Skip to content

Commit

Permalink
Add SubsetRandomSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Sep 25, 2023
1 parent 589f0f2 commit da857e5
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions python/paddle/io/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np

from ...framework import core
from ...tensor import randperm


class Sampler:
Expand Down Expand Up @@ -340,3 +341,52 @@ def __iter__(self):
def __len__(self):
mul = np.prod(self.weights.shape) // self.weights.shape[-1]
return self.num_samples * mul


class SubsetRandomSampler(Sampler):
r"""
Randomly sample elements from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
Examples:
.. code-block:: python
>>> from paddle.io import Dataset, SubsetRandomSampler
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([784]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> sampler = SubsetRandomSampler(indices=[1, 3, 5, 7, 9])
>>> for index in sampler:
... print(index)
5
3
1
7
9
see `paddle.io.Sampler`
"""

def __init__(self, indices):
self.indices = indices

def __iter__(self):
for i in randperm(len(self.indices)):
yield self.indices[i]

def __len__(self) -> int:
return len(self.indices)

0 comments on commit da857e5

Please sign in to comment.