Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat-BugFix] Resolve custom DataLoader #5745

Merged
merged 15 commits into from
Feb 5, 2021
Merged
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


- Refactored Accelerators and Plugins
- Refactored Accelerators and Plugins
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))
* Added new Accelerators for CPU, GPU and TPU ([#5719](https://github.com/PyTorchLightning/pytorch-lightning/pull/5719))
Expand Down Expand Up @@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619))


- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))



## [1.1.4] - YYYY-MM-DD

Expand Down
41 changes: 38 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import multiprocessing
import platform
from abc import ABC
Expand Down Expand Up @@ -109,14 +109,49 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
skip_valid_keys = ['args', 'kwargs', 'self']
tchaton marked this conversation as resolved.
Show resolved Hide resolved

params = {k:v for k, v in vars(dataloader).items() if not k.startswith("_")}
tchaton marked this conversation as resolved.
Show resolved Hide resolved

valid_kwargs = set(inspect.signature(dataloader.__init__).parameters)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
contains_dataset = True

if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in valid_kwargs
tchaton marked this conversation as resolved.
Show resolved Hide resolved
valid_kwargs.update(inspect.signature(DataLoader.__init__).parameters)

dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
name: params[name] for name in valid_kwargs
if name in params and name not in skip_keys
}

dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args['batch_sampler'] = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context

missing_kwargs = valid_kwargs.difference(skip_valid_keys).difference(dl_args)
if missing_kwargs:
"""
Example:
class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)
"""
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject DistributedSampler within {dataloader_cls_name} class."
"This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. "
f"Missing attributes are {missing_kwargs}. "
f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ",
"manually add DistributedSampler as "
f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).",
)

if not contains_dataset:
dl_args.pop('dataset')

dataloader = type(dataloader)(**dl_args)
dataloader.multiprocessing_context = multiprocessing_context
return dataloader
Expand Down
2 changes: 2 additions & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_l
python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_prediction_collection_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_misconfiguration_on_dataloader
88 changes: 88 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

import tests.base.develop_utils as tutils
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
Expand Down Expand Up @@ -1628,3 +1629,90 @@ def test_pytorch_profiler_nested(tmpdir):

expected_c = ['add']
assert [e.name for e in pa['c']] == expected_c

tchaton marked this conversation as resolved.
Show resolved Hide resolved

class IndexedRandomDataset(RandomDataset):

def __getitem__(self, index):
return self.data[index]


class CustomDataLoader(DataLoader):

def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)


class FailureCustomDataLoader(DataLoader):

def __init__(self, num_features, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)


class TestModel(BoringModel):

def __init__(self, numbers_test_dataloaders,
save_preds_on_dl_idx, failure):
super().__init__()
self._numbers_test_dataloaders = numbers_test_dataloaders
self._save_preds_on_dl_idx = save_preds_on_dl_idx
self._failure = failure

def create_dataset(self):
dataloader_cls = FailureCustomDataLoader if self._failure > 0 else CustomDataLoader
return dataloader_cls(32, IndexedRandomDataset(32, 64), batch_size=2)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders


def check_prediction_collection(tmpdir, save_preds_on_dl_idx, accelerator, gpus,
num_dl_idx, failure=0):
num_processes = 2
limit_test_batches = 2
trainer_args = {
"default_root_dir": tmpdir,
"limit_test_batches" : limit_test_batches,
"accelerator": accelerator,
}

if accelerator == "ddp_cpu":
trainer_args["num_processes"] = num_processes
else:
trainer_args["gpus"] = gpus

model = TestModel(num_dl_idx, save_preds_on_dl_idx, failure)
model.test_epoch_end = None

trainer = Trainer(**trainer_args)
if failure == 1:
try:
_ = trainer.test(model)
except MisconfigurationException as e:
assert "Missing attributes are {'num_features'}." in str(e)
return
tchaton marked this conversation as resolved.
Show resolved Hide resolved

else:
try:
_ = trainer.test(model)
except MisconfigurationException as e:
assert "inject DistributedSampler within FailureCustomDataLoader" in str(e)


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_misconfiguration_on_dataloader(tmpdir):
"""
Test Lightning raise a MisConfiguration error as we can't re-instantiate user Dataloader
"""
check_prediction_collection(tmpdir, True, "ddp", 2, 2, failure=1)


@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires a GPU machine")
def test_prediction_collection_1_gpu_failure(tmpdir):
"""
Test `PredictionCollection` will raise warning as we are using an invalid custom Dataloader
"""
check_prediction_collection(tmpdir, True, None, 1, 1, failure=2)