Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract dataloader utilities from TrainerDataLoadingMixin #10145

Merged
merged 23 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2c755cb
move dataloader utils from "mixin" to utils
awaelchli Oct 25, 2021
b304bb2
remove unused imports
awaelchli Oct 25, 2021
680ec8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2021
84cd04e
circular import
awaelchli Oct 26, 2021
cbd20b6
Merge remote-tracking branch 'origin/refactor/extract-dataloader-util…
awaelchli Oct 26, 2021
ef18f96
Merge branch 'master' into refactor/extract-dataloader-utils
awaelchli Nov 18, 2021
e974e0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2021
fd148c6
move update_dataloader to utilities
awaelchli Nov 18, 2021
25f6baf
keep update dataloader protected
awaelchli Nov 18, 2021
aa012d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2021
74e3687
move auto_add_worker_init_fn utility
awaelchli Nov 18, 2021
74b1bc0
Update pytorch_lightning/plugins/training_type/ipu.py
awaelchli Nov 18, 2021
c484738
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2021
9e91845
fix patching with bound method by wrapping in partial
awaelchli Nov 18, 2021
27d4b25
downgrade to protected attr
awaelchli Nov 18, 2021
5e55e98
try unbound method
awaelchli Nov 18, 2021
f0074bc
debug
awaelchli Nov 18, 2021
3357051
debug
awaelchli Nov 18, 2021
79481c4
attempt fix
awaelchli Nov 18, 2021
dd4bf6d
undo debugging statements
awaelchli Nov 18, 2021
49ff435
shorten import
awaelchli Nov 18, 2021
e729650
attempt without partial
awaelchli Nov 18, 2021
626ff06
Merge branch 'master' into refactor/extract-dataloader-utils
awaelchli Nov 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@
TrainingTypePlugin,
)
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import has_iterable_dataset
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, _update_dataloader, has_iterable_dataset
from pytorch_lightning.utilities.device_parser import _parse_devices
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down Expand Up @@ -239,10 +238,10 @@ def _setup_dataloader(
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)

# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
dataloader = TrainerDataLoadingMixin._update_dataloader(dataloader, sampler)
dataloader = _update_dataloader(dataloader, sampler)

# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)
_auto_add_worker_init_fn(dataloader, self.global_rank)

dataloader = self._strategy.process_dataloader(dataloader)
device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.
import json
import os
from functools import partial
from typing import Any, List, Optional, Union

import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import pytorch_lightning.utilities.data
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand All @@ -27,6 +29,7 @@
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _POPTORCH_AVAILABLE:
Expand Down Expand Up @@ -113,7 +116,8 @@ def setup(self) -> None:
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
# to use the simpler solution before adding abstractions to override the `DataLoader` class
self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader
self._update_dataloader_original = pytorch_lightning.utilities.data._update_dataloader
pytorch_lightning.utilities.data._update_dataloader = partial(IPUPlugin._convert_to_poptorch_loader, self)

def pre_dispatch(self) -> None:
precision = self.lightning_module.trainer.precision
Expand Down Expand Up @@ -193,8 +197,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
def _convert_to_poptorch_loader(
self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None
) -> "poptorch.DataLoader":
# use full path to avoid circular imports
dl_kwargs = pl.trainer.trainer.TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler)
# Override to drop last uneven batch, as IPUs does not support uneven inputs.
dl_kwargs["drop_last"] = True

Expand Down Expand Up @@ -259,7 +262,8 @@ def predict_step(self, *args, **kwargs):

def teardown(self) -> None:
# undo dataloader patching
self.lightning_module.trainer._update_dataloader = pl.trainer.trainer.TrainerDataLoadingMixin._update_dataloader
pytorch_lightning.utilities.data._update_dataloader = self._update_dataloader_original

for model in self.poptorch_models.values():
model.destroy()

Expand Down
176 changes: 15 additions & 161 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import multiprocessing
import os
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Collection, List, Optional, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataset import IterableDataset
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper, UnrepeatedDistributedSampler
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
_capture_metadata_collate,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
from pytorch_lightning.utilities.auto_restart import _capture_metadata_collate
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
)
from pytorch_lightning.utilities.data import get_len, has_iterable_dataset, has_len_all_ranks
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import pl_worker_init_function


