Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DataLoaders with missing arguments in replace_sampler #8519

Merged
merged 13 commits into from
Jul 26, 2021
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added private `prevent_trainer_and_dataloaders_deepcopy` context manager on the `LightningModule` ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


- Improved error messages in `replace_sampler` when the `DataLoader` attributes are not included in the signature or the signature is missing optional arguments ([#8519](https://github.com/PyTorchLightning/pytorch-lightning/pull/8519))


- Moved `DeviceDtypeModuleMixin` and `HyperparametersMixin` mixin to `core` ([#8396](https://github.com/PyTorchLightning/pytorch-lightning/pull/8396))


Expand Down Expand Up @@ -522,6 +525,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `ModelPruning` callback `on_save_checkpoint` to avoid making a `deepcopy` potentially leading to OOM ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


- Fixed the sampler replacement logic for `DataLoader`s which do not define all `DataLoader` attributes as `__init__` parameters ([#8519](https://github.com/PyTorchLightning/pytorch-lightning/pull/8519))


- Fixed DeepSpeed Windows support ([#8488](https://github.com/PyTorchLightning/pytorch-lightning/pull/8488))


Expand Down
115 changes: 64 additions & 51 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,7 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank)

def auto_add_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> DataLoader:
def auto_add_sampler(self, dataloader: Any, shuffle: bool, mode: Optional[RunningStage] = None) -> Any:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
Expand Down Expand Up @@ -147,7 +145,7 @@ def auto_add_sampler(
return dataloader

@staticmethod
def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
def _resolve_batch_sampler(dataloader, sampler, mode: Optional[RunningStage] = None) -> Dict[str, Any]:
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
Expand All @@ -159,62 +157,77 @@ def _resolve_batch_sampler(dl_args, dataloader, sampler, mode: Optional[RunningS
)
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
dl_args['batch_sampler'] = batch_sampler
dl_args['batch_size'] = 1
dl_args['shuffle'] = False
dl_args['sampler'] = None
dl_args['drop_last'] = False
else:
dl_args['sampler'] = sampler
dl_args['shuffle'] = False
dl_args['batch_sampler'] = None

return dl_args
return {
'sampler': None,
'shuffle': False,
'batch_sampler': batch_sampler,
'batch_size': 1,
'drop_last': False,
}
return {
'sampler': sampler,
'shuffle': False,
'batch_sampler': None,
}

def replace_sampler(self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None) -> DataLoader:
skip_keys = ('sampler', 'batch_sampler', 'dataset_kind')
skip_signature_keys = ('args', 'kwargs', 'self')
if not isinstance(dataloader, DataLoader):
raise ValueError(f'The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`')

# get the dataloader instance attributes
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")}

params = set(inspect.signature(dataloader.__init__).parameters)
contains_dataset = True

if type(dataloader) is not DataLoader:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
contains_dataset = "dataset" in params
params.update(inspect.signature(DataLoader.__init__).parameters)

dl_args = {name: attrs[name] for name in params if name in attrs and name not in skip_keys}

dl_args = self._resolve_batch_sampler(dl_args, dataloader, sampler, mode=mode)

multiprocessing_context = dataloader.multiprocessing_context
dl_args['multiprocessing_context'] = multiprocessing_context

missing_kwargs = params.difference(skip_signature_keys).difference(dl_args)
if missing_kwargs:
"""
Example:
class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)
"""
# not part of `vars`
attrs['multiprocessing_context'] = dataloader.multiprocessing_context

# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)

# keep only the params whose default is different to the current attr value
non_defaults = {name for name, p in params.items() if name in attrs and p.default != attrs[name]}
# add `dataset` as it might have been replaced with `*args`
non_defaults.add('dataset')

# kwargs to re-construct the dataloader
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults}
dl_kwargs.update(self._resolve_batch_sampler(dataloader, sampler, mode=mode))

required_args = {
p.name
for p in params.values()
if p.kind in (p.POSITIONAL_ONLY,
p.POSITIONAL_OR_KEYWORD) and p.default is p.empty and p.name not in dl_kwargs
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject DistributedSampler within {dataloader_cls_name} class."
"This would fail as your DataLoader doesn't expose all its __init__ parameters as attributes. "
f"Missing attributes are {missing_kwargs}. "
f"HINT: If you wrote the {dataloader_cls_name} class, add the `__init__` arguments as attributes or ",
"manually add DistributedSampler as "
f"{dataloader_cls_name}(dataset, ..., sampler=DistributedSampler(dataset, ...)).",
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

if not contains_dataset:
dl_args.pop('dataset')
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if not has_variadic_kwargs:
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = dl_kwargs.keys() - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_cls_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_cls_name}(dataset, sampler=DistributedSampler(dataset))`."
)

dataloader = type(dataloader)(**dl_args)
dataloader.multiprocessing_context = multiprocessing_context
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
return dataloader

def _get_distributed_sampler(
Expand Down
Loading