From c9af1a7aec63f78b0654b82228c4c628c1c695b9 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 21 Jul 2021 11:37:05 +0200 Subject: [PATCH] [bugfix] Reduce memory leaks (#8490) * reduce memory leak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update changelog * Apply suggestions from code review Co-authored-by: Ethan Harris * resolve flake8 * update on comments * resolve bug * update * Undo whitespace changes * remove bug * resolve flake8 * revert change * update on comments * delete the ddp wrapper as it hold memory * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve flake8 * update on comments * update changelog * resolve test * Update CHANGELOG * Refactor teardown * Fix comment * Do it for non-gpu too * remove ref when the model is not a lightning_module * Fix import error * move down * resolve bug * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve assignement * update * move above * Fix device calls to support tpu training * Updat todo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: Ethan Harris Co-authored-by: Carlos Mocholi Co-authored-by: Kaushik B --- CHANGELOG.md | 5 ++ pytorch_lightning/accelerators/accelerator.py | 10 ++-- pytorch_lightning/accelerators/gpu.py | 4 ++ pytorch_lightning/accelerators/tpu.py | 12 ++++- .../plugins/training_type/parallel.py | 5 ++ .../training_type/training_type_plugin.py | 6 +-- .../logger_connector/logger_connector.py | 7 +++ .../connectors/logger_connector/result.py | 21 ++++---- pytorch_lightning/trainer/trainer.py | 1 + tests/trainer/test_trainer.py | 52 +++++++++++++++++++ 10 files changed, 101 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26ecae0d0b981..a40288c58656d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -485,8 +485,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing call to `LightningModule.untoggle_optimizer` in training loop when running gradient accumulation with multiple optimizers ([#8284](https://github.com/PyTorchLightning/pytorch-lightning/pull/8284)) + - Fixed hash of LightningEnum to work with value instead of name([#8421](https://github.com/PyTorchLightning/pytorch-lightning/pull/8421)). + - Fixed `move_data_to_device` to return the batch if the object `to` function didn't return `self` ([#8433](https://github.com/PyTorchLightning/pytorch-lightning/pull/8433)) @@ -496,6 +498,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed clearing dataloader references before attaching new dataloaders in consecutive `Trainer.{fit,validate,test,predict}ยด runs ([#8442](https://github.com/PyTorchLightning/pytorch-lightning/pull/8442)) +- Fixed memory leaks on GPU by moving `optimizer_states`, `ResultCollection.extra`, `ResultMetric` attributes, and `LoggerConnector` metrics to `cpu`. Also, delete the DDP wrapper on `teardown` ([#8490](https://github.com/PyTorchLightning/pytorch-lightning/pull/8490)) + + - Fixed `SWA` callback using LightningModule `prevent_trainer_and_dataloaders_deepcopy` to avoid OOM ([#8472](https://github.com/PyTorchLightning/pytorch-lightning/pull/8472)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5d86c54028b6e..fe20647d5a384 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from collections import defaultdict -from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch from torch import Tensor @@ -112,13 +111,12 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None: self.precision_plugin.pre_dispatch() - def _move_optimizer_state(self) -> None: + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """ Moves the state of the optimizers to the GPU if needed. """ + device = device or self.root_device for opt in self.optimizers: - state: DefaultDict = defaultdict(dict) for p, v in opt.state.items(): - state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) - opt.state = state + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device) def dispatch(self, trainer: 'pl.Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index 0592cffa1a4bc..ac90a5e6926fc 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -52,3 +52,7 @@ def set_nvidia_flags(local_rank: int) -> None: all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids) _log.info(f"LOCAL_RANK: {local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def teardown(self) -> None: + super().teardown() + self._move_optimizer_state(torch.device("cpu")) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 83b678d171ffa..936b1c836bf74 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, Optional +import torch from torch.optim import Optimizer import pytorch_lightning as pl @@ -21,6 +22,7 @@ from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.utilities import _XLA_AVAILABLE +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException if _XLA_AVAILABLE: @@ -49,3 +51,11 @@ def run_optimizer_step( self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any ) -> None: xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) + + def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: + """ Moves the state of the optimizers to the TPU if needed. """ + # TODO: `self.root_device` would raise error if called outside the spawn process + # while training on 8 and more cores. + for opt in self.optimizers: + for p, v in opt.state.items(): + opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index e1c9a7149d066..f708bab24562e 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -133,6 +133,11 @@ def block_backward_sync(self): yield None def teardown(self) -> None: + # Un-reference the wrapper if any was used. + # todo (tchaton): Add support for all plugins. + if isinstance(self.model, DistributedDataParallel): + self.model = self.lightning_module + if self.on_gpu: # GPU teardown self.lightning_module.cpu() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e49d170a93d66..57de5b8195c2c 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -39,7 +39,7 @@ class TrainingTypePlugin(Plugin, ABC): """ def __init__(self) -> None: - self._model = None + self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None self._call_configure_sharded_model_hook = True @@ -121,12 +121,12 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs """Hook to do something after each optimizer step.""" @property - def model(self) -> Module: + def model(self) -> Optional[Module]: """Returns the potentially wrapped LightningModule""" return self._model @model.setter - def model(self, new_model: Module) -> None: + def model(self, new_model: Optional[Module]) -> None: self._model = new_model @property diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a539318b3773d..265f8e6b3ebdb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -23,6 +23,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import _METRIC, MetricSource from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import DeviceType +from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT @@ -312,3 +313,9 @@ def progress_bar_metrics(self) -> Dict[str, float]: metrics = self.metrics[MetricSource.PBAR] self._progress_bar_metrics.update(metrics) return self._progress_bar_metrics + + def teardown(self): + args = (torch.Tensor, move_data_to_device, "cpu") + self._logged_metrics = apply_to_collection(self._logged_metrics, *args) + self._progress_bar_metrics = apply_to_collection(self._progress_bar_metrics, *args) + self._callback_metrics = apply_to_collection(self._callback_metrics, *args) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ae0989010a9a6..e94eb7f9fa6a0 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -21,9 +21,8 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device from pytorch_lightning.utilities.data import extract_batch_size -from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.metrics import metrics_to_scalars @@ -254,12 +253,7 @@ def __getstate__(self, drop_value: bool = False) -> dict: if not self.is_tensor and drop_value: # Avoid serializing ResultMetrics which are passed Metrics skip.append('value') - with self.sync_context( - should_sync=not self.meta.sync.rank_zero_only, - process_group=self.meta.sync.group, - distributed_available=distributed_available - ): - d = {k: v for k, v in self.__dict__.items() if k not in skip} + d = {k: v for k, v in self.__dict__.items() if k not in skip} d['meta'] = d['meta'].__getstate__() d['_class'] = self.__class__.__name__ return d @@ -276,6 +270,12 @@ def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> 'Resul result_metric.__setstate__(state, sync_fn=sync_fn) return result_metric + def to(self, *args: Any, **kwargs: Any) -> 'DeviceDtypeModuleMixin': + self.__dict__.update( + apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs) + ) + return self + class ResultMetricCollection(dict): """ @@ -597,10 +597,7 @@ def extract_batch_size(self, batch: Any) -> None: def to(self, *args, **kwargs) -> 'ResultCollection': """Move all data to the given device.""" - def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]: - return item.to(*args, **kwargs) - - apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs) + self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) if self.minimize is not None: self.minimize = self.minimize.to(*args, **kwargs) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f6342b5e8e458..ef6b74f1deb96 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -964,6 +964,7 @@ def _post_dispatch(self): # which need to happen before. self.accelerator.teardown() self._active_loop.teardown() + self.logger_connector.teardown() def _dispatch(self): if self.evaluating: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4f2043d80c805..683919560c7ae 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -25,6 +25,7 @@ import pytest import torch from omegaconf import OmegaConf +from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD from torch.utils.data import DataLoader @@ -1969,3 +1970,54 @@ def training_step(self, batch, batch_idx): # simulate random failure in training_step on rank 0 with pytest.raises(DeadlockDetectedException, match="CustomException"): trainer.fit(model) + + +@RunIf(min_gpus=1) +def test_multiple_trainer_constant_memory_allocated(tmpdir): + """ + This tests ensures calling the trainer several times reset the memory back to 0. + """ + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + loss = super().training_step(batch, batch_idx) + self.log("train_loss", loss["loss"]) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.layer.parameters(), lr=0.1) + + class Check(Callback): + + def on_epoch_start(self, trainer, *_): + assert isinstance(trainer.training_type_plugin.model, DistributedDataParallel) + + initial = torch.cuda.memory_allocated(0) + + model = TestModel() + trainer_kwargs = dict( + default_root_dir=tmpdir, + fast_dev_run=True, + gpus=1, + accelerator="ddp", + progress_bar_refresh_rate=0, + callbacks=Check() + ) + trainer = Trainer(**trainer_kwargs) + trainer.fit(model) + + assert trainer.training_type_plugin.model is model + assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") + assert trainer.callback_metrics['train_loss'].device == torch.device("cpu") + + memory_1 = torch.cuda.memory_allocated(0) + deepcopy(trainer) + memory_2 = torch.cuda.memory_allocated(0) + assert memory_1 == memory_2 == initial + + trainer_2 = Trainer(**trainer_kwargs) + trainer_2.fit(model) + memory_3 = torch.cuda.memory_allocated(0) + + assert initial == memory_1 == memory_3