Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Reduce memory leaks #8490

Merged
merged 41 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
aa89cc9
reduce memory leak
tchaton Jul 20, 2021
f57e21d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
341163b
update changelog
tchaton Jul 20, 2021
63eeaf0
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
f9ca9dc
Apply suggestions from code review
Borda Jul 20, 2021
1804975
resolve flake8
tchaton Jul 20, 2021
7a7b95f
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
7f06053
update on comments
tchaton Jul 20, 2021
b48e34d
resolve bug
tchaton Jul 20, 2021
eef89bc
update
tchaton Jul 20, 2021
13b335d
Undo whitespace changes
carmocca Jul 20, 2021
7fca3d8
remove bug
tchaton Jul 20, 2021
03e8faa
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
3fd83e3
resolve flake8
tchaton Jul 20, 2021
1d8b484
revert change
tchaton Jul 20, 2021
38ff815
update on comments
tchaton Jul 20, 2021
ef9a4cd
delete the ddp wrapper as it hold memory
tchaton Jul 20, 2021
9acb0c1
Merge branch 'master' into reduce_memory_leak
tchaton Jul 20, 2021
9ab40de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
000138e
resolve flake8
tchaton Jul 20, 2021
4e439f4
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
559436f
update on comments
tchaton Jul 20, 2021
8c9145d
update changelog
tchaton Jul 20, 2021
7c015cb
resolve test
tchaton Jul 20, 2021
70affbb
Update CHANGELOG
carmocca Jul 20, 2021
26bb10f
Refactor teardown
carmocca Jul 20, 2021
231b02b
Fix comment
carmocca Jul 20, 2021
0d1c365
Do it for non-gpu too
carmocca Jul 20, 2021
9077898
remove ref when the model is not a lightning_module
tchaton Jul 20, 2021
2ba1e9e
Fix import error
carmocca Jul 20, 2021
a8018df
Merge branch 'master' into reduce_memory_leak
tchaton Jul 20, 2021
666383c
move down
tchaton Jul 20, 2021
f16b8de
resolve bug
tchaton Jul 20, 2021
a915396
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 20, 2021
8c84391
resolve assignement
tchaton Jul 20, 2021
2d3223a
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
tchaton Jul 20, 2021
47c1ad3
update
tchaton Jul 21, 2021
9c347a2
move above
tchaton Jul 21, 2021
b26f98b
Fix device calls to support tpu training
kaushikb11 Jul 21, 2021
89a7033
Merge branch 'reduce_memory_leak' of https://github.com/PyTorchLightn…
kaushikb11 Jul 21, 2021
2719d03
Updat todo
kaushikb11 Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284))


- Fixed hash of LightningEnum to work with value instead of name([#8421](https://github.com/PyTorchLightning/pytorch-lightning/pull/8421)).


- Fixed `move_data_to_device` to return the batch if the object `to` function didn't return `self` ([#8433](https://github.com/PyTorchLightning/pytorch-lightning/pull/8433))


Expand All @@ -496,6 +498,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}´ runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442))


- Fixed memory leaks on GPU by moving `optimizer_states`, `ResultCollection.extra`, `ResultMetric` attributes, and `LoggerConnector` metrics to `cpu`. Also, delete the DDP wrapper on `teardown` ([#8490](https://github.com/PyTorchLightning/pytorch-lightning/pull/8490))


- Fixed `SWA` callback using LightningModule `prevent_trainer_and_dataloaders_deepcopy` to avoid OOM ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472))


Expand Down
11 changes: 4 additions & 7 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -104,21 +103,19 @@ def start_predicting(self, trainer: 'pl.Trainer') -> None:

def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
self._move_optimizer_state()
self._move_optimizer_state(self.root_device)

self.training_type_plugin.pre_dispatch()
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
self.setup_optimizers(trainer)

self.precision_plugin.pre_dispatch()

def _move_optimizer_state(self) -> None:
def _move_optimizer_state(self, device: torch.device) -> None:
""" Moves the state of the optimizers to the GPU if needed. """
for opt in self.optimizers:
state: DefaultDict = defaultdict(dict)
for p, v in opt.state.items():
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
opt.state = state
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device)

def dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ def set_nvidia_flags(local_rank: int) -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]")

