Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Introduce a dataloader factory class to better manage data modules #6776

Closed
ananthsub opened this issue Apr 1, 2021 · 3 comments
Closed
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Apr 1, 2021

🚀 Feature

Create a new class to manage dataloaders
Proposal: https://docs.google.com/document/d/1c0dBmASUfQy0kIpliGD7sGmdzC0sgbuOQStkM02UySM/edit

Motivation

  • DataModules bundle together training, validation, and testing dataloaders
  • Often times, we want to configure different dataloader settings for each of these phases
    • Example: configure larger batch sizes for validation (no gradients so we can use more memory for batches/activations)
    • Example: handle uneven end of data differently across training vs validation (drop tail in training, wraparound for validation)
  • This is combination of bundling of different phases along with a variety of knobs creates complex datamodule initialization logic
  • Furthermore, a generic datamodule is difficult to implement as it can be used for training (where the train dataloader must be defined and optionally the val dataloader) or validation (where only the val dataloader must be defined) or testing (where only the test dataloader must be defined) or prediction (where only the predict dataloader must be defined).

Pitch

from typing import Any, Dict, List, Mapping, Optional
from abc import ABC, abstractmethod

import torch
from pytorch_lightning.core.hooks import CheckpointHooks
import pytorch_lightning as pl
from torch.utils.data import DataLoader

class DataLoaderFactory(ABC, CheckpointHooks):

    def __init__(self) -> None:
        # Pointer to the trainer object
        # Placeholder until we define a proper TrainerContext class (e.g. frozen dataclass)
        # to pass things like progress tracking, rank, or world size to the factory
        self.trainer = None

        # Private attrs to keep track of whether or not data hooks have been called yet
        self._prepare_data_called: bool = False
        self._setup_called: bool = False
        self._teardown_called: bool = False

        # This should not be a trainer setting
        # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#prepare-data-per-node
        self.prepare_data_per_node: bool = True

    def prepare_data(self) -> None:
        pass

    def setup(self) -> None:
        pass
    
    @abstractmethod
    def get_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
    
    def teardown(self) -> None:
        pass

    @property
    def prepare_data_called(self) -> bool:
        return self._prepare_data_called

    @prepare_data_called.setter
    def prepare_data_called(self, val: bool) -> None:
        self._prepare_data_called = val

    @property
    def setup_called(self) -> bool:
        return self._setup_called

    @setup_called.setter
    def setup_called(self, val: bool) -> None:
        self._setup_called = val

    @property
    def teardown_called(self) -> bool:
        return self._teardown_called

    @teardown_called.setter
    def teardown_called(self, val: bool) -> None:
        self._teardown_called = val

We can add optional attributes inside of the datamodule for these classes, one for each of train/val/test/predict as a convenience, along with a convenience method to instantiate a datamodule from these factories similar to this: https://github.com/PyTorchLightning/pytorch-lightning/blob/a72a7992a283f2eb5183d129a8cf6466903f1dc8/pytorch_lightning/core/datamodule.py#L343-L398

class LightningDataModule(...)
    train_dataloader_factory: Optional[DataLoaderFactory] = None
    val_dataloader_factory: Optional[DataLoaderFactory] = None
    test_dataloader_factory: Optional[DataLoaderFactory] = None
    predict_dataloader_factory: Optional[DataLoaderFactory] = None

   @classmethod
    def from_dataloader_factories(
        cls,
        train_dataloader_factory: Optional[DataLoaderFactory] = None,
        val_dataloader_factory: Optional[DataLoaderFactory]  = None,
        test_dataloader_factory: Optional[DataLoaderFactory] = None,
        predict_dataloader_factory: Optional[DataLoaderFactory] = None
    ):
        datamodule = cls()
        if train_dataloader_factory is not None:
            datamodule.train_dataloader = train_dataloader_factory.get_dataloader
        if val_dataloader_factory is not None:
            datamodule.val_dataloader = val_dataloader_factory.get_dataloader
        if test_dataloader_factory is not None:
            datamodule.test_dataloader = test_dataloader_factory.get_dataloader
        if predict_dataloader_factory is not None:
            datamodule.predict_dataloader = predict_dataloader_factory.get_dataloader
        return datamodule

This can also replace the raw dataloaders that are currently accepted on the Trainer.fit()/validate()/test()/predict() APIs - the classes here are an improvement as they can have access to trainer context which the existing dataloaders do not have.

cc @justusschock @awaelchli @carmocca

Alternatives

Additional context

@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on labels Apr 1, 2021
@tmcclintock
Copy link

Rather than holding a pointer to the trainer and waiting until TrainerContext to be implemented, why not pass the trainer state it in at call time to get_dataloaders? This would avoid potential tangled states between the trainer and factory.

@justusschock
Copy link
Member

justusschock commented Apr 6, 2021

Because you might already prepare some other stuff earlier that may rely on the trainer (e.g. when frequently recreating a dataloader yo may not want to do all the steps over and over again and thus may want to rely on some trainer states in other methods too/want them to be triggered asynchronously)

@stale stale bot added the won't fix This will not be worked on label May 7, 2021
@ananthsub
Copy link
Contributor Author

This ended up being more of a configuration issue than anything else. The core DataModule API is mostly good (#7301 + other properties not used anywhere else). Adding dedicated add_{train/val/test}_dataloader_options methods to our datamodule is a way we can support complexity per running stage instead of putting everything into the constructor

@Lightning-AI Lightning-AI deleted a comment from stale bot Jun 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

3 participants