From bf9cd51b173999edf9ff1ad61c3f4ab082ce9826 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 21 Feb 2021 13:26:29 +0000 Subject: [PATCH 1/4] Expose deepspeed config parameters to init function due to instability in parameters --- .../plugins/training_type/deepspeed.py | 40 ++++++++++++-- tests/plugins/test_deepspeed_plugin.py | 52 +++++++++++++++++++ 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 0f9a8378052a5e..f04da82e049806 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -79,6 +79,11 @@ def __init__( num_nodes: int = 1, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + loss_scale: float = 0, + initial_scale_power: int = 32, + loss_scale_window: int = 1000, + hysteresis: int = 2, + min_loss_scale: int = 1 ) -> None: """ @@ -127,6 +132,18 @@ def __init__( logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) + loss_scale: Loss scaling value for FP16 training. + 0.0 results in dynamic loss scaling, otherwise static (Default: 0) + + initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed + by 2^initial_scale_power (Default 32) + + loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000) + + hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2) + + min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000) + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -154,6 +171,13 @@ def __init__( self._config_initialized = False deepspeed.utils.logging.logger.setLevel(logging_level) + # default FP16 parameters. + self.loss_scale = loss_scale + self.initial_scale_power = initial_scale_power + self.loss_scale_window = loss_scale_window + self.hysteresis = hysteresis + self.min_loss_scale = min_loss_scale + def _load_config(self, config): if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable") @@ -297,9 +321,19 @@ def _format_precision_config(self): amp_level = self.lightning_module.trainer.accelerator_connector.amp_level precision = self.lightning_module.trainer.accelerator_connector.precision if precision == 16: - if "amp" not in self.config and amp_type == AMPType.NATIVE: - self.config["fp16"] = {"enabled": True} - elif "apex" not in self.config and amp_type == AMPType.APEX: + if "fp16" not in self.config and amp_type == AMPType.NATIVE: + # FP16 is a DeepSpeed standalone AMP implementation + rank_zero_info("Enabling DeepSpeed FP16.") + self.config["fp16"] = { + "enabled": True, + "loss_scale": self.loss_scale, + "initial_scale_power": self.initial_scale_power, + "loss_scale_window": self.loss_scale_window, + "hysteresis": self.hysteresis, + "min_loss_scale": self.min_loss_scale + } + elif "amp" not in self.config and amp_type == AMPType.APEX: + rank_zero_only("Enabling DeepSpeed APEX Implementation.") self.config["amp"] = { "enabled": True, "opt_level": amp_level, diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index fbb53974efd33c..945b4180da1f12 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -304,6 +304,58 @@ def on_train_start(self) -> None: _assert_save_model_is_equal(model, tmpdir, trainer) +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_deepspeed_custom_precision_params(tmpdir): + """ + Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.training_type_plugin.config['fp16']['loss_scale'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['initial_scale_power'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['loss_scale_window'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['hysteresis'] == 10 + assert self.trainer.training_type_plugin.config['fp16']['min_loss_scale'] == 10 + raise SystemExit() + + model = TestModel() + trainer = Trainer( + plugins=[ + DeepSpeedPlugin( + loss_scale=10, initial_scale_power=10, loss_scale_window=10, hysteresis=10, min_loss_scale=10 + ) + ], + precision=16, + gpus=1 + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config): + """ + Ensure if we use a config and turn off cpu_offload, that this is set to False within the config. + """ + + deepspeed_zero_config['zero_optimization']['cpu_offload'] = False + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert self.trainer.training_type_plugin.config['zero_optimization']['cpu_offload'] is False + raise SystemExit() + + model = TestModel() + trainer = Trainer(plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], precision=16, gpus=1) + with pytest.raises(SystemExit): + trainer.fit(model) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From bffb11cd1e8c15db9720362bab8b486fe172b08d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 21 Feb 2021 13:27:52 +0000 Subject: [PATCH 2/4] See if tests can run on normal CI, without special tests --- tests/plugins/test_deepspeed_plugin.py | 9 --------- tests/special_tests.sh | 3 --- 2 files changed, 12 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 945b4180da1f12..e230cdda14fa4c 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -211,9 +211,6 @@ def test_invalid_deepspeed_defaults_no_precision(tmpdir): @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) def test_warn_deepspeed_override_backward(tmpdir): """ Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning. @@ -232,9 +229,6 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) def test_deepspeed_run_configure_optimizers(tmpdir): """ Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), @@ -268,9 +262,6 @@ def on_train_start(self) -> None: @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) def test_deepspeed_config(tmpdir, deepspeed_zero_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 472f7afda5e9e0..ffb21255a6d3c6 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -17,9 +17,6 @@ export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_warn_deepspeed_override_backward -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_run_configure_optimizers -python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_config python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual From 973b43cbb427539fa6348e1b5d5331826c7169a4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 21 Feb 2021 13:32:56 +0000 Subject: [PATCH 3/4] Add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 55895318cba4fe..24612c45d7e22e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070) +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) + + ## [1.2.0] - 2021-02-18 ### Added From 5a536cf86ae70c44d186f626999c2c57d12f561e Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sun, 21 Feb 2021 15:17:32 +0000 Subject: [PATCH 4/4] Update pytorch_lightning/plugins/training_type/deepspeed.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f04da82e049806..75e5bf74be6432 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -136,7 +136,7 @@ def __init__( 0.0 results in dynamic loss scaling, otherwise static (Default: 0) initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed - by 2^initial_scale_power (Default 32) + by ``2^initial_scale_power`` (Default: 32) loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000)