Skip to content

Commit

Permalink
Revert part of #10279 (#10376)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Nov 8, 2021
1 parent 504556f commit 012bc12
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 62 deletions.
17 changes: 7 additions & 10 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,15 @@ def _setup_dataloader(
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)

dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
try:
dataloader = type(dataloader)(**dataloader_kwargs)
except TypeError:
dataloader_kwargs.pop("dataset")
dataloader = type(dataloader)(**dataloader_kwargs)
# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler)

# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
return _LiteDataLoader(
dataloader=self._strategy.process_dataloader(dataloader),
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
)

dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
return _LiteDataLoader(dataloader=dataloader, device=device)

def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
Expand Down
80 changes: 28 additions & 52 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from torch.utils.data import DataLoader, DistributedSampler, Sampler

from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.utilities import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -192,57 +197,6 @@ def run(self):
LiteWithCustomDataLoader().run()


def test_setup_custom_dataloaders():
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()

class CustomDataLoader(DataLoader):
def __init__(self, value: int = 2, *args, **kwargs):
self.value = value
super().__init__(range(value), *args, **kwargs)

dataloader = CustomDataLoader(2, batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader.value == 2
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader2(DataLoader):
def __init__(self, range, *args, **kwargs):
self.range = range
super().__init__(range, *args, **kwargs)

dataloader = CustomDataLoader2(range(2), batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader(DataLoader):
def __init__(self, value: int, *args, **kwargs):
super().__init__(range(value), *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
def run(self):
# This doesn't fail as the context manager would save all the arguments provided
# to the dataloaders.
dataloader = CustomDataLoader(2, batch_size=2)
self.setup_dataloaders(dataloader)

LiteWithCustomDataLoader().run()

with pytest.raises(
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
):
dataloader = CustomDataLoader(2, batch_size=2)
lite_dataloader = lite.setup_dataloaders(dataloader)


def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()
Expand Down Expand Up @@ -490,3 +444,25 @@ def run(self):
assert self.is_global_zero == (self.local_rank == 0)

Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()


def test_replace_dataloader_init_method():
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""

class CustomDataLoader(DataLoader):
def __init__(self, extra_argument: int, *args, **kwargs):
super().__init__(*args, **kwargs)

dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
lite = EmptyLite()
with pytest.raises(MisconfigurationException, match="extra_argument"):
dataloader = lite.setup_dataloaders(dataloader)

with _replace_dataloader_init_method():
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)

dataloader = CustomDataLoader(1, range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)

0 comments on commit 012bc12

Please sign in to comment.