class TrainerDataLoadingMixin(ABC):
Expand All @@ -63,8 +60,8 @@ class TrainerDataLoadingMixin(ABC):
overfit_batches: Union[int, float]
distributed_sampler_kwargs: dict
accelerator: Accelerator
accelerator_connector: AcceleratorConnector
call_hook: Callable
_accelerator_connector: AcceleratorConnector

def _worker_check(self, dataloader: DataLoader, name: str) -> None:
if not isinstance(dataloader, DataLoader):
Expand Down Expand Up @@ -114,11 +111,6 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
" in the `DataLoader` init to improve performance."
)

@staticmethod
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)

def _requires_distributed_sampler(self, dataloader) -> bool:
return (
self._accelerator_connector.replace_sampler_ddp
Expand Down Expand Up @@ -159,7 +151,7 @@ def prepare_dataloader(self, dataloader: Any, shuffle: bool, mode: Optional[Runn
or self._accelerator_connector.use_ipu # IPUs use a custom `DataLoader`
):
sampler = self._resolve_sampler(dataloader, shuffle=shuffle, mode=mode)
dataloader = self._update_dataloader(dataloader, sampler, mode=mode)
dataloader = _update_dataloader(dataloader, sampler, mode=mode)

if cycle_iterator is not None:
cycle_iterator.loader = dataloader
Expand All @@ -182,138 +174,6 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional

return dataloader.sampler

@staticmethod
def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for
its re-instantiation.

If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`,
so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped
into a `FastForwardSampler`.
"""
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting):
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

if _fault_tolerant_training():
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)

return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}

if _fault_tolerant_training():
fast_forward_sampler = sampler = FastForwardSampler(sampler)
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

return {"sampler": sampler, "shuffle": False, "batch_sampler": None}

@staticmethod
def _get_dataloader_init_kwargs(
dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None
) -> Dict[str, Any]:
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}
# not part of `vars`
attrs["multiprocessing_context"] = dataloader.multiprocessing_context

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
params.update(inspect.signature(DataLoader.__init__).parameters)
del params["self"]

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add("dataset")

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(
TrainerDataLoadingMixin._dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)
)

required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
and p.default is p.empty
and p.name not in dl_kwargs
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

if isinstance(dl_kwargs["dataset"], IterableDataset):
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None

if _fault_tolerant_training():
dataset = dl_kwargs["dataset"]
if isinstance(dataset, IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif get_len(dataset) != float("inf"):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)

return dl_kwargs

@staticmethod
def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
dl_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
return dataloader

@staticmethod
def _get_distributed_sampler(
dataloader: DataLoader,
Expand Down Expand Up @@ -350,7 +210,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, "train_dataloader")

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, self._auto_add_worker_init_fn, rank=self.global_rank)
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
Expand Down Expand Up @@ -445,9 +305,7 @@ def _reset_eval_dataloader(
"You requested to overfit but enabled val/test dataloader shuffling."
" We are turning it off for you."
)
dataloaders[loader_i] = self._update_dataloader(
loader, SequentialSampler(loader.dataset), mode=mode
)
dataloaders[loader_i] = _update_dataloader(loader, SequentialSampler(loader.dataset), mode=mode)
else:
rank_zero_warn(
f"Your `{mode.dataloader_prefix}_dataloader` has `shuffle=True`,"
Expand All @@ -461,9 +319,7 @@ def _reset_eval_dataloader(
dataloaders = [self.prepare_dataloader(dl, False, mode=mode) for dl in dataloaders if dl is not None]

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(
dataloaders, dtype=DataLoader, function=self._auto_add_worker_init_fn, rank=self.global_rank
)
apply_to_collection(dataloaders, dtype=DataLoader, function=_auto_add_worker_init_fn, rank=self.global_rank)

loader_num_batches = []

Expand Down Expand Up @@ -607,9 +463,7 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader):
)

def replace_sampler(dataloader: DataLoader) -> DataLoader:
return TrainerDataLoadingMixin._update_dataloader(
dataloader, SequentialSampler(dataloader.dataset), mode=RunningStage.TRAINING
)
return _update_dataloader(dataloader, SequentialSampler(dataloader.dataset), mode=RunningStage.TRAINING)

dataloader = apply_to_collection(dataloader, DataLoader, replace_sampler)

Expand Down
Loading