Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat-BugFix] Resolve custom DataLoader #5745

Merged
merged 15 commits into from
Feb 5, 2021
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added RPC and Sharded plugins ([#5732](https://github.com/PyTorchLightning/pytorch-lightning/pull/5732))
* Added missing `LightningModule`-wrapper logic to new plugins and accelerator ([#5734](https://github.com/PyTorchLightning/pytorch-lightning/pull/5734))


### Deprecated

- Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down Expand Up @@ -172,6 +173,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619))


- Fixed support custom DataLoader with DDP if they can be re-instantiated ([#5745](https://github.com/PyTorchLightning/pytorch-lightning/pull/5745))



## [1.1.4] - YYYY-MM-DD

Expand Down
73 changes: 66 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import multiprocessing
import platform
from abc import ABC
from copy import deepcopy
from typing import Callable, Iterable, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from torch.utils.data import SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.accelerators.legacy.accelerator import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.data import has_iterable_dataset
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -113,14 +117,69 @@ def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:

return dataloader

@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler):
batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
)
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
dl_args['sampler'] = None
dl_args['drop_last'] = False
else:
dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args['batch_sampler'] = None

return dl_args

def replace_sampler(self, dataloader, sampler):
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')

attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}

params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True

dl_args = {k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys}
if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)

dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}

dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler)

dl_args['sampler'] = sampler
dl_args['shuffle'] = False
multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context

missing_kwargs = params.difference(skip_signature_keys).difference(dl_args)
if missing_kwargs:
"""
Example:
class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)
"""
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject DistributedSampler within {dataloader_cls_name} class."
"This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. "
f"Missing attributes are {missing_kwargs}. "
f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ",
"manually add DistributedSampler as "
f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).",
)

if not contains_dataset:
dl_args.pop('dataset')

dataloader = type(dataloader)(**dl_args)
dataloader.multiprocessing_context = multiprocessing_context
return dataloader
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_l
python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp
python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
python ${DEFAULTS} tests/trainer/test_data_loading.py::test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler
117 changes: 117 additions & 0 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.sampler import SequentialSampler

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
from tests.base import RandomDataset


class IndexedRandomDataset(RandomDataset):

def __getitem__(self, index):
return self.data[index]


class CustomDataLoader(DataLoader):

def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)


class FailureCustomDataLoader(DataLoader):

def __init__(self, num_features, dataset, *args, **kwargs):
super().__init__(dataset, *args, **kwargs)


class CustomBatchSampler(BatchSampler):
pass


class TestModel(BoringModel):

def __init__(self, numbers_test_dataloaders,
save_preds_on_dl_idx, mode):
super().__init__()
self._numbers_test_dataloaders = numbers_test_dataloaders
self._save_preds_on_dl_idx = save_preds_on_dl_idx
self._mode = mode

def test_step(self, batch, batch_idx, dataloader_idx=None):
return super().test_step(batch, batch_idx)

def create_dataset(self):
dataset = IndexedRandomDataset(32, 64)
batch_sampler = None
batch_size = 2
if self._mode == 2:
batch_size = 1
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True)
dataloader_cls = CustomDataLoader
else:
dataloader_cls = FailureCustomDataLoader
return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders


def check_replace_distrubuted_sampler(
tmpdir,
save_preds_on_dl_idx,
accelerator,
gpus,
num_dl_idx,
mode
):
num_processes = 2
limit_test_batches = 2
trainer_args = {
"default_root_dir": tmpdir,
"limit_test_batches" : limit_test_batches,
"accelerator": accelerator,
}

if accelerator == "ddp_cpu":
trainer_args["num_processes"] = num_processes
else:
trainer_args["gpus"] = gpus

model = TestModel(num_dl_idx, save_preds_on_dl_idx, mode)
model.test_epoch_end = None

trainer = Trainer(**trainer_args)
if mode == 1:
match = "DistributedSampler within"
with pytest.raises(MisconfigurationException, match=match):
trainer.test(model)
else:
trainer.test(model)


@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize("mode", [1, 2])
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode):
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode)