Skip to content

Commit

Permalink
add ConcatDataset API
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick-Star125 committed Sep 25, 2023
1 parent db901f9 commit 2ae9ff2
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .dataloader import WeightedRandomSampler # noqa: F401
from .dataloader import Subset # noqa: F401
from .dataloader import random_split # noqa: F401
from .dataloader import ConcatDataset # noqa: F401

__all__ = [ # noqa
'Dataset',
Expand All @@ -46,4 +47,5 @@
'WeightedRandomSampler',
'random_split',
'Subset',
'ConcatDataset',
]
1 change: 1 addition & 0 deletions python/paddle/io/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .dataset import ChainDataset
from .dataset import random_split
from .dataset import Subset
from .dataset import ConcatDataset

from .batch_sampler import BatchSampler
from .batch_sampler import DistributedBatchSampler
Expand Down
74 changes: 74 additions & 0 deletions python/paddle/io/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import bisect
import paddle

from ... import framework
Expand Down Expand Up @@ -567,3 +568,76 @@ def _accumulate(iterable, fn=lambda x, y: x + y):
for element in it:
total = fn(total, element)
yield total


class ConcatDataset(Dataset):
"""
Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
Returns:
Dataset: A Dataset which concatenated by multiple datasets.
Examples:
.. code-block:: python
>>> import numpy as np
>>> import paddle
>>> from paddle.io import Dataset, ConcatDataset
>>> # define a random dataset
>>> class RandomDataset(Dataset):
... def __init__(self, num_samples):
... self.num_samples = num_samples
...
... def __getitem__(self, idx):
... image = np.random.random([32]).astype('float32')
... label = np.random.randint(0, 9, (1, )).astype('int64')
... return image, label
...
... def __len__(self):
... return self.num_samples
...
>>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)])
>>> for i in range(len(dataset)):
... image, label = dataset[i]
... # do something
"""

@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r

def __init__(self, datasets) -> None:
super().__init__()
self.datasets = list(datasets)
assert len(self.datasets) > 0, 'datasets should not be an empty iterable'
for d in self.datasets:
assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.datasets)

def __len__(self):
return self.cumulative_sizes[-1]

def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
62 changes: 62 additions & 0 deletions test/legacy_test/test_multiprocess_dataloader_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Dataset,
IterableDataset,
TensorDataset,
ConcatDataset,
)

IMAGE_SIZE = 32
Expand Down Expand Up @@ -440,5 +441,66 @@ def test_iterable_dataset(self):
self.run_main(dataset, 10, 3)


class RandomIterableDataset(IterableDataset):
def __init__(self, sample_num):
self.sample_num = sample_num

def __iter__(self):
for i in range(self.sample_num):
np.random.seed(i)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, 9, (1,)).astype('int64')
yield image, label


class TestConcatDataset(unittest.TestCase):
def run_main(self, num_workers, places):
result = ConcatDataset([[0], [1]])
self.assertEqual(2, len(result))
self.assertEqual(0, result[0])
self.assertEqual(1, result[1])

result = ConcatDataset([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])

result = ConcatDataset([[0, 1, 2, 3, 4],
[],
[5, 6, 7, 8, 9]])
self.assertEqual(10, len(result))
self.assertEqual(0, result[0])
self.assertEqual(5, result[5])

result = ConcatDataset([[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]])
with self.assertRaises(IndexError):
# this one goes to 11
result[11]


def test_main(self):
places = [paddle.CPUPlace()]
if paddle.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for p in places:
self.run_main(num_workers=0, places=p)

def test_iterable_dataset_err(self):
d1 = TensorDataset([paddle.rand((7, 3, 28, 28)), paddle.rand((7,))])
it1 = RandomIterableDataset(10)
it2 = RandomIterableDataset(10)

with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
ConcatDataset([d1, it2, it1])

with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
ConcatDataset([it2])

with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"):
ConcatDataset([it1, d1])


if __name__ == '__main__':
unittest.main()

0 comments on commit 2ae9ff2

Please sign in to comment.