diff --git a/allennlp/data/data_loaders/multi_process_data_loader.py b/allennlp/data/data_loaders/multi_process_data_loader.py index b33ccf0ac73..abec5fdc156 100644 --- a/allennlp/data/data_loaders/multi_process_data_loader.py +++ b/allennlp/data/data_loaders/multi_process_data_loader.py @@ -203,7 +203,9 @@ def __init__( deque(self.iter_instances(), maxlen=0) def __len__(self) -> int: - if self.max_instances_in_memory is None: + if self.batches_per_epoch is not None: + return self.batches_per_epoch + elif self.max_instances_in_memory is None: # We haven't read the instances yet, so we do so now, caching them as we go. if not self._instances: deque(self.iter_instances(), maxlen=0) @@ -218,8 +220,6 @@ def __len__(self) -> int: return num_instances // batch_size else: return 1 + num_instances // batch_size - elif self.batches_per_epoch is not None: - return self.batches_per_epoch else: # We can't know the number of batches for a lazy loader when batches_per_epoch # is not specified. diff --git a/tests/data/data_loaders/multi_process_data_loader_test.py b/tests/data/data_loaders/multi_process_data_loader_test.py index caafefd896e..9f1a370253b 100644 --- a/tests/data/data_loaders/multi_process_data_loader_test.py +++ b/tests/data/data_loaders/multi_process_data_loader_test.py @@ -156,3 +156,14 @@ def test_drop_last(): for batch in batches: assert len(batch["index"]) == 16 assert len(batches) == 6 + + +def test_batches_per_epoch(): + loader = MultiProcessDataLoader( + MockDatasetReader(), "some path", batch_size=4, batches_per_epoch=10 + ) + vocab = Vocabulary.from_instances(loader.iter_instances()) + loader.index_with(vocab) + + assert len(loader) == 10 + assert len(list(loader)) == 10