Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat-BugFix] Resolve custom DataLoader #5745

Merged
merged 15 commits into from
Feb 5, 2021
Merged
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


- Refactored Accelerators and Plugins
- Refactored Accelerators and Plugins
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))
* Added new Accelerators for CPU, GPU and TPU ([#5719](https://github.com/PyTorchLightning/pytorch-lightning/pull/5719))
Expand Down Expand Up @@ -169,6 +169,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619))


- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))



## [1.1.4] - YYYY-MM-DD

Expand Down
67 changes: 61 additions & 6 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import multiprocessing
import platform
from abc import ABC
from copy import deepcopy
from typing import Callable, Iterable, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
Expand Down Expand Up @@ -107,16 +107,71 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

return dataloader

@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler):
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last
)
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
dl_args['sampler'] = None
dl_args['drop_last'] = False
else:
dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args['batch_sampler'] = None
return dl_args

def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')

attrs = {k:v for k, v in vars(dataloader).items() if not k.startswith("_")}

params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True

if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)

dl_args = {
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
name: attrs[name] for name in params
if name in attrs and name not in skip_keys
}

dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler)

multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context

missing_kwargs = params.difference(skip_signature_keys).difference(dl_args)
if missing_kwargs:
"""
Example:
class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)
"""
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject DistributedSampler within {dataloader_cls_name} class."
"This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. "
f"Missing attributes are {missing_kwargs}. "
f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ",
"manually add DistributedSampler as "
f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).",
)

if not contains_dataset:
dl_args.pop('dataset')

dataloader = type(dataloader)(**dl_args)
dataloader.multiprocessing_context = multiprocessing_context
return dataloader
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_l
python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler
123 changes: 123 additions & 0 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, RandomDataset


class IndexedRandomDataset(RandomDataset):

def __getitem__(self, index):
return self.data[index]


class CustomDataLoader(DataLoader):

def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)


class FailureCustomDataLoader(DataLoader):

def __init__(self, num_features, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)


class CustomBatchSampler(BatchSampler):
pass


class TestModel(BoringModel):

def __init__(self, numbers_test_dataloaders,
save_preds_on_dl_idx, mode):
super().__init__()
self._numbers_test_dataloaders = numbers_test_dataloaders
self._save_preds_on_dl_idx = save_preds_on_dl_idx
self._mode = mode

def test_step(self, batch, batch_idx, dataloader_idx=None):
return super().test_step(batch, batch_idx)

def create_dataset(self):
dataset = IndexedRandomDataset(32, 64)
batch_sampler = None
batch_size = 2
if self._mode == 3:
batch_size = 1
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True)
dataloader_cls = CustomDataLoader
else:
dataloader_cls = FailureCustomDataLoader if self._mode > 0 else CustomDataLoader
return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders


def check_replace_distrubuted_sampler(
tmpdir,
save_preds_on_dl_idx,
accelerator,
gpus,
num_dl_idx,
mode
):
num_processes = 2
limit_test_batches = 2
trainer_args = {
"default_root_dir": tmpdir,
"limit_test_batches" : limit_test_batches,
"accelerator": accelerator,
}

if accelerator == "ddp_cpu":
trainer_args["num_processes"] = num_processes
else:
trainer_args["gpus"] = gpus

model = TestModel(num_dl_idx, save_preds_on_dl_idx, mode)
model.test_epoch_end = None

trainer = Trainer(**trainer_args)
if mode < 3:
if mode == 1:
match = "Missing attributes are"
else:
match = "DistributedSampler within"
with pytest.raises(MisconfigurationException, match=match):
_ = trainer.test(model)
else:
_ = trainer.test(model)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize("mode", [1, 3])
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)


@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires a GPU machine")
def test_replace_distrubuted_sampler_1_gpu_mode(tmpdir):
check_replace_distrubuted_sampler(tmpdir, True, None, 1, 1, 2)