-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat-BugFix] Resolve custom DataLoader (#5745)
* resolve custom dataloader * update changelog * fix tests * update on comments * resolve comments * add support for custom batch_sampler * Update tests/trainer/test_data_loading.py * resolve test * resolve flake8 * resolve yapf Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
- Loading branch information
1 parent
d2c2e50
commit d8f2d8e
Showing
4 changed files
with
187 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
|
||
import pytest | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.sampler import BatchSampler | ||
from torch.utils.data.sampler import SequentialSampler | ||
|
||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from tests.base import BoringModel | ||
from tests.base import RandomDataset | ||
|
||
|
||
class IndexedRandomDataset(RandomDataset): | ||
|
||
def __getitem__(self, index): | ||
return self.data[index] | ||
|
||
|
||
class CustomDataLoader(DataLoader): | ||
|
||
def __init__(self, num_features, dataset, *args, **kwargs): | ||
self.num_features = num_features | ||
super().__init__(dataset, *args, **kwargs) | ||
|
||
|
||
class FailureCustomDataLoader(DataLoader): | ||
|
||
def __init__(self, num_features, dataset, *args, **kwargs): | ||
super().__init__(dataset, *args, **kwargs) | ||
|
||
|
||
class CustomBatchSampler(BatchSampler): | ||
pass | ||
|
||
|
||
class TestModel(BoringModel): | ||
|
||
def __init__(self, numbers_test_dataloaders, | ||
save_preds_on_dl_idx, mode): | ||
super().__init__() | ||
self._numbers_test_dataloaders = numbers_test_dataloaders | ||
self._save_preds_on_dl_idx = save_preds_on_dl_idx | ||
self._mode = mode | ||
|
||
def test_step(self, batch, batch_idx, dataloader_idx=None): | ||
return super().test_step(batch, batch_idx) | ||
|
||
def create_dataset(self): | ||
dataset = IndexedRandomDataset(32, 64) | ||
batch_sampler = None | ||
batch_size = 2 | ||
if self._mode == 2: | ||
batch_size = 1 | ||
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=True) | ||
dataloader_cls = CustomDataLoader | ||
else: | ||
dataloader_cls = FailureCustomDataLoader | ||
return dataloader_cls(32, dataset, batch_size=batch_size, batch_sampler=batch_sampler) | ||
|
||
def test_dataloader(self): | ||
return [self.create_dataset()] * self._numbers_test_dataloaders | ||
|
||
|
||
def check_replace_distrubuted_sampler( | ||
tmpdir, | ||
save_preds_on_dl_idx, | ||
accelerator, | ||
gpus, | ||
num_dl_idx, | ||
mode | ||
): | ||
num_processes = 2 | ||
limit_test_batches = 2 | ||
trainer_args = { | ||
"default_root_dir": tmpdir, | ||
"limit_test_batches" : limit_test_batches, | ||
"accelerator": accelerator, | ||
} | ||
|
||
if accelerator == "ddp_cpu": | ||
trainer_args["num_processes"] = num_processes | ||
else: | ||
trainer_args["gpus"] = gpus | ||
|
||
model = TestModel(num_dl_idx, save_preds_on_dl_idx, mode) | ||
model.test_epoch_end = None | ||
|
||
trainer = Trainer(**trainer_args) | ||
if mode == 1: | ||
match = "DistributedSampler within" | ||
with pytest.raises(MisconfigurationException, match=match): | ||
trainer.test(model) | ||
else: | ||
trainer.test(model) | ||
|
||
|
||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', | ||
reason="test should be run outside of pytest") | ||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") | ||
@pytest.mark.parametrize("mode", [1, 2]) | ||
def test_replace_distrubuted_sampler_custom_dataloader_custom_batch_sampler(tmpdir, mode): | ||
check_replace_distrubuted_sampler(tmpdir, True, "ddp", 2, 2, mode) |