diff --git a/CHANGELOG.md b/CHANGELOG.md index b1d1d393b321a..d29725d6b34ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -332,6 +332,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472)) +- Improved error messages in `replace_sampler` when the `DataLoader` attributes are not included in the signature or the signature is missing optional arguments ([#8519](https://github.com/PyTorchLightning/pytorch-lightning/pull/8519)) + + - Moved `DeviceDtypeModuleMixin` and `HyperparametersMixin` mixin to `core` ([#8396](https://github.com/PyTorchLightning/pytorch-lightning/pull/8396)) @@ -529,6 +532,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `ModelPruning` callback `on_save_checkpoint` to avoid making a `deepcopy` potentially leading to OOM ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472)) +- Fixed the sampler replacement logic for `DataLoader`s which do not define all `DataLoader` attributes as `__init__` parameters ([#8519](https://github.com/PyTorchLightning/pytorch-lightning/pull/8519)) + + - Fixed DeepSpeed Windows support ([#8488](https://github.com/PyTorchLightning/pytorch-lightning/pull/8488)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index b54ca96b6424f..78d57fa6d9bb7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -113,9 +113,7 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) - def auto_add_sampler( - self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None - ) -> DataLoader: + def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets @@ -147,7 +145,7 @@ def auto_add_sampler( return dataloader @staticmethod - def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]: + def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]: batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. @@ -159,62 +157,77 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningS ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) - dl_args['batch_sampler'] = batch_sampler - dl_args['batch_size'] = 1 - dl_args['shuffle'] = False - dl_args['sampler'] = None - dl_args['drop_last'] = False - else: - dl_args['sampler'] = sampler - dl_args['shuffle'] = False - dl_args['batch_sampler'] = None - - return dl_args + return { + 'sampler': None, + 'shuffle': False, + 'batch_sampler': batch_sampler, + 'batch_size': 1, + 'drop_last': False, + } + return { + 'sampler': sampler, + 'shuffle': False, + 'batch_sampler': None, + } def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader: - skip_keys = ('sampler', 'batch_sampler', 'dataset_kind') - skip_signature_keys = ('args', 'kwargs', 'self') + if not isinstance(dataloader, DataLoader): + raise ValueError(f'The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`') + # get the dataloader instance attributes attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} - - params = set(inspect.signature(dataloader.__init__).parameters) - contains_dataset = True - - if type(dataloader) is not DataLoader: - contains_dataset = "dataset" in params - params.update(inspect.signature(DataLoader.__init__).parameters) - - dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys} - - dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode) - - multiprocessing_context = dataloader.multiprocessing_context - dl_args['multiprocessing_context'] = multiprocessing_context - - missing_kwargs = params.difference(skip_signature_keys).difference(dl_args) - if missing_kwargs: - """ - Example: - class CustomDataLoader(DataLoader): - def __init__(self, num_features, dataset, *args, **kwargs): - self.num_features = num_features - super().__init__(dataset, *args, **kwargs) - """ + # not part of `vars` + attrs['multiprocessing_context'] = dataloader.multiprocessing_context + + # get the dataloader instance `__init__` parameters + params = dict(inspect.signature(dataloader.__init__).parameters) + + # keep only the params whose default is different to the current attr value + non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]} + # add `dataset` as it might have been replaced with `*args` + non_defaults.add('dataset') + + # kwargs to re-construct the dataloader + dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} + dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode)) + + required_args = { + p.name + for p in params.values() + if p.kind in (p.POSITIONAL_ONLY, + p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs + } + # the dataloader has required args which we could not extract from the existing attributes + if required_args: + required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( - f"Trying to inject DistributedSampler within {dataloader_cls_name} class." - "This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. " - f"Missing attributes are {missing_kwargs}. " - f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ", - "manually add DistributedSampler as " - f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).", + f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " + "This would fail as some of the `__init__` arguments are not available as instance attributes. " + f"The missing attributes are {required_args}. " + f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or " + "manually add the `DistributedSampler` as: " + f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." ) - if not contains_dataset: - dl_args.pop('dataset') + has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) + if not has_variadic_kwargs: + # the dataloader signature does not allow keyword arguments that need to be passed + missing_kwargs = dl_kwargs.keys() - params.keys() + if missing_kwargs: + missing_kwargs = sorted(missing_kwargs) + dataloader_cls_name = dataloader.__class__.__name__ + raise MisconfigurationException( + f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. " + "This would fail as it doesn't expose all its attributes in the `__init__` signature. " + f"The missing arguments are {missing_kwargs}. " + f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or " + "manually add the `DistributedSampler` as: " + f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`." + ) - dataloader = type(dataloader)(**dl_args) - dataloader.multiprocessing_context = multiprocessing_context + dl_cls = type(dataloader) + dataloader = dl_cls(**dl_kwargs) return dataloader def _get_distributed_sampler( diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index 5d4da1be7ddbe..4af0182e6c833 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -11,102 +11,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys +from re import escape + import pytest -from torch.utils.data import DataLoader -from torch.utils.data.sampler import BatchSampler, SequentialSampler +from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_7 from tests.helpers import BoringModel, RandomDataset -from tests.helpers.runif import RunIf - - -class IndexedRandomDataset(RandomDataset): - - def __getitem__(self, index): - return self.data[index] - - -class CustomDataLoader(DataLoader): - - def __init__(self, num_features, dataset, *args, **kwargs): - self.num_features = num_features - super().__init__(dataset, *args, **kwargs) - - -class FailureCustomDataLoader(DataLoader): - - def __init__(self, num_features, dataset, *args, **kwargs): - super().__init__(dataset, *args, **kwargs) -class CustomBatchSampler(BatchSampler): - pass +@pytest.mark.skipif( + sys.platform == "win32" and not _TORCH_GREATER_EQUAL_1_7, reason="Bad `torch.distributed` support on Windows" +) +@pytest.mark.parametrize('mode', (1, 2)) +def test_replace_distributed_sampler(tmpdir, mode): + class IndexedRandomDataset(RandomDataset): -class TestModel(BoringModel): + def __getitem__(self, index): + return self.data[index] - def __init__(self, numbers_test_dataloaders, save_preds_on_dl_idx, mode): - super().__init__() - self._numbers_test_dataloaders = numbers_test_dataloaders - self._save_preds_on_dl_idx = save_preds_on_dl_idx - self._mode = mode + class CustomDataLoader(DataLoader): - def test_step(self, batch, batch_idx, dataloader_idx=None): - return super().test_step(batch, batch_idx) + def __init__(self, num_features, dataset, *args, **kwargs): + self.num_features = num_features + super().__init__(dataset, *args, **kwargs) - def create_dataset(self): - dataset = IndexedRandomDataset(32, 64) - batch_sampler = None - batch_size = 2 - if self._mode == 2: - batch_size = 1 - batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True) - dataloader_cls = CustomDataLoader - else: - dataloader_cls = FailureCustomDataLoader - return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler) - - def test_dataloader(self): - return [self.create_dataset()] * self._numbers_test_dataloaders + class FailureCustomDataLoader(DataLoader): + def __init__(self, num_features, dataset, *args, **kwargs): + super().__init__(dataset, *args, **kwargs) -def check_replace_distributed_sampler(tmpdir, save_preds_on_dl_idx, accelerator, gpus, num_dl_idx, mode): - num_processes = 2 - limit_test_batches = 2 - trainer_args = { - "default_root_dir": tmpdir, - "limit_test_batches": limit_test_batches, - "accelerator": accelerator, - } + class CustomBatchSampler(BatchSampler): + pass - if accelerator == "ddp_cpu": - trainer_args["num_processes"] = num_processes - else: - trainer_args["gpus"] = gpus + class TestModel(BoringModel): - model = TestModel(num_dl_idx, save_preds_on_dl_idx, mode) + def __init__(self, numbers_test_dataloaders, mode): + super().__init__() + self._numbers_test_dataloaders = numbers_test_dataloaders + self._mode = mode + + def test_step(self, batch, batch_idx, dataloader_idx=None): + return super().test_step(batch, batch_idx) + + def on_test_start(self) -> None: + dataloader = self.trainer.test_dataloaders[0] + assert isinstance(dataloader, CustomDataLoader) + assert dataloader.batch_size is None + + batch_sampler = dataloader.batch_sampler + assert isinstance(batch_sampler, CustomBatchSampler) + assert batch_sampler.batch_size == 1 + assert batch_sampler.drop_last + assert isinstance(batch_sampler.sampler, DistributedSampler) + + def create_dataset(self): + dataset = IndexedRandomDataset(32, 64) + batch_sampler = None + batch_size = 2 + if self._mode == 2: + batch_size = 1 + batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True) + dataloader_cls = CustomDataLoader + else: + dataloader_cls = FailureCustomDataLoader + return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler) + + def test_dataloader(self): + return [self.create_dataset()] * self._numbers_test_dataloaders + + model = TestModel(2, mode) model.test_epoch_end = None - trainer = Trainer(**trainer_args) + trainer = Trainer( + default_root_dir=tmpdir, limit_test_batches=2, plugins="ddp_find_unused_parameters_false", num_processes=1 + ) if mode == 1: - match = "DistributedSampler within" + match = escape("missing attributes are ['num_features']") with pytest.raises(MisconfigurationException, match=match): trainer.test(model) else: trainer.test(model) -@RunIf(min_gpus=2, special=True) -def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler_0(tmpdir): - check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode=1) - - -@RunIf(min_gpus=2, special=True) -def test_replace_distributed_sampler_custom_dataloader_custom_batch_sampler_1(tmpdir): - check_replace_distributed_sampler(tmpdir, True, "ddp", 2, 2, mode=2) - - @pytest.mark.parametrize("num_workers", [0, 1]) def test_dataloader_warnings(num_workers): @@ -127,3 +119,143 @@ def on_train_start(self, *_) -> None: trainer = Trainer(accelerator="ddp_spawn") with pytest.warns(UserWarning, match=warn_str), pytest.raises(SystemExit): trainer.fit(TestModel(), dl) + + +def test_replace_sampler_raises(): + trainer = Trainer() + with pytest.raises(ValueError, match="needs to subclass `torch.utils.data.DataLoader"): + trainer.replace_sampler(object(), object(), mode='fit') # noqa + + +def test_dataloaders_with_missing_keyword_arguments(): + trainer = Trainer() + ds = RandomDataset(10, 20) + + class TestDataLoader(DataLoader): + + def __init__(self, dataset): + super().__init__(dataset) + + loader = TestDataLoader(ds) + sampler = SequentialSampler(ds) + match = escape("missing arguments are ['batch_sampler', 'sampler', 'shuffle']") + with pytest.raises(MisconfigurationException, match=match): + trainer.replace_sampler(loader, sampler, mode='fit') + match = escape("missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler', 'shuffle']") + with pytest.raises(MisconfigurationException, match=match): + trainer.replace_sampler(loader, sampler, mode='predict') + + class TestDataLoader(DataLoader): + + def __init__(self, dataset, *args, **kwargs): + super().__init__(dataset) + + loader = TestDataLoader(ds) + sampler = SequentialSampler(ds) + trainer.replace_sampler(loader, sampler, mode='fit') + trainer.replace_sampler(loader, sampler, mode='predict') + + class TestDataLoader(DataLoader): + + def __init__(self, *foo, **bar): + super().__init__(*foo, **bar) + + loader = TestDataLoader(ds) + sampler = SequentialSampler(ds) + trainer.replace_sampler(loader, sampler, mode='fit') + trainer.replace_sampler(loader, sampler, mode='predict') + + class TestDataLoader(DataLoader): + + def __init__(self, num_feat, dataset, *args, shuffle=False): + self.num_feat = num_feat + super().__init__(dataset) + + loader = TestDataLoader(1, ds) + sampler = SequentialSampler(ds) + match = escape("missing arguments are ['batch_sampler', 'sampler']") + with pytest.raises(MisconfigurationException, match=match): + trainer.replace_sampler(loader, sampler, mode='fit') + match = escape("missing arguments are ['batch_sampler', 'batch_size', 'drop_last', 'sampler']") + with pytest.raises(MisconfigurationException, match=match): + trainer.replace_sampler(loader, sampler, mode='predict') + + class TestDataLoader(DataLoader): + + def __init__(self, num_feat, dataset, **kwargs): + self.feat_num = num_feat + super().__init__(dataset) + + loader = TestDataLoader(1, ds) + sampler = SequentialSampler(ds) + match = escape("missing attributes are ['num_feat']") + with pytest.raises(MisconfigurationException, match=match): + trainer.replace_sampler(loader, sampler, mode='fit') + match = escape("missing attributes are ['num_feat']") + with pytest.raises(MisconfigurationException, match=match): + trainer.replace_sampler(loader, sampler, mode='predict') + + +def test_replace_sampler_with_multiprocessing_context(): + """This test verifies that replace_sampler conserves multiprocessing context""" + train = RandomDataset(32, 64) + context = 'spawn' + train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) + trainer = Trainer() + new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) + assert new_data_loader.multiprocessing_context == train.multiprocessing_context + + +def test_dataloader_reinit_for_subclass(): + + class CustomDataLoader(DataLoader): + + def __init__( + self, + dataset, + batch_size=1, + shuffle=False, + sampler=None, + batch_sampler=None, + num_workers=0, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + dummy_kwarg=None, + ): + super().__init__( + dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, + timeout, worker_init_fn + ) + self.dummy_kwarg = dummy_kwarg + self.something_unrelated = 1 + + trainer = Trainer(num_processes=1, accelerator='ddp_cpu') + + class CustomDummyObj: + sampler = None + + result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True) + assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" + + dataset = list(range(10)) + result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True) + assert isinstance(result, DataLoader) + assert isinstance(result, CustomDataLoader) + assert result.dummy_kwarg is None + + # Shuffled DataLoader should also work + result = trainer.auto_add_sampler(CustomDataLoader(dataset, shuffle=True), shuffle=True) + assert isinstance(result, DataLoader) + assert isinstance(result, CustomDataLoader) + assert result.dummy_kwarg is None + + class CustomSampler(Sampler): + pass + + # Should raise an error if existing sampler is being replaced + dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset)) + with pytest.raises(MisconfigurationException, match='will be replaced by `DistributedSampler`'): + trainer.auto_add_sampler(dataloader, shuffle=True) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 34ecef6f0f598..4e38fee91557b 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -21,7 +21,6 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset, IterableDataset, Subset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler import tests.helpers.pipelines as tpipes from pytorch_lightning import Callback, seed_everything, Trainer @@ -988,69 +987,6 @@ def gen(self): assert trainer.current_epoch == 1 -@RunIf(min_gpus=2) -def test_dataloader_reinit_for_subclass(tmpdir): - - class CustomDataLoader(torch.utils.data.DataLoader): - - def __init__( - self, - dataset, - batch_size=1, - shuffle=False, - sampler=None, - batch_sampler=None, - num_workers=0, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - dummy_kwarg=None, - **kwargs - ): - super().__init__( - dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, - timeout, worker_init_fn - ) - - self.dummy_kwarg = dummy_kwarg - - trainer = Trainer( - gpus=[0, 1], - num_nodes=1, - accelerator='ddp_spawn', - default_root_dir=tmpdir, - ) - - class CustomDummyObj: - sampler = None - - result = trainer.auto_add_sampler(CustomDummyObj(), shuffle=True) - assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader" - - dataset = list(range(1000)) - result = trainer.auto_add_sampler(CustomDataLoader(dataset), shuffle=True) - assert isinstance(result, torch.utils.data.DataLoader) - assert isinstance(result, CustomDataLoader) - assert hasattr(result, 'dummy_kwarg') - - # Shuffled DataLoader should also work - result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), shuffle=True) - assert isinstance(result, torch.utils.data.DataLoader) - assert isinstance(result, CustomDataLoader) - assert hasattr(result, 'dummy_kwarg') - - class CustomSampler(torch.utils.data.Sampler): - pass - - # Should raise an error if existing sampler is being replaced - with pytest.raises(MisconfigurationException, match='DistributedSampler'): - trainer.auto_add_sampler( - CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), shuffle=True - ) - - class DistribSamplerCallback(Callback): def __init__(self, expected_seeds=(0, 0, 0)): @@ -1542,23 +1478,6 @@ def test_dataloaders_reset_and_attach(tmpdir): assert trainer.predict_dataloaders[0] is dataloader_1 -def test_replace_sampler_with_multiprocessing_context(tmpdir): - """ - This test verifies that replace_sampler conserves multiprocessing context - """ - train = RandomDataset(32, 64) - context = 'spawn' - train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) - trainer = Trainer( - max_epochs=1, - progress_bar_refresh_rate=20, - overfit_batches=5, - ) - - new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset)) - assert (new_data_loader.multiprocessing_context == train.multiprocessing_context) - - @pytest.mark.parametrize('multiple_trainloader_mode', ["min_size", "max_size_cycle"]) def test_correct_dataloader_idx_in_hooks(tmpdir, multiple_trainloader_mode): """