Skip to content

Commit

Permalink
Support DataLoaders with missing arguments in replace_sampler (#8519
Browse files Browse the repository at this point in the history
)

* Support `DataLoader`s with missing arguments in `replace_sampler`

* Fix for multiprocessing context

* Fixes and test improvements

* Fixes and test improvements

* Fixes and test improvements

* Test any variadic name

* Update CHANGELOG

* Make sure extra attributes can be present

* Skip on old Windows

* Update pytorch_lightning/trainer/data_loading.py

* Update pytorch_lightning/trainer/data_loading.py

* Check is dataloader

* Typo
  • Loading branch information
carmocca authored Jul 26, 2021
1 parent c519fce commit 6dbdf43
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 203 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


- Improved error messages in `replace_sampler` when the `DataLoader` attributes are not included in the signature or the signature is missing optional arguments ([#8519](https://github.com/PyTorchLightning/pytorch-lightning/pull/8519))


- Moved `DeviceDtypeModuleMixin` and `HyperparametersMixin` mixin to `core` ([#8396](https://github.com/PyTorchLightning/pytorch-lightning/pull/8396))


Expand Down Expand Up @@ -529,6 +532,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `ModelPruning` callback `on_save_checkpoint` to avoid making a `deepcopy` potentially leading to OOM ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


- Fixed the sampler replacement logic for `DataLoader`s which do not define all `DataLoader` attributes as `__init__` parameters ([#8519](https://github.com/PyTorchLightning/pytorch-lightning/pull/8519))


- Fixed DeepSpeed Windows support ([#8488](https://github.com/PyTorchLightning/pytorch-lightning/pull/8488))


Expand Down
115 changes: 64 additions & 51 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

def auto_add_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DataLoader:
def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
Expand Down Expand Up @@ -147,7 +145,7 @@ def auto_add_sampler(
return dataloader

@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
Expand All @@ -159,62 +157,77 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningS
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
dl_args['sampler'] = None
dl_args['drop_last'] = False
else:
dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args['batch_sampler'] = None

return dl_args
return {
'sampler': None,
'shuffle': False,
'batch_sampler': batch_sampler,
'batch_size': 1,
'drop_last': False,
}
return {
'sampler': sampler,
'shuffle': False,
'batch_sampler': None,
}

def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')
if not isinstance(dataloader, DataLoader):
raise ValueError(f'The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`')

# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}

params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True

if type(dataloader) is not DataLoader:
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)

dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}

dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode)

multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context

missing_kwargs = params.difference(skip_signature_keys).difference(dl_args)
if missing_kwargs:
"""
Example:
class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)
"""
# not part of `vars`
attrs['multiprocessing_context'] = dataloader.multiprocessing_context

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add('dataset')

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode))

required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY,
p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject DistributedSampler within {dataloader_cls_name} class."
"This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. "
f"Missing attributes are {missing_kwargs}. "
f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ",
"manually add DistributedSampler as "
f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).",
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

if not contains_dataset:
dl_args.pop('dataset')
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

dataloader = type(dataloader)(**dl_args)
dataloader.multiprocessing_context = multiprocessing_context
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
return dataloader

def _get_distributed_sampler(
Expand Down
Loading

0 comments on commit 6dbdf43

Please sign in to comment.