def teardown(self) -> None:
super().teardown()
self._move_optimizer_state(torch.device("cpu"))
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,6 @@ def teardown(self) -> None:
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()

# Un-reference the wrapper if any was used.
self.model = self.lightning_module
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TrainingTypePlugin(Plugin, ABC):
"""

def __init__(self) -> None:
self._model = None
self._model: Optional[Module] = None
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
self._call_configure_sharded_model_hook = True

Expand Down Expand Up @@ -121,12 +121,12 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs
"""Hook to do something after each optimizer step."""

@property
def model(self) -> Module:
def model(self) -> Optional[Module]:
"""Returns the potentially wrapped LightningModule"""
return self._model

@model.setter
def model(self, new_model: Module) -> None:
def model(self, new_model: Optional[Module]) -> None:
self._model = new_model

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT

Expand Down Expand Up @@ -312,3 +313,9 @@ def progress_bar_metrics(self) -> Dict[str, float]:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics

def teardown(self):
args = (torch.Tensor, move_data_to_device, "cpu")
self._logged_metrics = apply_to_collection(self._logged_metrics, *args)
self._progress_bar_metrics = apply_to_collection(self._progress_bar_metrics, *args)
self._callback_metrics = apply_to_collection(self._callback_metrics, *args)
21 changes: 9 additions & 12 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.metrics import metrics_to_scalars
Expand Down Expand Up @@ -254,12 +253,7 @@ def __getstate__(self, drop_value: bool = False) -> dict:
if not self.is_tensor and drop_value:
# Avoid serializing ResultMetrics which are passed Metrics
skip.append('value')
with self.sync_context(
should_sync=not self.meta.sync.rank_zero_only,
process_group=self.meta.sync.group,
distributed_available=distributed_available
):
d = {k: v for k, v in self.__dict__.items() if k not in skip}
d = {k: v for k, v in self.__dict__.items() if k not in skip}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
d['meta'] = d['meta'].__getstate__()
d['_class'] = self.__class__.__name__
return d
Expand All @@ -276,6 +270,12 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'Resul
result_metric.__setstate__(state, sync_fn=sync_fn)
return result_metric

def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin':
self.__dict__.update(
apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)
)
return self
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class ResultMetricCollection(dict):
"""
Expand Down Expand Up @@ -597,10 +597,7 @@ def extract_batch_size(self, batch: Any) -> None:
def to(self, *args, **kwargs) -> 'ResultCollection':
"""Move all data to the given device."""

def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]:
return item.to(*args, **kwargs)

apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs)
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))

if self.minimize is not None:
self.minimize = self.minimize.to(*args, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ def _post_dispatch(self):
# which need to happen before.
self.accelerator.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()

def _dispatch(self):
if self.evaluating:
Expand Down
52 changes: 52 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
import torch
from omegaconf import OmegaConf
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -1969,3 +1970,54 @@ def training_step(self, batch, batch_idx):
# simulate random failure in training_step on rank 0
with pytest.raises(DeadlockDetectedException, match="CustomException"):
trainer.fit(model)


@RunIf(min_gpus=1)
def test_multiple_trainer_constant_memory_allocated(tmpdir):
"""
This tests ensures calling the trainer several times reset the memory back to 0.
"""

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
self.log("train_loss", loss["loss"])
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.layer.parameters(), lr=0.1)

class Check(Callback):

def on_epoch_start(self, trainer, *_):
assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel)

initial = torch.cuda.memory_allocated(0)

model = TestModel()
trainer_kwargs = dict(
default_root_dir=tmpdir,
fast_dev_run=True,
gpus=1,
accelerator="ddp",
progress_bar_refresh_rate=0,
callbacks=Check()
)
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)

assert trainer.training_type_plugin.model is model
assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu")
assert trainer.callback_metrics['train_loss'].device == torch.device("cpu")

memory_1 = torch.cuda.memory_allocated(0)
deepcopy(trainer)
memory_2 = torch.cuda.memory_allocated(0)
assert memory_1 == memory_2 == initial

trainer_2 = Trainer(**trainer_kwargs)
trainer_2.fit(model)
memory_3 = torch.cuda.memory_allocated(0)

assert initial == memory_1 == memory_3