diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 6c2e0dae678347..1c21ee80b0b8ff 100755 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -29,6 +29,7 @@ from .dataloader import WeightedRandomSampler # noqa: F401 from .dataloader import Subset # noqa: F401 from .dataloader import random_split # noqa: F401 +from .dataloader import ConcatDataset # noqa: F401 __all__ = [ # noqa 'Dataset', @@ -46,4 +47,5 @@ 'WeightedRandomSampler', 'random_split', 'Subset', + 'ConcatDataset', ] diff --git a/python/paddle/io/dataloader/__init__.py b/python/paddle/io/dataloader/__init__.py index bb65463f70afc7..24d70743880176 100644 --- a/python/paddle/io/dataloader/__init__.py +++ b/python/paddle/io/dataloader/__init__.py @@ -19,6 +19,7 @@ from .dataset import ChainDataset from .dataset import random_split from .dataset import Subset +from .dataset import ConcatDataset from .batch_sampler import BatchSampler from .batch_sampler import DistributedBatchSampler diff --git a/python/paddle/io/dataloader/dataset.py b/python/paddle/io/dataloader/dataset.py index 4daf410a318362..f9045508b49154 100755 --- a/python/paddle/io/dataloader/dataset.py +++ b/python/paddle/io/dataloader/dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import bisect import paddle from ... import framework @@ -567,3 +568,76 @@ def _accumulate(iterable, fn=lambda x, y: x + y): for element in it: total = fn(total, element) yield total + + +class ConcatDataset(Dataset): + """ + Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + + Args: + datasets (sequence): List of datasets to be concatenated + + Returns: + Dataset: A Dataset which concatenated by multiple datasets. + + Examples: + + .. code-block:: python + + >>> import numpy as np + >>> import paddle + >>> from paddle.io import Dataset, ConcatDataset + + + >>> # define a random dataset + >>> class RandomDataset(Dataset): + ... def __init__(self, num_samples): + ... self.num_samples = num_samples + ... + ... def __getitem__(self, idx): + ... image = np.random.random([32]).astype('float32') + ... label = np.random.randint(0, 9, (1, )).astype('int64') + ... return image, label + ... + ... def __len__(self): + ... return self.num_samples + ... + >>> dataset = ConcatDataset([RandomDataset(10), RandomDataset(10)]) + >>> for i in range(len(dataset)): + ... image, label = dataset[i] + ... # do something + """ + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, datasets) -> None: + super().__init__() + self.datasets = list(datasets) + assert len(self.datasets) > 0, 'datasets should not be an empty iterable' + for d in self.datasets: + assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] diff --git a/test/legacy_test/test_multiprocess_dataloader_dataset.py b/test/legacy_test/test_multiprocess_dataloader_dataset.py index d10d51d6a02410..1473dcc3863271 100755 --- a/test/legacy_test/test_multiprocess_dataloader_dataset.py +++ b/test/legacy_test/test_multiprocess_dataloader_dataset.py @@ -25,6 +25,7 @@ Dataset, IterableDataset, TensorDataset, + ConcatDataset, ) IMAGE_SIZE = 32 @@ -440,5 +441,66 @@ def test_iterable_dataset(self): self.run_main(dataset, 10, 3) +class RandomIterableDataset(IterableDataset): + def __init__(self, sample_num): + self.sample_num = sample_num + + def __iter__(self): + for i in range(self.sample_num): + np.random.seed(i) + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, 9, (1,)).astype('int64') + yield image, label + + +class TestConcatDataset(unittest.TestCase): + def run_main(self, num_workers, places): + result = ConcatDataset([[0], [1]]) + self.assertEqual(2, len(result)) + self.assertEqual(0, result[0]) + self.assertEqual(1, result[1]) + + result = ConcatDataset([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + self.assertEqual(10, len(result)) + self.assertEqual(0, result[0]) + self.assertEqual(5, result[5]) + + result = ConcatDataset([[0, 1, 2, 3, 4], + [], + [5, 6, 7, 8, 9]]) + self.assertEqual(10, len(result)) + self.assertEqual(0, result[0]) + self.assertEqual(5, result[5]) + + result = ConcatDataset([[0, 1, 2, 3, 4], + [5, 6, 7, 8, 9]]) + with self.assertRaises(IndexError): + # this one goes to 11 + result[11] + + + def test_main(self): + places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + places.append(paddle.CUDAPlace(0)) + for p in places: + self.run_main(num_workers=0, places=p) + + def test_iterable_dataset_err(self): + d1 = TensorDataset([paddle.rand((7, 3, 28, 28)), paddle.rand((7,))]) + it1 = RandomIterableDataset(10) + it2 = RandomIterableDataset(10) + + with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): + ConcatDataset([d1, it2, it1]) + + with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): + ConcatDataset([it2]) + + with self.assertRaisesRegex(AssertionError, "does not support IterableDataset"): + ConcatDataset([it1, d1]) + + if __name__ == '__main__': unittest.main()