Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add teardown hook to LightningDataModule #4673

Merged
merged 25 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


- Added `teardown()` hook to LightningDataModule ([#4673](https://github.com/PyTorchLightning/pytorch-lightning/pull/4673))


- Added `auto_insert_metric_name` parameter to `ModelCheckpoint` ([#6277](https://github.com/PyTorchLightning/pytorch-lightning/pull/6277))


Expand Down
4 changes: 2 additions & 2 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,7 @@ prepare_data
setup
~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.setup
.. automethod:: pytorch_lightning.core.hooks.DataHooks.setup
:noindex:

tbptt_split_batch
Expand All @@ -1268,7 +1268,7 @@ tbptt_split_batch
teardown
~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.teardown
.. automethod:: pytorch_lightning.core.hooks.DataHooks.teardown
:noindex:

train_dataloader
Expand Down
13 changes: 12 additions & 1 deletion docs/source/extensions/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)

def teardown(self, stage: Optional[str] = None):
# Used to clean-up when the run is finished
...

But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can
let Lightning handle those details for you while making this dataset reusable so you can share with
colleagues or use in different projects.
Expand Down Expand Up @@ -243,7 +247,10 @@ There are also data operations you might want to perform on every GPU. Use setup
self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape)


.. warning:: `setup` is called from every process. Setting state here is okay.
.. warning:: ``setup`` is called from every process. Setting state here is okay.


.. note:: ``teardown`` can be used to clean up the state. It is also called from every process


train_dataloader
Expand Down Expand Up @@ -411,10 +418,14 @@ You can of course use DataModules in plain PyTorch code as well.
for batch in dm.val_dataloader():
...

dm.teardown(stage='fit')

# lazy load test data
dm.setup(stage='test')
for batch in dm.test_dataloader():
...

dm.teardown(stage='test')

But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified
structure.
67 changes: 54 additions & 13 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""LightningDataModule for loading DataLoaders with ease."""

import functools
from abc import abstractmethod
from argparse import ArgumentParser, Namespace
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -44,6 +43,8 @@ def __call__(cls, *args, **kwargs):
cls.prepare_data = track_data_hook_calls(rank_zero_only(cls.prepare_data))
# Track setup calls
cls.setup = track_data_hook_calls(cls.setup)
# Track teardown calls
cls.teardown = track_data_hook_calls(cls.teardown)

# Get instance of LightningDataModule by mocking its __init__ via __call__
obj = type.__call__(cls, *args, **kwargs)
Expand All @@ -52,12 +53,13 @@ def __call__(cls, *args, **kwargs):


def track_data_hook_calls(fn):
"""A decorator that checks if prepare_data/setup have been called.
"""A decorator that checks if prepare_data/setup/teardown has been called.

- When ``dm.prepare_data()`` is called, ``dm.has_prepared_data`` gets set to True
- When ``dm.setup()``, ``dm.has_setup_{fit,validate,test}`` get set to True
- When ``dm.setup(stage)`` is called, where stage is any of ``{fit,validate,test,predict}``.
Its corresponding `dm_has_setup_{stage}` attribute gets set to True
- ``dm.teardown()`` and ``dm.teardown(stage)`` act exactly like ``dm.setup``

Args:
fn (function): Function that will be tracked to see if it has been called.
Expand All @@ -71,9 +73,10 @@ def wrapped_fn(*args, **kwargs):

# The object instance from which setup or prepare_data was called
obj = args[0]
name = fn.__name__

# If calling setup, we check the stage and assign stage-specific bool args
if fn.__name__ == "setup":
if name in ("setup", "teardown"):

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit', 'validate', and 'test' to True.
Expand All @@ -82,11 +85,11 @@ def wrapped_fn(*args, **kwargs):

if stage is None:
for s in ("fit", "validate", "test"):
setattr(obj, f"_has_setup_{s}", True)
setattr(obj, f"_has_{name}_{s}", True)
else:
setattr(obj, f"_has_setup_{stage}", True)
setattr(obj, f"_has_{name}_{stage}", True)

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if fn.__name__ == "prepare_data":
elif name == "prepare_data":
obj._has_prepared_data = True

return fn(*args, **kwargs)
Expand Down Expand Up @@ -119,14 +122,18 @@ def val_dataloader(self):
def test_dataloader(self):
test_split = Dataset(...)
return DataLoader(test_split)
def teardown(self):
# clean up after fit or test
# called on every process in DDP

A DataModule implements 5 key methods:
A DataModule implements 6 key methods:

* **prepare_data** (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
* **setup** (things to do on every accelerator in distributed mode).
* **train_dataloader** the training dataloader.
* **val_dataloader** the val dataloader(s).
* **test_dataloader** the test dataloader(s).
* **teardown** (things to do on every accelerator in distributed mode when finished)


This allows you to share a full dataset without explaining how to download,
Expand Down Expand Up @@ -154,11 +161,17 @@ def __init__(

# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False

self._has_setup_fit = False
self._has_setup_validate = False
self._has_setup_test = False
self._has_setup_predict = False

self._has_teardown_fit = False
self._has_teardown_validate = False
self._has_teardown_test = False
self._has_teardown_predict = False

@property
def train_transforms(self):
"""
Expand Down Expand Up @@ -259,13 +272,41 @@ def has_setup_predict(self) -> bool:
"""
return self._has_setup_predict

@abstractmethod
def prepare_data(self, *args, **kwargs):
pass
@property
def has_teardown_fit(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='fit')`` has been called or not.

@abstractmethod
def setup(self, stage: Optional[str] = None):
pass
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Returns:
bool: True ``if datamodule.teardown(stage='fit')`` has been called. False by default.
"""
return self._has_teardown_fit

@property
def has_teardown_validate(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='validate')`` has been called or not.

Returns:
bool: True if ``datamodule.teardown(stage='validate')`` has been called. False by default.
"""
return self._has_teardown_validate

@property
def has_teardown_test(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='test')`` has been called or not.

Returns:
bool: True if ``datamodule.teardown(stage='test')`` has been called. False by default.
"""
return self._has_teardown_test

@property
def has_teardown_predict(self) -> bool:
"""Return bool letting you know if ``datamodule.teardown(stage='predict')`` has been called or not.

Returns:
bool: True if ``datamodule.teardown(stage='predict')`` has been called. False by default.
"""
return self._has_teardown_predict

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
Expand Down
72 changes: 36 additions & 36 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,6 @@
class ModelHooks:
"""Hooks to be used in LightningModule."""

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``

Example::

class LitModel(...):
def __init__(self):
self.l1 = None

def prepare_data(self):
download_data()
tokenize()

# don't do this
self.something = else

def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)

"""

def teardown(self, stage: Optional[str] = None) -> None:
"""
Called at the end of fit (train + validate), validate, test, predict, or tune.

Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""

def on_fit_start(self) -> None:
"""
Called at the very beginning of fit.
Expand Down Expand Up @@ -383,6 +347,42 @@ def prepare_data(self):
model.test_dataloader()
"""

def setup(self, stage: Optional[str] = None) -> None:
"""
Called at the beginning of fit (train + validate), validate, test, predict, or tune.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``

Example::

class LitModel(...):
def __init__(self):
self.l1 = None

def prepare_data(self):
download_data()
tokenize()

# don't do this
self.something = else

def setup(stage):
data = Load_data(...)
self.l1 = nn.Linear(28, data.num_classes)

"""

def teardown(self, stage: Optional[str] = None) -> None:
"""
Called at the end of fit (train + validate), validate, test, predict, or tune.

Args:
stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'``
"""

def train_dataloader(self) -> Any:
"""
Implement one or more PyTorch DataLoaders for training.
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,12 @@ def call_setup_hook(self, model: LightningModule) -> None:

def call_teardown_hook(self, model: LightningModule) -> None:
state = self._teardown_state

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_teardown_{state}')
if not called:
self.datamodule.teardown(stage=state)

self.profiler.teardown(stage=state)
self.teardown(stage=state)
model.teardown(stage=state)
Expand Down
Loading