Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass datamodule kwargs to datasets #666

Closed
adamjstewart opened this issue Jul 10, 2022 · 0 comments · Fixed by #730
Closed

Pass datamodule kwargs to datasets #666

adamjstewart opened this issue Jul 10, 2022 · 0 comments · Fixed by #730
Labels
backwards-incompatible Changes that are not backwards compatible datamodules PyTorch Lightning datamodules
Milestone

Comments

@adamjstewart
Copy link
Collaborator

Summary

We should remove dataset-specific arguments from datamodules and instead pass them directly to the dataset through kwargs.

Rationale

As an example, we'll look at TropicalCycloneWindEstimation and CycloneDataModule.

TropicalCycloneWindEstimation has a number of options that users may want to configure:

  • root: root directory containing data
  • split: train or test
  • transforms: data augmentations to apply
  • download: whether or not to download the dataset
  • api_key: MLHUB_API_KEY
  • checksum: whether or not to checksum the download

However, CycloneDataModule only exposes a subset of these:

  • root_dir: different name for some reason
  • api_key: MLHUB_API_KEY

If a user wants to, for example, automatically download the dataset, they have to modify the source code of CycloneDataModule. If we instead pass the **kwargs from CycloneDataModule directly to TropicalCycloneWindEstimation, we end up with less code duplication, more features, and greater consistency.

Implementation

Here is what the change would look like for CycloneDataModule:

diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py
index 9a97242..32db799 100644
--- a/torchgeo/datamodules/cyclone.py
+++ b/torchgeo/datamodules/cyclone.py
@@ -26,11 +26,9 @@ class CycloneDataModule(pl.LightningDataModule):
 
     def __init__(
         self,
-        root_dir: str,
         seed: int,
         batch_size: int = 64,
         num_workers: int = 0,
-        api_key: Optional[str] = None,
         **kwargs: Any,
     ) -> None:
         """Initialize a LightningDataModule for NASA Cyclone based DataLoaders.
@@ -45,11 +43,10 @@ class CycloneDataModule(pl.LightningDataModule):
                 downloaded
         """
         super().__init__()  # type: ignore[no-untyped-call]
-        self.root_dir = root_dir
         self.seed = seed
         self.batch_size = batch_size
         self.num_workers = num_workers
-        self.api_key = api_key
+        self.kwargs = kwargs
 
     def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
         """Transform a single sample from the Dataset.
@@ -76,12 +73,7 @@ class CycloneDataModule(pl.LightningDataModule):
         This includes optionally downloading the dataset. This is done once per node,
         while :func:`setup` is done once per GPU.
         """
-        TropicalCycloneWindEstimation(
-            self.root_dir,
-            split="train",
-            download=self.api_key is not None,
-            api_key=self.api_key,
-        )
+        TropicalCycloneWindEstimation(split="train", **self.kwargs)
 
     def setup(self, stage: Optional[str] = None) -> None:
         """Create the train/val/test splits based on the original Dataset objects.

This should be done for all other datamodules.

Alternatives

The alternative is to add new parameters to each datamodule as we need them. This results in an inconsistent API.

Additional information

No response

@adamjstewart adamjstewart added datamodules PyTorch Lightning datamodules backwards-incompatible Changes that are not backwards compatible labels Jul 10, 2022
@adamjstewart adamjstewart added this to the 0.4.0 milestone Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible datamodules PyTorch Lightning datamodules
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant