From 0b271474e5cc35a40009559d6c4488ab3b02723e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 19 Feb 2021 02:13:54 +0100 Subject: [PATCH 1/5] continue towards 1.3 (#6069) --- CHANGELOG.md | 17 +++++++++++++++++ pytorch_lightning/__init__.py | 2 +- tests/checkpointing/test_legacy_checkpoints.py | 1 + 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b80afe7b24d0f..2ad54381a082b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,23 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [UnReleased] - 2021-MM-DD + +### Added + + +### Changed + + +### Deprecated + + +### Removed + + +### Fixed + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index b816a4e8aafb9..43c0837e13934 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -5,7 +5,7 @@ import time _this_year = time.strftime("%Y") -__version__ = '1.2.0' +__version__ = '1.3.0dev' __author__ = 'William Falcon et al.' __author_email__ = 'waf2107@columbia.edu' __license__ = 'Apache-2.0' diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7b1a7facbb3fe..abc692daf15d0 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -52,6 +52,7 @@ "1.1.6", "1.1.7", "1.1.8", + "1.2.0", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version): From 4b7c0fae00084b72dffe37fdd0ea7d2e9b60d103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Feb 2021 18:00:27 +0100 Subject: [PATCH 2/5] Fix amp autocast (#6080) * precision fixes * add amp test model * fix test * revert * move assert to training step * fix test * fix test * remove unrelated changes * add changelog * remove unused import --- CHANGELOG.md | 2 ++ .../plugins/precision/native_amp.py | 3 ++- tests/models/test_amp.py | 22 +++++++++++++------ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ad54381a082b..7dad863d41293 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) + ## [1.2.0] - 2021-02-18 diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 60c0f5f84626f..94e6cf376b03a 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" - yield torch.cuda.amp.autocast() + with torch.cuda.amp.autocast(): + yield diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 2dd6c9d997dbf..53ec32764f3ed 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -27,6 +27,16 @@ from tests.helpers import BoringModel +class AMPTestModel(BoringModel): + + def training_step(self, batch, batch_idx): + assert torch.is_autocast_enabled() + output = self(batch) + assert output.dtype == torch.float16 + loss = self.loss(batch, output) + return {"loss": loss} + + @pytest.mark.skip(reason='dp + amp not supported currently') # TODO @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_amp_single_gpu_dp(tmpdir): @@ -41,7 +51,7 @@ def test_amp_single_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -60,10 +70,9 @@ def test_amp_single_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -81,7 +90,7 @@ def test_amp_multi_gpu_dp(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) @@ -100,10 +109,9 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir): precision=16, ) - model = BoringModel() + model = AMPTestModel() # tutils.run_model_test(trainer_options, model) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" @@ -122,7 +130,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # simulate setting slurm flags tutils.set_random_master_port() - model = BoringModel() + model = AMPTestModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) From f2660acbf9cbc1c48c81a78164d53084604f08a0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 19 Feb 2021 22:45:53 +0100 Subject: [PATCH 3/5] add sanity check on nb available GPUs (#6092) --- azure-pipelines.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 6d4d49aff40c1..4d84253473bbc 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -66,8 +66,9 @@ jobs: pip list displayName: 'Install dependencies' - - script: | + - bash: | python tests/collect_env_details.py + python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" displayName: 'Env details' - bash: | @@ -76,7 +77,7 @@ jobs: ls -l legacy/checkpoints/ displayName: 'Get legacy checkpoints' - - script: | + - bash: | python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50 displayName: 'Testing: standard' @@ -90,11 +91,11 @@ jobs: codecov --token=$(CODECOV_TOKEN) --flags=gpu,pytest --name="GPU-coverage" --env=linux,azure displayName: 'Statistics' - - script: | + - bash: | python -m pytest benchmarks pl_examples -v --maxfail=2 --durations=0 displayName: 'Testing: extended' - - script: | + - bash: | python setup.py install --user --quiet bash pl_examples/run_ddp-example.sh pip uninstall -y pytorch-lightning From 3bdc0673ea5fcb10035d783df0d913be4df499b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Feb 2021 13:30:21 +0100 Subject: [PATCH 4/5] consistent behavior for reduce method across all Plugins (#6011) * reduction docs * docs for abstract base method * make mean the default * add preliminary chlog Co-authored-by: Jirka Borovec --- CHANGELOG.md | 4 ++++ .../plugins/training_type/ddp.py | 20 ++++++++++++---- .../plugins/training_type/ddp2.py | 24 ++++++++++++++----- .../plugins/training_type/ddp_spawn.py | 20 ++++++++++++---- pytorch_lightning/plugins/training_type/dp.py | 23 +++++++++++++----- .../plugins/training_type/horovod.py | 22 +++++++++++++---- .../plugins/training_type/single_device.py | 16 +++++++++++-- .../training_type/training_type_plugin.py | 11 +++++++-- 8 files changed, 111 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dad863d41293..66a64d52195a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080)) +- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011) + + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 6a4e948e899cf..80161d6e59b6b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -278,10 +278,22 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): - if isinstance(output, torch.Tensor): - output = sync_ddp_if_available(output, group, reduce_op) - return output + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index a7c8477a40c2d..a94bb5459bb1e 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -25,14 +25,26 @@ def setup(self, model): self.task_idx = self.cluster_environment.local_rank() # the difference to DDP is that we don't call children processes here - def reduce(self, output, *args, **kwargs): - if isinstance(output, Result): - output.dp_reduce() + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all processes to one aggregated tensor. + In DDP2, the reduction here is only across local devices within the node. - elif isinstance(output, torch.Tensor): - output = output.mean() + Args: + tensor: the tensor to sync and reduce + *args: ignored for DDP2 + **kwargs: ignored for DDP2 - return output + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + elif isinstance(tensor, torch.Tensor): + tensor = tensor.mean() + + return tensor @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 66d3cb7bf4619..ca25a6d8bc382 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -256,10 +256,22 @@ def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, opti if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): - if isinstance(output, torch.Tensor): - output = sync_ddp_if_available(output, group, reduce_op) - return output + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, torch.Tensor): + tensor = sync_ddp_if_available(tensor, group, reduce_op=(reduce_op or "mean")) + return tensor def training_step(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index e1002faf8a3b4..c2b16303e5d4e 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -31,14 +31,25 @@ def setup(self, model): model.to(self.root_device) self._model = DataParallel(LightningParallelModule(model), self.parallel_devices) - def reduce(self, output, *args, **kwargs): - if isinstance(output, Result): - output.dp_reduce() + def reduce(self, tensor, *args, **kwargs): + """ + Reduces a tensor from all parallel processes to one aggregated tensor. - elif isinstance(output, torch.Tensor): - output = output.mean() + Args: + tensor: the tensor to sync and reduce + *args: ignored for DP + **kwargs: ignored for DP - return output + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ + if isinstance(tensor, Result): + tensor.dp_reduce() + + elif isinstance(tensor, torch.Tensor): + tensor = tensor.mean() + + return tensor @property def root_device(self): diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 351d945675a0c..e940cb1d7229b 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -127,23 +127,35 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): + def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to 'mean'/'avg'. + Can also be a string 'sum' to calculate the sum during reduction. + + Return: + reduced value, except when the input was not a tensor the output remains is unchanged + """ if group is not None: raise ValueError( "Horovod does not support allreduce using a subcommunicator at this time. " "Unset `group`." ) - if reduce_op is None or reduce_op == "sum": - reduce_op = hvd.Sum - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + if reduce_op in (None, "avg", "mean"): reduce_op = hvd.Average + elif reduce_op == "sum": + reduce_op = hvd.Sum else: raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") # sync all processes before reduction hvd.join() - return hvd.allreduce(output, op=reduce_op) + return hvd.allreduce(tensor, op=reduce_op) def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): if group is not None: diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 4b1d24301b8a0..5bf0597ed7f18 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -19,8 +19,20 @@ def on_tpu(self) -> bool: def on_gpu(self) -> bool: return self.device.type == "cuda" and torch.cuda.is_available() - def reduce(self, output: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: - return output + def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: + """ + Reduces a tensor from several distributed processes to one aggregated tensor. + As this plugin only operates with a single device, the reduction is simply the identity. + + Args: + tensor: the tensor to sync and reduce + *args: ignored + **kwargs: ignored + + Return: + the unmodified input as reduction is not needed for single process operation + """ + return tensor @property def root_device(self) -> torch.device: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index d7c3b4d4d77e1..b60b63df23e48 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -55,8 +55,15 @@ def is_global_zero(self) -> bool: """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" @abstractmethod - def reduce(self, output: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - """Reduces the given output (e.g. across GPUs/Processes)""" + def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: + """ + Reduces the given tensor (e.g. across GPUs/processes). + + Args: + tensor: the tensor to sync and reduce + *args: plugin-specific positional arguments + **kwargs: plugin-specific keyword arguments + """ @abstractmethod def barrier(self, name: Optional[str] = None) -> None: From 97a81c3cfed9d6677d411672eecb0ed38514cb04 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sat, 20 Feb 2021 12:58:54 +0000 Subject: [PATCH 5/5] [Hot Fix] Give priority to plugins to set distributed mode, and then accelerator (#6089) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Give priority to plugins to set distributed mode, and then accelerator * Add CHANGELOG.md * Update CHANGELOG.md * Remove very scary line * Ensure we set cluster environment after slurm configured if necessary * Simplify the fix with a reset Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 2 ++ .../connectors/accelerator_connector.py | 4 +++- .../test_accelerator_connector.py | 24 ++++++++++++++++++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 66a64d52195a7..52cdb000a2a0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011) +- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) + ## [1.2.0] - 2021-02-18 diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index d32970d61fa9b..7021081d6cc90 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -163,6 +163,9 @@ def handle_given_plugins( for plug in plugins: if isinstance(plug, str): + # Reset the distributed type as the user has overridden training type + # via the plugins argument + self._distrib_type = None self.set_distributed_mode(plug) elif isinstance(plug, TrainingTypePlugin): @@ -196,7 +199,6 @@ def handle_given_plugins( ) self._training_type_plugin = training_type - self._training_type_plugin = self.training_type_plugin self._precision_plugin = precision self._cluster_environment = cluster_environment or self.select_cluster_environment() diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 76d4a597d8ecb..82b631807c8e9 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -23,7 +23,14 @@ from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin +from pytorch_lightning.plugins import ( + DDP2Plugin, + DDPPlugin, + DDPShardedPlugin, + DDPSpawnPlugin, + PrecisionPlugin, + SingleDevicePlugin, +) from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment from tests.helpers.boring_model import BoringModel @@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module): with pytest.raises(SystemExit): trainer.fit(model) + + +@pytest.mark.parametrize( + ["accelerator", "plugin"], + [('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')], +) +def test_plugin_accelerator_choice(accelerator, plugin): + """ + Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent. + """ + trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + + trainer = Trainer(plugins=plugin, num_processes=2) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)