Skip to content

Commit

Permalink
[bugfix] Reduce memory leaks (#8490)
Browse files Browse the repository at this point in the history
* reduce memory leak

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update changelog

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* resolve flake8

* update on comments

* resolve bug

* update

* Undo whitespace changes

* remove bug

* resolve flake8

* revert change

* update on comments

* delete the ddp wrapper as it hold memory

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolve flake8

* update on comments

* update changelog

* resolve test

* Update CHANGELOG

* Refactor teardown

* Fix comment

* Do it for non-gpu too

* remove ref when the model is not a lightning_module

* Fix import error

* move down

* resolve bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* resolve assignement

* update

* move above

* Fix device calls to support tpu training

* Updat todo

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Kaushik B <kaushikbokka@gmail.com>
  • Loading branch information
6 people committed Jul 21, 2021
1 parent d0038b5 commit c9af1a7
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 22 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284))


- Fixed hash of LightningEnum to work with value instead of name([#8421](https://github.com/PyTorchLightning/pytorch-lightning/pull/8421)).


- Fixed `move_data_to_device` to return the batch if the object `to` function didn't return `self` ([#8433](https://github.com/PyTorchLightning/pytorch-lightning/pull/8433))


Expand All @@ -496,6 +498,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


- Fixed memory leaks on GPU by moving `optimizer_states`, `ResultCollection.extra`, `ResultMetric` attributes, and `LoggerConnector` metrics to `cpu`. Also, delete the DDP wrapper on `teardown` ([#8490](https://github.com/PyTorchLightning/pytorch-lightning/pull/8490))


- Fixed `SWA` callback using LightningModule `prevent_trainer_and_dataloaders_deepcopy` to avoid OOM ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


Expand Down
10 changes: 4 additions & 6 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -112,13 +111,12 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:

self.precision_plugin.pre_dispatch()

def _move_optimizer_state(self) -> None:
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
""" Moves the state of the optimizers to the GPU if needed. """
device = device or self.root_device
for opt in self.optimizers:
state: DefaultDict = defaultdict(dict)
for p, v in opt.state.items():
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
opt.state = state
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)

def dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ def set_nvidia_flags(local_rank: int) -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def teardown(self) -> None:
super().teardown()
self._move_optimizer_state(torch.device("cpu"))
12 changes: 11 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
from typing import Any, Callable, Optional

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -21,6 +22,7 @@
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _XLA_AVAILABLE:
Expand Down Expand Up @@ -49,3 +51,11 @@ def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})

def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
""" Moves the state of the optimizers to the TPU if needed. """
# TODO: `self.root_device` would raise error if called outside the spawn process
# while training on 8 and more cores.
for opt in self.optimizers:
for p, v in opt.state.items():
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
5 changes: 5 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def block_backward_sync(self):
yield None

def teardown(self) -> None:
# Un-reference the wrapper if any was used.
# todo (tchaton): Add support for all plugins.
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module

if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TrainingTypePlugin(Plugin, ABC):
"""

def __init__(self) -> None:
self._model = None
self._model: Optional[Module] = None
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
self._call_configure_sharded_model_hook = True

Expand Down Expand Up @@ -121,12 +121,12 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs
"""Hook to do something after each optimizer step."""

@property
def model(self) -> Module:
def model(self) -> Optional[Module]:
"""Returns the potentially wrapped LightningModule"""
return self._model

@model.setter
def model(self, new_model: Module) -> None:
def model(self, new_model: Optional[Module]) -> None:
self._model = new_model

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT

Expand Down Expand Up @@ -312,3 +313,9 @@ def progress_bar_metrics(self) -> Dict[str, float]:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics

def teardown(self):
args = (torch.Tensor, move_data_to_device, "cpu")
self._logged_metrics = apply_to_collection(self._logged_metrics, *args)
self._progress_bar_metrics = apply_to_collection(self._progress_bar_metrics, *args)
self._callback_metrics = apply_to_collection(self._callback_metrics, *args)
21 changes: 9 additions & 12 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.metrics import metrics_to_scalars
Expand Down Expand Up @@ -254,12 +253,7 @@ def __getstate__(self, drop_value: bool = False) -> dict:
if not self.is_tensor and drop_value:
# Avoid serializing ResultMetrics which are passed Metrics
skip.append('value')
with self.sync_context(
should_sync=not self.meta.sync.rank_zero_only,
process_group=self.meta.sync.group,
distributed_available=distributed_available
):
d = {k: v for k, v in self.__dict__.items() if k not in skip}
d = {k: v for k, v in self.__dict__.items() if k not in skip}
d['meta'] = d['meta'].__getstate__()
d['_class'] = self.__class__.__name__
return d
Expand All @@ -276,6 +270,12 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'Resul
result_metric.__setstate__(state, sync_fn=sync_fn)
return result_metric

def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin':
self.__dict__.update(
apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)
)
return self


class ResultMetricCollection(dict):
"""
Expand Down Expand Up @@ -597,10 +597,7 @@ def extract_batch_size(self, batch: Any) -> None:
def to(self, *args, **kwargs) -> 'ResultCollection':
"""Move all data to the given device."""

def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]:
return item.to(*args, **kwargs)

apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs)
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))

if self.minimize is not None:
self.minimize = self.minimize.to(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ def _post_dispatch(self):
# which need to happen before.
self.accelerator.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()

def _dispatch(self):
if self.evaluating:
Expand Down
52 changes: 52 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
import torch
from omegaconf import OmegaConf
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -1969,3 +1970,54 @@ def training_step(self, batch, batch_idx):
# simulate random failure in training_step on rank 0
with pytest.raises(DeadlockDetectedException, match="CustomException"):
trainer.fit(model)


@RunIf(min_gpus=1)
def test_multiple_trainer_constant_memory_allocated(tmpdir):
"""
This tests ensures calling the trainer several times reset the memory back to 0.
"""

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
self.log("train_loss", loss["loss"])
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.layer.parameters(), lr=0.1)

class Check(Callback):

def on_epoch_start(self, trainer, *_):
assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel)

initial = torch.cuda.memory_allocated(0)

model = TestModel()
trainer_kwargs = dict(
default_root_dir=tmpdir,
fast_dev_run=True,
gpus=1,
accelerator="ddp",
progress_bar_refresh_rate=0,
callbacks=Check()
)
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)

assert trainer.training_type_plugin.model is model
assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu")
assert trainer.callback_metrics['train_loss'].device == torch.device("cpu")

memory_1 = torch.cuda.memory_allocated(0)
deepcopy(trainer)
memory_2 = torch.cuda.memory_allocated(0)
assert memory_1 == memory_2 == initial

trainer_2 = Trainer(**trainer_kwargs)
trainer_2.fit(model)
memory_3 = torch.cuda.memory_allocated(0)

assert initial == memory_1 == memory_3

0 comments on commit c9af1a7

Please sign in to comment.