From ddfd5fe6e3e0c37cfc428858ef7c4fdb877c0578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 15 Aug 2023 14:39:51 +0200 Subject: [PATCH] Disable memory sharing on model parameters in ddp-spawn (#18238) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit a0ca2c8bcdc79e02fa0c249bfef0697c111d9604) --- src/lightning/fabric/CHANGELOG.md | 6 ++ .../strategies/launchers/multiprocessing.py | 26 +++++++++ src/lightning/pytorch/CHANGELOG.md | 2 + .../strategies/launchers/multiprocessing.py | 6 +- .../launchers/test_multiprocessing.py | 10 ++-- .../test_multiprocessing_integration.py | 55 +++++++++++++++++++ .../launchers/test_multiprocessing.py | 28 ++++++++++ 7 files changed, 127 insertions(+), 6 deletions(-) create mode 100644 tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index ba482b8ee58b9..70b1aa6cb357e 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [UnReleased] - 2023-08-DD + +### Fixed + +- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238)) + ## [2.0.7] - 2023-08-14 diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index da85b39a623d7..f7b224c242094 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import os from dataclasses import dataclass from multiprocessing.queues import SimpleQueue @@ -19,7 +20,10 @@ import torch import torch.backends.cudnn import torch.multiprocessing as mp +from lightning_utilities import apply_to_collection +from torch.nn import Module +from lightning.fabric.accelerators.cpu import CPUAccelerator from lightning.fabric.strategies.launchers.launcher import _Launcher from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.imports import _IS_INTERACTIVE @@ -122,6 +126,10 @@ def _wrapping_function( ) -> None: if global_states: global_states.restore() + + if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator): + args, kwargs = _disable_module_memory_sharing((args, kwargs)) + os.environ["LOCAL_RANK"] = str(process_idx) results = function(*args, **kwargs) @@ -190,3 +198,21 @@ def _check_bad_cuda_fork() -> None: if _IS_INTERACTIVE: message += " You will have to restart the Python kernel." raise RuntimeError(message) + + +def _disable_module_memory_sharing(data: Any) -> Any: + """Disables memory sharing on parameters and buffers of `nn.Module`s contained in the given collection. + + Note: This is only required when running on CPU. + """ + # PyTorch enables memory sharing automatically on all tensors that are passed through `mp.spawn`. + # For model weights and buffers, this is undesired and can lead to race conditions between processes. + # Hence, we copy the tensors in the entire module to ensure it doesn't share memory with other processes. + + @torch.no_grad() + def unshare(module: Module) -> Module: + for tensor in itertools.chain(module.parameters(), module.buffers()): + tensor.data = tensor.data.clone() + return module + + return apply_to_collection(data, function=unshare, dtype=Module) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 42ae22dd12d6c..4fa4377c62e73 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed redundant `iter()` call to dataloader when checking dataloading configuration ([#18415](https://github.com/Lightning-AI/lightning/pull/18415)) +- Fixed model parameters getting shared between processes when running with `strategy="ddp_spawn"` and `accelerator="cpu"`; this has a necessary memory impact, as parameters are replicated for each process now ([#18238](https://github.com/Lightning-AI/lightning/pull/18238)) + ## [2.0.5] - 2023-07-07 diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index b42a169ab66b7..7c165d1982eff 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -27,10 +27,11 @@ from torch import Tensor import lightning.pytorch as pl -from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork +from lightning.fabric.strategies.launchers.multiprocessing import _check_bad_cuda_fork, _disable_module_memory_sharing from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.seed import _collect_rng_states, _set_rng_states from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.strategies.launchers.launcher import _Launcher from lightning.pytorch.trainer.connectors.signal_connector import _SIGNUM from lightning.pytorch.trainer.states import TrainerFn, TrainerState @@ -145,6 +146,9 @@ def _wrapping_function( ) -> None: if global_states: global_states.restore() + if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator): + args, kwargs = _disable_module_memory_sharing((args, kwargs)) + os.environ["LOCAL_RANK"] = str(process_idx) results = function(*args, **kwargs) diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index 91c209dcc5ee3..f54226f1174e5 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -23,20 +23,20 @@ @RunIf(skip_windows=True) @pytest.mark.parametrize("start_method", ["fork", "forkserver"]) -def test_multiprocessing_launcher_interactive_compatible(start_method): +def test_interactive_compatible(start_method): launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) assert launcher.is_interactive_compatible == (start_method == "fork") @mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) -def test_multiprocessing_launcher_forking_on_unsupported_platform(_): +def test_forking_on_unsupported_platform(_): with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"): _MultiProcessingLauncher(strategy=Mock(), start_method="fork") @pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))]) @mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp") -def test_multiprocessing_launcher_start_method(mp_mock, start_method): +def test_start_method(mp_mock, start_method): mp_mock.get_all_start_methods.return_value = [start_method] launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) launcher.launch(function=Mock()) @@ -51,7 +51,7 @@ def test_multiprocessing_launcher_start_method(mp_mock, start_method): @pytest.mark.parametrize("start_method", ["spawn", pytest.param("fork", marks=RunIf(standalone=True))]) @mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp") -def test_multiprocessing_launcher_restore_globals(mp_mock, start_method): +def test_restore_globals(mp_mock, start_method): """Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'.""" mp_mock.get_all_start_methods.return_value = [start_method] launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) @@ -89,7 +89,7 @@ def test_global_state_snapshot(): @pytest.mark.parametrize("start_method", ["fork", "forkserver"]) @mock.patch("torch.cuda.is_initialized", return_value=True) @mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp") -def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method): +def test_check_for_bad_cuda_fork(mp_mock, _, start_method): mp_mock.get_all_start_methods.return_value = [start_method] launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"): diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py new file mode 100644 index 0000000000000..e53b8087b8084 --- /dev/null +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -0,0 +1,55 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.nn as nn + +from lightning.fabric import Fabric +from tests_fabric.helpers.runif import RunIf + + +class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(2, 2) + self.tied_layer = nn.Linear(2, 2) + self.tied_layer.weight = self.layer.weight + self.register_buffer("buffer", torch.ones(3)) + + +@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))]) +def test_memory_sharing_disabled(strategy): + """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race + conditions on model updates.""" + tensor = torch.rand(4) + model = SimpleModel() + assert not tensor.is_shared() + assert not model.layer.weight.is_shared() + assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr() + + fabric = Fabric(accelerator="cpu", devices=2, strategy=strategy) + fabric.launch(_test_memory_sharing_disabled, tensor, model) + + +def _test_memory_sharing_disabled(fabric, tensor, model): + is_spawn = fabric.strategy.launcher._start_method == "spawn" + assert not is_spawn or tensor.is_shared() + assert not model.layer.weight.is_shared() + assert not model.tied_layer.weight.is_shared() + assert not model.buffer.is_shared() + + # weights remain tied + assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr() + assert torch.equal(model.layer.weight.data, model.tied_layer.weight.data) + fabric.barrier() diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 73b006e6f7831..f38b7575950ca 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -175,3 +175,31 @@ def test_kill(): with patch("os.kill") as kill_patch: launcher.kill(15) assert kill_patch.mock_calls == [call(proc0.pid, 15), call(proc1.pid, 15)] + + +class SimpleModel(BoringModel): + def __init__(self): + super().__init__() + self.tied_layer = torch.nn.Linear(32, 2) + self.tied_layer.weight = self.layer.weight + self.register_buffer("buffer", torch.ones(3)) + + def on_fit_start(self) -> None: + assert not self.layer.weight.is_shared() + assert not self.tied_layer.weight.is_shared() + assert not self.buffer.is_shared() + + # weights remain tied + assert self.layer.weight.data_ptr() == self.tied_layer.weight.data_ptr() + assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data) + + +def test_memory_sharing_disabled(): + """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race + conditions on model updates.""" + model = SimpleModel() + assert not model.layer.weight.is_shared() + assert model.layer.weight.data_ptr() == model.tied_layer.weight.data_ptr() + + trainer = Trainer(accelerator="cpu", devices=2, strategy="ddp_spawn", max_steps=0) + trainer.fit(model)