diff --git a/CHANGELOG.md b/CHANGELOG.md index 7caadaf4f4048..cb044e58c2a5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487)) +- Added `__len__` to `IndexBatchSamplerWrapper` ([#7681](https://github.com/PyTorchLightning/pytorch-lightning/pull/7681)) + + - Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684)) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index d064040d8e019..559e1161ce676 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -132,6 +132,9 @@ def __iter__(self) -> Iterator[List[int]]: self.batch_indices = batch yield batch + def __len__(self) -> int: + return len(self._sampler) + @property def drop_last(self) -> bool: return self._sampler.drop_last diff --git a/tests/overrides/test_distributed.py b/tests/overrides/test_distributed.py index d09ac9c8bad06..e742eb6ecccd9 100644 --- a/tests/overrides/test_distributed.py +++ b/tests/overrides/test_distributed.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable + import pytest from torch.utils.data import BatchSampler, SequentialSampler from pytorch_lightning import seed_everything from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler +from pytorch_lightning.utilities.data import has_len @pytest.mark.parametrize("shuffle", [False, True]) @@ -54,3 +57,13 @@ def test_index_batch_sampler(tmpdir): for batch in index_batch_sampler: assert index_batch_sampler.batch_indices == batch + + +def test_index_batch_sampler_methods(): + dataset = range(15) + sampler = SequentialSampler(dataset) + batch_sampler = BatchSampler(sampler, 3, False) + index_batch_sampler = IndexBatchSamplerWrapper(batch_sampler) + + assert isinstance(index_batch_sampler, Iterable) + assert has_len(index_batch_sampler)